SoftGroup/softgroup/util/utils.py
2022-04-08 11:10:38 +00:00

118 lines
3.4 KiB
Python

import os
import os.path as osp
import torch
from collections import OrderedDict
from math import cos, pi
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
# Epoch counts from 0 to N-1
def cosine_lr_after_step(optimizer, base_lr, epoch, step_epoch, total_epochs, clip=1e-6):
if epoch < step_epoch:
lr = base_lr
else:
lr = clip + 0.5 * (base_lr - clip) * \
(1 + cos(pi * ((epoch - step_epoch) / (total_epochs - step_epoch))))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def is_power2(num):
return num != 0 and ((num & (num - 1)) == 0)
def is_multiple(num, multiple):
return num != 0 and num % multiple == 0
def weights_to_cpu(state_dict):
"""Copy a model state_dict to cpu.
Args:
state_dict (OrderedDict): Model weights on GPU.
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
return state_dict_cpu
def checkpoint_save(epoch, model, optimizer, work_dir, save_freq=16):
f = os.path.join(work_dir, f'epoch_{epoch}.pth')
checkpoint = {
'net': weights_to_cpu(model.state_dict()),
'optimizer': optimizer.state_dict(),
'epoch': epoch
}
torch.save(checkpoint, f)
if os.path.exists(f'{work_dir}/latest.pth'):
os.remove(f'{work_dir}/latest.pth')
os.system(f'cd {work_dir}; ln -s {osp.basename(f)} latest.pth')
# remove previous checkpoints unless they are a power of 2 or a multiple of save_freq
epoch = epoch - 1
f = os.path.join(work_dir, f'epoch_{epoch}.pth')
if os.path.isfile(f):
if not is_multiple(epoch, save_freq) and not is_power2(epoch):
os.remove(f)
def load_checkpoint(checkpoint, logger, model, optimizer=None, strict=False):
state_dict = torch.load(checkpoint)
src_state_dict = state_dict['net']
target_state_dict = model.state_dict()
skip_keys = []
# skip mismatch size tensors in case of pretraining
for k in src_state_dict.keys():
if k not in target_state_dict:
continue
if src_state_dict[k].size() != target_state_dict[k].size():
skip_keys.append(k)
for k in skip_keys:
del src_state_dict[k]
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=strict)
if skip_keys:
logger.info(
f'removed keys in source state_dict due to size mismatch: {", ".join(skip_keys)}')
if missing_keys:
logger.info(f'missing keys in source state_dict: {", ".join(missing_keys)}')
if unexpected_keys:
logger.info(f'unexpected key in source state_dict: {", ".join(unexpected_keys)}')
# load optimizer
if optimizer is not None:
assert 'optimizer' in state_dict
optimizer.load_state_dict(state_dict['optimizer'])
if 'epoch' in state_dict:
epoch = state_dict['epoch']
else:
epoch = 0
return epoch + 1
def get_max_memory():
mem = torch.cuda.max_memory_allocated()
mem_mb = torch.tensor([int(mem) // (1024 * 1024)], dtype=torch.int)
return mem_mb.item()