diff --git a/softgroup/util/utils.py b/softgroup/util/utils.py index 8557ffc..81374b5 100644 --- a/softgroup/util/utils.py +++ b/softgroup/util/utils.py @@ -86,7 +86,10 @@ def checkpoint_save(epoch, model, optimizer, work_dir, save_freq=16): def load_checkpoint(checkpoint, logger, model, optimizer=None, strict=False): - state_dict = torch.load(checkpoint) + if hasattr(model, 'module'): + model = model.module + device = torch.cuda.current_device() + state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage.cuda(device)) src_state_dict = state_dict['net'] target_state_dict = model.state_dict() skip_keys = []