separate train val func

This commit is contained in:
Thang Vu 2022-04-15 05:31:38 +00:00
parent f483165c2a
commit efe2ddb6b1

195
train.py
View File

@ -26,17 +26,111 @@ def get_args():
parser.add_argument('--dist', action='store_true', help='run with distributed parallel')
parser.add_argument('--resume', type=str, help='path to resume from')
parser.add_argument('--work_dir', type=str, help='working directory')
parser.add_argument('--skip_validate', action='store_true', help='skip validation')
args = parser.parse_args()
return args
if __name__ == '__main__':
def train(epoch, model, optimizer, scaler, train_loader, cfg, logger, writer):
model.train()
iter_time = AverageMeter(True)
data_time = AverageMeter(True)
meter_dict = {}
end = time.time()
if train_loader.sampler is not None and cfg.dist:
train_loader.sampler.set_epoch(epoch)
for i, batch in enumerate(train_loader, start=1):
data_time.update(time.time() - end)
cosine_lr_after_step(optimizer, cfg.optimizer.lr, epoch - 1, cfg.step_epoch, cfg.epochs)
with torch.cuda.amp.autocast(enabled=cfg.fp16):
loss, log_vars = model(batch, return_loss=True)
# meter_dict
for k, v in log_vars.items():
if k not in meter_dict.keys():
meter_dict[k] = AverageMeter()
meter_dict[k].update(v)
# backward
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# time and print
remain_iter = len(train_loader) * (cfg.epochs - epoch + 1) - i
iter_time.update(time.time() - end)
end = time.time()
remain_time = remain_iter * iter_time.avg
remain_time = str(datetime.timedelta(seconds=int(remain_time)))
lr = optimizer.param_groups[0]['lr']
if is_multiple(i, 10):
log_str = f'Epoch [{epoch}/{cfg.epochs}][{i}/{len(train_loader)}] '
log_str += f'lr: {lr:.2g}, eta: {remain_time}, mem: {get_max_memory()}, '\
f'data_time: {data_time.val:.2f}, iter_time: {iter_time.val:.2f}'
for k, v in meter_dict.items():
log_str += f', {k}: {v.val:.4f}'
logger.info(log_str)
writer.add_scalar('train/learning_rate', lr, epoch)
for k, v in meter_dict.items():
writer.add_scalar(f'train/{k}', v.avg, epoch)
checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq)
def validate(epoch, model, val_loader, cfg, logger, writer):
logger.info('Validation')
results = []
all_sem_preds, all_sem_labels, all_offset_preds, all_offset_labels = [], [], [], []
all_inst_labels, all_pred_insts, all_gt_insts = [], [], []
_, world_size = get_dist_info()
progress_bar = tqdm(total=len(val_loader) * world_size, disable=not is_main_process())
val_set = val_loader.dataset
with torch.no_grad():
model = model.eval()
for i, batch in enumerate(val_loader):
result = model(batch)
results.append(result)
progress_bar.update(world_size)
progress_bar.close()
results = collect_results_gpu(results, len(val_set))
if is_main_process():
for res in results:
all_sem_preds.append(res['semantic_preds'])
all_sem_labels.append(res['semantic_labels'])
all_offset_preds.append(res['offset_preds'])
all_offset_labels.append(res['offset_labels'])
all_inst_labels.append(res['instance_labels'])
if not cfg.model.semantic_only:
all_pred_insts.append(res['pred_instances'])
all_gt_insts.append(res['gt_instances'])
if not cfg.model.semantic_only:
logger.info('Evaluate instance segmentation')
scannet_eval = ScanNetEval(val_set.CLASSES)
eval_res = scannet_eval.evaluate(all_pred_insts, all_gt_insts)
writer.add_scalar('val/AP', eval_res['all_ap'], epoch)
writer.add_scalar('val/AP_50', eval_res['all_ap_50%'], epoch)
writer.add_scalar('val/AP_25', eval_res['all_ap_25%'], epoch)
logger.info('Evaluate semantic segmentation and offset MAE')
miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label, logger)
acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label, logger)
mae = evaluate_offset_mae(all_offset_preds, all_offset_labels, all_inst_labels,
cfg.model.ignore_label, logger)
writer.add_scalar('val/mIoU', miou, epoch)
writer.add_scalar('val/Acc', acc, epoch)
writer.add_scalar('val/Offset MAE', mae, epoch)
def main():
args = get_args()
cfg_txt = open(args.config, 'r').read()
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
if args.dist:
init_dist()
cfg.dist = args.dist
# work_dir & logger
if args.work_dir:
@ -81,96 +175,11 @@ if __name__ == '__main__':
# train and val
logger.info('Training')
for epoch in range(start_epoch, cfg.epochs + 1):
model.train()
iter_time = AverageMeter(True)
data_time = AverageMeter(True)
meter_dict = {}
end = time.time()
if train_loader.sampler is not None and args.dist:
train_loader.sampler.set_epoch(epoch)
for i, batch in enumerate(train_loader, start=1):
data_time.update(time.time() - end)
cosine_lr_after_step(optimizer, cfg.optimizer.lr, epoch - 1, cfg.step_epoch, cfg.epochs)
with torch.cuda.amp.autocast(enabled=cfg.fp16):
loss, log_vars = model(batch, return_loss=True)
# meter_dict
for k, v in log_vars.items():
if k not in meter_dict.keys():
meter_dict[k] = AverageMeter()
meter_dict[k].update(v)
# backward
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# time and print
remain_iter = len(train_loader) * (cfg.epochs - epoch + 1) - i
iter_time.update(time.time() - end)
end = time.time()
remain_time = remain_iter * iter_time.avg
remain_time = str(datetime.timedelta(seconds=int(remain_time)))
lr = optimizer.param_groups[0]['lr']
if is_multiple(i, 10):
log_str = f'Epoch [{epoch}/{cfg.epochs}][{i}/{len(train_loader)}] '
log_str += f'lr: {lr:.2g}, eta: {remain_time}, mem: {get_max_memory()}, '\
f'data_time: {data_time.val:.2f}, iter_time: {iter_time.val:.2f}'
for k, v in meter_dict.items():
log_str += f', {k}: {v.val:.4f}'
logger.info(log_str)
writer.add_scalar('train/learning_rate', lr, epoch)
for k, v in meter_dict.items():
writer.add_scalar(f'train/{k}', v.avg, epoch)
checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq)
# validation
if is_multiple(epoch, cfg.save_freq) or is_power2(epoch):
logger.info('Validation')
results = []
all_sem_preds, all_sem_labels, all_offset_preds, all_offset_labels = [], [], [], []
all_inst_labels, all_pred_insts, all_gt_insts = [], [], []
_, world_size = get_dist_info()
progress_bar = tqdm(total=len(val_loader) * world_size, disable=not is_main_process())
with torch.no_grad():
model = model.eval()
for i, batch in enumerate(val_loader):
result = model(batch)
results.append(result)
progress_bar.update(world_size)
progress_bar.close()
results = collect_results_gpu(results, len(val_set))
if is_main_process():
for res in results:
all_sem_preds.append(res['semantic_preds'])
all_sem_labels.append(res['semantic_labels'])
all_offset_preds.append(res['offset_preds'])
all_offset_labels.append(res['offset_labels'])
all_inst_labels.append(res['instance_labels'])
if not cfg.model.semantic_only:
all_pred_insts.append(res['pred_instances'])
all_gt_insts.append(res['gt_instances'])
if not cfg.model.semantic_only:
logger.info('Evaluate instance segmentation')
scannet_eval = ScanNetEval(val_set.CLASSES)
eval_res = scannet_eval.evaluate(all_pred_insts, all_gt_insts)
writer.add_scalar('val/AP', eval_res['all_ap'], epoch)
writer.add_scalar('val/AP_50', eval_res['all_ap_50%'], epoch)
writer.add_scalar('val/AP_25', eval_res['all_ap_25%'], epoch)
logger.info('Evaluate semantic segmentation and offset MAE')
miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
logger)
acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
logger)
mae = evaluate_offset_mae(all_offset_preds, all_offset_labels, all_inst_labels,
cfg.model.ignore_label, logger)
writer.add_scalar('val/mIoU', miou, epoch)
writer.add_scalar('val/Acc', acc, epoch)
writer.add_scalar('val/Offset MAE', mae, epoch)
train(epoch, model, optimizer, scaler, train_loader, cfg, logger, writer)
if not args.skip_validate and (is_multiple(epoch, cfg.save_freq) or is_power2(epoch)):
validate(epoch, model, val_loader, cfg, logger, writer)
writer.flush()
if __name__ == '__main__':
main()