mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
152 lines
5.5 KiB
Python
152 lines
5.5 KiB
Python
import argparse
|
|
import datetime
|
|
import numpy as np
|
|
import os
|
|
import os.path as osp
|
|
import random
|
|
import shutil
|
|
import sys
|
|
import time
|
|
import torch
|
|
import yaml
|
|
from data import build_dataloader, build_dataset
|
|
from model.softgroup import SoftGroup
|
|
from munch import Munch
|
|
from tensorboardX import SummaryWriter
|
|
|
|
from softgroup.util import (AverageMeter, build_optimizer, checkpoint_save, cosine_lr_after_step,
|
|
get_max_memory, get_root_logger, load_checkpoint)
|
|
|
|
|
|
def eval_epoch(val_loader, model, model_fn, epoch):
|
|
logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
|
|
am_dict = {}
|
|
|
|
with torch.no_grad():
|
|
model.eval()
|
|
start_epoch = time.time()
|
|
for i, batch in enumerate(val_loader):
|
|
|
|
# prepare input and forward
|
|
loss, preds, visual_dict, meter_dict = model_fn(
|
|
batch, model, epoch, semantic_only=cfg.semantic_only)
|
|
|
|
for k, v in meter_dict.items():
|
|
if k not in am_dict.keys():
|
|
am_dict[k] = AverageMeter()
|
|
am_dict[k].update(v[0], v[1])
|
|
sys.stdout.write("\riter: {}/{} loss: {:.4f}({:.4f})".format(
|
|
i + 1, len(val_loader), am_dict['loss'].val, am_dict['loss'].avg))
|
|
|
|
logger.info("epoch: {}/{}, val loss: {:.4f}, time: {}s".format(
|
|
epoch, cfg.epochs, am_dict['loss'].avg,
|
|
time.time() - start_epoch))
|
|
|
|
for k in am_dict.keys():
|
|
if k in visual_dict.keys():
|
|
writer.add_scalar(k + '_eval', am_dict[k].avg, epoch)
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser('SoftGroup')
|
|
parser.add_argument('config', type=str, help='path to config file')
|
|
parser.add_argument('--resume', type=str, help='path to resume from')
|
|
parser.add_argument('--work_dir', type=str, help='working directory')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# TODO remove these setup
|
|
torch.backends.cudnn.enabled = False
|
|
test_seed = 123
|
|
random.seed(test_seed)
|
|
np.random.seed(test_seed)
|
|
torch.manual_seed(test_seed)
|
|
torch.cuda.manual_seed_all(test_seed)
|
|
|
|
args = get_args()
|
|
cfg_txt = open(args.config, 'r').read()
|
|
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
|
|
|
|
# work_dir & logger
|
|
if args.work_dir is not None:
|
|
cfg.work_dir = args.work_dir
|
|
elif cfg.get('work_dir', None) is None:
|
|
cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
|
|
os.makedirs(osp.abspath(cfg.work_dir), exist_ok=True)
|
|
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
|
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
|
|
logger = get_root_logger(log_file=log_file)
|
|
logger.info(f'Config:\n{cfg_txt}')
|
|
shutil.copy(args.config, osp.join(cfg.work_dir, osp.basename(args.config)))
|
|
writer = SummaryWriter(cfg.work_dir)
|
|
|
|
# model
|
|
model = SoftGroup(**cfg.model).cuda()
|
|
|
|
# data
|
|
train_set = build_dataset(cfg.data.train, logger)
|
|
val_set = build_dataset(cfg.data.test, logger)
|
|
train_loader = build_dataloader(train_set, training=True, **cfg.dataloader.train)
|
|
val_loader = build_dataloader(val_set, training=False, **cfg.dataloader.test)
|
|
|
|
# optim
|
|
optimizer = build_optimizer(model, cfg.optimizer)
|
|
|
|
# pretrain, resume
|
|
start_epoch = 1
|
|
if args.resume:
|
|
logger.info(f'Resume from {args.resume}')
|
|
start_epoch = load_checkpoint(args.resume, logger, model, optimizer=optimizer)
|
|
elif cfg.pretrain:
|
|
logger.info(f'Load pretrain from {cfg.pretrain}')
|
|
load_checkpoint(cfg.pretrain, logger, model)
|
|
|
|
# train and val
|
|
logger.info('Training')
|
|
for epoch in range(start_epoch, cfg.epochs + 1):
|
|
model.train()
|
|
iter_time = AverageMeter()
|
|
data_time = AverageMeter()
|
|
meter_dict = {}
|
|
end = time.time()
|
|
|
|
for i, batch in enumerate(train_loader, start=1):
|
|
data_time.update(time.time() - end)
|
|
|
|
cosine_lr_after_step(optimizer, cfg.optimizer.lr, epoch - 1, cfg.step_epoch, cfg.epochs)
|
|
loss, log_vars = model(batch, return_loss=True)
|
|
|
|
# meter_dict
|
|
for k, v in log_vars.items():
|
|
if k not in meter_dict.keys():
|
|
meter_dict[k] = AverageMeter()
|
|
meter_dict[k].update(v[0], v[1])
|
|
|
|
# backward
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# time and print
|
|
current_iter = (epoch - 1) * len(train_loader) + i
|
|
max_iter = cfg.epochs * len(train_loader)
|
|
remain_iter = max_iter - current_iter
|
|
iter_time.update(time.time() - end)
|
|
end = time.time()
|
|
remain_time = remain_iter * iter_time.avg
|
|
remain_time = str(datetime.timedelta(seconds=int(remain_time)))
|
|
lr = optimizer.param_groups[0]['lr']
|
|
writer.add_scalar('learning_rate', lr, current_iter)
|
|
for k, v in meter_dict.items():
|
|
writer.add_scalar(k, v.val, current_iter)
|
|
if i % 10 == 0:
|
|
log_str = f'Epoch [{epoch}/{cfg.epochs}][{i}/{len(train_loader)}] '
|
|
log_str += f'lr: {lr:.2g}, eta: {remain_time}, mem: {get_max_memory()}, '\
|
|
f'data_time: {data_time.val:.2f}, iter_time: {iter_time.val:.2f}'
|
|
for k, v in meter_dict.items():
|
|
log_str += f', {k}: {v.val:.4f}'
|
|
logger.info(log_str)
|
|
checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq)
|