fix load checkpoint ddp

This commit is contained in:
Thang Vu 2022-04-11 02:00:22 +00:00
parent 40572e2834
commit c68be64a35

View File

@ -86,7 +86,10 @@ def checkpoint_save(epoch, model, optimizer, work_dir, save_freq=16):
def load_checkpoint(checkpoint, logger, model, optimizer=None, strict=False):
state_dict = torch.load(checkpoint)
if hasattr(model, 'module'):
model = model.module
device = torch.cuda.current_device()
state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage.cuda(device))
src_state_dict = state_dict['net']
target_state_dict = model.state_dict()
skip_keys = []