mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
separate train val func
This commit is contained in:
parent
f483165c2a
commit
efe2ddb6b1
195
train.py
195
train.py
@ -26,17 +26,111 @@ def get_args():
|
|||||||
parser.add_argument('--dist', action='store_true', help='run with distributed parallel')
|
parser.add_argument('--dist', action='store_true', help='run with distributed parallel')
|
||||||
parser.add_argument('--resume', type=str, help='path to resume from')
|
parser.add_argument('--resume', type=str, help='path to resume from')
|
||||||
parser.add_argument('--work_dir', type=str, help='working directory')
|
parser.add_argument('--work_dir', type=str, help='working directory')
|
||||||
|
parser.add_argument('--skip_validate', action='store_true', help='skip validation')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
def train(epoch, model, optimizer, scaler, train_loader, cfg, logger, writer):
|
||||||
|
model.train()
|
||||||
|
iter_time = AverageMeter(True)
|
||||||
|
data_time = AverageMeter(True)
|
||||||
|
meter_dict = {}
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
if train_loader.sampler is not None and cfg.dist:
|
||||||
|
train_loader.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
for i, batch in enumerate(train_loader, start=1):
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
cosine_lr_after_step(optimizer, cfg.optimizer.lr, epoch - 1, cfg.step_epoch, cfg.epochs)
|
||||||
|
with torch.cuda.amp.autocast(enabled=cfg.fp16):
|
||||||
|
loss, log_vars = model(batch, return_loss=True)
|
||||||
|
|
||||||
|
# meter_dict
|
||||||
|
for k, v in log_vars.items():
|
||||||
|
if k not in meter_dict.keys():
|
||||||
|
meter_dict[k] = AverageMeter()
|
||||||
|
meter_dict[k].update(v)
|
||||||
|
|
||||||
|
# backward
|
||||||
|
optimizer.zero_grad()
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
# time and print
|
||||||
|
remain_iter = len(train_loader) * (cfg.epochs - epoch + 1) - i
|
||||||
|
iter_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
remain_time = remain_iter * iter_time.avg
|
||||||
|
remain_time = str(datetime.timedelta(seconds=int(remain_time)))
|
||||||
|
lr = optimizer.param_groups[0]['lr']
|
||||||
|
|
||||||
|
if is_multiple(i, 10):
|
||||||
|
log_str = f'Epoch [{epoch}/{cfg.epochs}][{i}/{len(train_loader)}] '
|
||||||
|
log_str += f'lr: {lr:.2g}, eta: {remain_time}, mem: {get_max_memory()}, '\
|
||||||
|
f'data_time: {data_time.val:.2f}, iter_time: {iter_time.val:.2f}'
|
||||||
|
for k, v in meter_dict.items():
|
||||||
|
log_str += f', {k}: {v.val:.4f}'
|
||||||
|
logger.info(log_str)
|
||||||
|
writer.add_scalar('train/learning_rate', lr, epoch)
|
||||||
|
for k, v in meter_dict.items():
|
||||||
|
writer.add_scalar(f'train/{k}', v.avg, epoch)
|
||||||
|
checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq)
|
||||||
|
|
||||||
|
|
||||||
|
def validate(epoch, model, val_loader, cfg, logger, writer):
|
||||||
|
logger.info('Validation')
|
||||||
|
results = []
|
||||||
|
all_sem_preds, all_sem_labels, all_offset_preds, all_offset_labels = [], [], [], []
|
||||||
|
all_inst_labels, all_pred_insts, all_gt_insts = [], [], []
|
||||||
|
_, world_size = get_dist_info()
|
||||||
|
progress_bar = tqdm(total=len(val_loader) * world_size, disable=not is_main_process())
|
||||||
|
val_set = val_loader.dataset
|
||||||
|
with torch.no_grad():
|
||||||
|
model = model.eval()
|
||||||
|
for i, batch in enumerate(val_loader):
|
||||||
|
result = model(batch)
|
||||||
|
results.append(result)
|
||||||
|
progress_bar.update(world_size)
|
||||||
|
progress_bar.close()
|
||||||
|
results = collect_results_gpu(results, len(val_set))
|
||||||
|
if is_main_process():
|
||||||
|
for res in results:
|
||||||
|
all_sem_preds.append(res['semantic_preds'])
|
||||||
|
all_sem_labels.append(res['semantic_labels'])
|
||||||
|
all_offset_preds.append(res['offset_preds'])
|
||||||
|
all_offset_labels.append(res['offset_labels'])
|
||||||
|
all_inst_labels.append(res['instance_labels'])
|
||||||
|
if not cfg.model.semantic_only:
|
||||||
|
all_pred_insts.append(res['pred_instances'])
|
||||||
|
all_gt_insts.append(res['gt_instances'])
|
||||||
|
if not cfg.model.semantic_only:
|
||||||
|
logger.info('Evaluate instance segmentation')
|
||||||
|
scannet_eval = ScanNetEval(val_set.CLASSES)
|
||||||
|
eval_res = scannet_eval.evaluate(all_pred_insts, all_gt_insts)
|
||||||
|
writer.add_scalar('val/AP', eval_res['all_ap'], epoch)
|
||||||
|
writer.add_scalar('val/AP_50', eval_res['all_ap_50%'], epoch)
|
||||||
|
writer.add_scalar('val/AP_25', eval_res['all_ap_25%'], epoch)
|
||||||
|
logger.info('Evaluate semantic segmentation and offset MAE')
|
||||||
|
miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label, logger)
|
||||||
|
acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label, logger)
|
||||||
|
mae = evaluate_offset_mae(all_offset_preds, all_offset_labels, all_inst_labels,
|
||||||
|
cfg.model.ignore_label, logger)
|
||||||
|
writer.add_scalar('val/mIoU', miou, epoch)
|
||||||
|
writer.add_scalar('val/Acc', acc, epoch)
|
||||||
|
writer.add_scalar('val/Offset MAE', mae, epoch)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
cfg_txt = open(args.config, 'r').read()
|
cfg_txt = open(args.config, 'r').read()
|
||||||
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
|
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
|
||||||
|
|
||||||
if args.dist:
|
if args.dist:
|
||||||
init_dist()
|
init_dist()
|
||||||
|
cfg.dist = args.dist
|
||||||
|
|
||||||
# work_dir & logger
|
# work_dir & logger
|
||||||
if args.work_dir:
|
if args.work_dir:
|
||||||
@ -81,96 +175,11 @@ if __name__ == '__main__':
|
|||||||
# train and val
|
# train and val
|
||||||
logger.info('Training')
|
logger.info('Training')
|
||||||
for epoch in range(start_epoch, cfg.epochs + 1):
|
for epoch in range(start_epoch, cfg.epochs + 1):
|
||||||
model.train()
|
train(epoch, model, optimizer, scaler, train_loader, cfg, logger, writer)
|
||||||
iter_time = AverageMeter(True)
|
if not args.skip_validate and (is_multiple(epoch, cfg.save_freq) or is_power2(epoch)):
|
||||||
data_time = AverageMeter(True)
|
validate(epoch, model, val_loader, cfg, logger, writer)
|
||||||
meter_dict = {}
|
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
if train_loader.sampler is not None and args.dist:
|
|
||||||
train_loader.sampler.set_epoch(epoch)
|
|
||||||
|
|
||||||
for i, batch in enumerate(train_loader, start=1):
|
|
||||||
data_time.update(time.time() - end)
|
|
||||||
|
|
||||||
cosine_lr_after_step(optimizer, cfg.optimizer.lr, epoch - 1, cfg.step_epoch, cfg.epochs)
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=cfg.fp16):
|
|
||||||
loss, log_vars = model(batch, return_loss=True)
|
|
||||||
|
|
||||||
# meter_dict
|
|
||||||
for k, v in log_vars.items():
|
|
||||||
if k not in meter_dict.keys():
|
|
||||||
meter_dict[k] = AverageMeter()
|
|
||||||
meter_dict[k].update(v)
|
|
||||||
|
|
||||||
# backward
|
|
||||||
optimizer.zero_grad()
|
|
||||||
scaler.scale(loss).backward()
|
|
||||||
scaler.step(optimizer)
|
|
||||||
scaler.update()
|
|
||||||
|
|
||||||
# time and print
|
|
||||||
remain_iter = len(train_loader) * (cfg.epochs - epoch + 1) - i
|
|
||||||
iter_time.update(time.time() - end)
|
|
||||||
end = time.time()
|
|
||||||
remain_time = remain_iter * iter_time.avg
|
|
||||||
remain_time = str(datetime.timedelta(seconds=int(remain_time)))
|
|
||||||
lr = optimizer.param_groups[0]['lr']
|
|
||||||
|
|
||||||
if is_multiple(i, 10):
|
|
||||||
log_str = f'Epoch [{epoch}/{cfg.epochs}][{i}/{len(train_loader)}] '
|
|
||||||
log_str += f'lr: {lr:.2g}, eta: {remain_time}, mem: {get_max_memory()}, '\
|
|
||||||
f'data_time: {data_time.val:.2f}, iter_time: {iter_time.val:.2f}'
|
|
||||||
for k, v in meter_dict.items():
|
|
||||||
log_str += f', {k}: {v.val:.4f}'
|
|
||||||
logger.info(log_str)
|
|
||||||
writer.add_scalar('train/learning_rate', lr, epoch)
|
|
||||||
for k, v in meter_dict.items():
|
|
||||||
writer.add_scalar(f'train/{k}', v.avg, epoch)
|
|
||||||
checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq)
|
|
||||||
|
|
||||||
# validation
|
|
||||||
if is_multiple(epoch, cfg.save_freq) or is_power2(epoch):
|
|
||||||
logger.info('Validation')
|
|
||||||
results = []
|
|
||||||
all_sem_preds, all_sem_labels, all_offset_preds, all_offset_labels = [], [], [], []
|
|
||||||
all_inst_labels, all_pred_insts, all_gt_insts = [], [], []
|
|
||||||
_, world_size = get_dist_info()
|
|
||||||
progress_bar = tqdm(total=len(val_loader) * world_size, disable=not is_main_process())
|
|
||||||
with torch.no_grad():
|
|
||||||
model = model.eval()
|
|
||||||
for i, batch in enumerate(val_loader):
|
|
||||||
result = model(batch)
|
|
||||||
results.append(result)
|
|
||||||
progress_bar.update(world_size)
|
|
||||||
progress_bar.close()
|
|
||||||
results = collect_results_gpu(results, len(val_set))
|
|
||||||
if is_main_process():
|
|
||||||
for res in results:
|
|
||||||
all_sem_preds.append(res['semantic_preds'])
|
|
||||||
all_sem_labels.append(res['semantic_labels'])
|
|
||||||
all_offset_preds.append(res['offset_preds'])
|
|
||||||
all_offset_labels.append(res['offset_labels'])
|
|
||||||
all_inst_labels.append(res['instance_labels'])
|
|
||||||
if not cfg.model.semantic_only:
|
|
||||||
all_pred_insts.append(res['pred_instances'])
|
|
||||||
all_gt_insts.append(res['gt_instances'])
|
|
||||||
if not cfg.model.semantic_only:
|
|
||||||
logger.info('Evaluate instance segmentation')
|
|
||||||
scannet_eval = ScanNetEval(val_set.CLASSES)
|
|
||||||
eval_res = scannet_eval.evaluate(all_pred_insts, all_gt_insts)
|
|
||||||
writer.add_scalar('val/AP', eval_res['all_ap'], epoch)
|
|
||||||
writer.add_scalar('val/AP_50', eval_res['all_ap_50%'], epoch)
|
|
||||||
writer.add_scalar('val/AP_25', eval_res['all_ap_25%'], epoch)
|
|
||||||
logger.info('Evaluate semantic segmentation and offset MAE')
|
|
||||||
miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
|
|
||||||
logger)
|
|
||||||
acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
|
|
||||||
logger)
|
|
||||||
mae = evaluate_offset_mae(all_offset_preds, all_offset_labels, all_inst_labels,
|
|
||||||
cfg.model.ignore_label, logger)
|
|
||||||
writer.add_scalar('val/mIoU', miou, epoch)
|
|
||||||
writer.add_scalar('val/Acc', acc, epoch)
|
|
||||||
writer.add_scalar('val/Offset MAE', mae, epoch)
|
|
||||||
writer.flush()
|
writer.flush()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user