From efe2ddb6b13655015d95362ace44138d6a0a482e Mon Sep 17 00:00:00 2001 From: Thang Vu Date: Fri, 15 Apr 2022 05:31:38 +0000 Subject: [PATCH] separate train val func --- train.py | 195 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 102 insertions(+), 93 deletions(-) diff --git a/train.py b/train.py index 827fa2e..5ddf9d2 100644 --- a/train.py +++ b/train.py @@ -26,17 +26,111 @@ def get_args(): parser.add_argument('--dist', action='store_true', help='run with distributed parallel') parser.add_argument('--resume', type=str, help='path to resume from') parser.add_argument('--work_dir', type=str, help='working directory') + parser.add_argument('--skip_validate', action='store_true', help='skip validation') args = parser.parse_args() return args -if __name__ == '__main__': +def train(epoch, model, optimizer, scaler, train_loader, cfg, logger, writer): + model.train() + iter_time = AverageMeter(True) + data_time = AverageMeter(True) + meter_dict = {} + end = time.time() + + if train_loader.sampler is not None and cfg.dist: + train_loader.sampler.set_epoch(epoch) + + for i, batch in enumerate(train_loader, start=1): + data_time.update(time.time() - end) + cosine_lr_after_step(optimizer, cfg.optimizer.lr, epoch - 1, cfg.step_epoch, cfg.epochs) + with torch.cuda.amp.autocast(enabled=cfg.fp16): + loss, log_vars = model(batch, return_loss=True) + + # meter_dict + for k, v in log_vars.items(): + if k not in meter_dict.keys(): + meter_dict[k] = AverageMeter() + meter_dict[k].update(v) + + # backward + optimizer.zero_grad() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + # time and print + remain_iter = len(train_loader) * (cfg.epochs - epoch + 1) - i + iter_time.update(time.time() - end) + end = time.time() + remain_time = remain_iter * iter_time.avg + remain_time = str(datetime.timedelta(seconds=int(remain_time))) + lr = optimizer.param_groups[0]['lr'] + + if is_multiple(i, 10): + log_str = f'Epoch [{epoch}/{cfg.epochs}][{i}/{len(train_loader)}] ' + log_str += f'lr: {lr:.2g}, eta: {remain_time}, mem: {get_max_memory()}, '\ + f'data_time: {data_time.val:.2f}, iter_time: {iter_time.val:.2f}' + for k, v in meter_dict.items(): + log_str += f', {k}: {v.val:.4f}' + logger.info(log_str) + writer.add_scalar('train/learning_rate', lr, epoch) + for k, v in meter_dict.items(): + writer.add_scalar(f'train/{k}', v.avg, epoch) + checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq) + + +def validate(epoch, model, val_loader, cfg, logger, writer): + logger.info('Validation') + results = [] + all_sem_preds, all_sem_labels, all_offset_preds, all_offset_labels = [], [], [], [] + all_inst_labels, all_pred_insts, all_gt_insts = [], [], [] + _, world_size = get_dist_info() + progress_bar = tqdm(total=len(val_loader) * world_size, disable=not is_main_process()) + val_set = val_loader.dataset + with torch.no_grad(): + model = model.eval() + for i, batch in enumerate(val_loader): + result = model(batch) + results.append(result) + progress_bar.update(world_size) + progress_bar.close() + results = collect_results_gpu(results, len(val_set)) + if is_main_process(): + for res in results: + all_sem_preds.append(res['semantic_preds']) + all_sem_labels.append(res['semantic_labels']) + all_offset_preds.append(res['offset_preds']) + all_offset_labels.append(res['offset_labels']) + all_inst_labels.append(res['instance_labels']) + if not cfg.model.semantic_only: + all_pred_insts.append(res['pred_instances']) + all_gt_insts.append(res['gt_instances']) + if not cfg.model.semantic_only: + logger.info('Evaluate instance segmentation') + scannet_eval = ScanNetEval(val_set.CLASSES) + eval_res = scannet_eval.evaluate(all_pred_insts, all_gt_insts) + writer.add_scalar('val/AP', eval_res['all_ap'], epoch) + writer.add_scalar('val/AP_50', eval_res['all_ap_50%'], epoch) + writer.add_scalar('val/AP_25', eval_res['all_ap_25%'], epoch) + logger.info('Evaluate semantic segmentation and offset MAE') + miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label, logger) + acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label, logger) + mae = evaluate_offset_mae(all_offset_preds, all_offset_labels, all_inst_labels, + cfg.model.ignore_label, logger) + writer.add_scalar('val/mIoU', miou, epoch) + writer.add_scalar('val/Acc', acc, epoch) + writer.add_scalar('val/Offset MAE', mae, epoch) + + +def main(): args = get_args() cfg_txt = open(args.config, 'r').read() cfg = Munch.fromDict(yaml.safe_load(cfg_txt)) if args.dist: init_dist() + cfg.dist = args.dist # work_dir & logger if args.work_dir: @@ -81,96 +175,11 @@ if __name__ == '__main__': # train and val logger.info('Training') for epoch in range(start_epoch, cfg.epochs + 1): - model.train() - iter_time = AverageMeter(True) - data_time = AverageMeter(True) - meter_dict = {} - end = time.time() - - if train_loader.sampler is not None and args.dist: - train_loader.sampler.set_epoch(epoch) - - for i, batch in enumerate(train_loader, start=1): - data_time.update(time.time() - end) - - cosine_lr_after_step(optimizer, cfg.optimizer.lr, epoch - 1, cfg.step_epoch, cfg.epochs) - - with torch.cuda.amp.autocast(enabled=cfg.fp16): - loss, log_vars = model(batch, return_loss=True) - - # meter_dict - for k, v in log_vars.items(): - if k not in meter_dict.keys(): - meter_dict[k] = AverageMeter() - meter_dict[k].update(v) - - # backward - optimizer.zero_grad() - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - - # time and print - remain_iter = len(train_loader) * (cfg.epochs - epoch + 1) - i - iter_time.update(time.time() - end) - end = time.time() - remain_time = remain_iter * iter_time.avg - remain_time = str(datetime.timedelta(seconds=int(remain_time))) - lr = optimizer.param_groups[0]['lr'] - - if is_multiple(i, 10): - log_str = f'Epoch [{epoch}/{cfg.epochs}][{i}/{len(train_loader)}] ' - log_str += f'lr: {lr:.2g}, eta: {remain_time}, mem: {get_max_memory()}, '\ - f'data_time: {data_time.val:.2f}, iter_time: {iter_time.val:.2f}' - for k, v in meter_dict.items(): - log_str += f', {k}: {v.val:.4f}' - logger.info(log_str) - writer.add_scalar('train/learning_rate', lr, epoch) - for k, v in meter_dict.items(): - writer.add_scalar(f'train/{k}', v.avg, epoch) - checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq) - - # validation - if is_multiple(epoch, cfg.save_freq) or is_power2(epoch): - logger.info('Validation') - results = [] - all_sem_preds, all_sem_labels, all_offset_preds, all_offset_labels = [], [], [], [] - all_inst_labels, all_pred_insts, all_gt_insts = [], [], [] - _, world_size = get_dist_info() - progress_bar = tqdm(total=len(val_loader) * world_size, disable=not is_main_process()) - with torch.no_grad(): - model = model.eval() - for i, batch in enumerate(val_loader): - result = model(batch) - results.append(result) - progress_bar.update(world_size) - progress_bar.close() - results = collect_results_gpu(results, len(val_set)) - if is_main_process(): - for res in results: - all_sem_preds.append(res['semantic_preds']) - all_sem_labels.append(res['semantic_labels']) - all_offset_preds.append(res['offset_preds']) - all_offset_labels.append(res['offset_labels']) - all_inst_labels.append(res['instance_labels']) - if not cfg.model.semantic_only: - all_pred_insts.append(res['pred_instances']) - all_gt_insts.append(res['gt_instances']) - if not cfg.model.semantic_only: - logger.info('Evaluate instance segmentation') - scannet_eval = ScanNetEval(val_set.CLASSES) - eval_res = scannet_eval.evaluate(all_pred_insts, all_gt_insts) - writer.add_scalar('val/AP', eval_res['all_ap'], epoch) - writer.add_scalar('val/AP_50', eval_res['all_ap_50%'], epoch) - writer.add_scalar('val/AP_25', eval_res['all_ap_25%'], epoch) - logger.info('Evaluate semantic segmentation and offset MAE') - miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label, - logger) - acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label, - logger) - mae = evaluate_offset_mae(all_offset_preds, all_offset_labels, all_inst_labels, - cfg.model.ignore_label, logger) - writer.add_scalar('val/mIoU', miou, epoch) - writer.add_scalar('val/Acc', acc, epoch) - writer.add_scalar('val/Offset MAE', mae, epoch) + train(epoch, model, optimizer, scaler, train_loader, cfg, logger, writer) + if not args.skip_validate and (is_multiple(epoch, cfg.save_freq) or is_power2(epoch)): + validate(epoch, model, val_loader, cfg, logger, writer) writer.flush() + + +if __name__ == '__main__': + main()