mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
fix load checkpoint ddp
This commit is contained in:
parent
40572e2834
commit
c68be64a35
@ -86,7 +86,10 @@ def checkpoint_save(epoch, model, optimizer, work_dir, save_freq=16):
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint, logger, model, optimizer=None, strict=False):
|
||||
state_dict = torch.load(checkpoint)
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
device = torch.cuda.current_device()
|
||||
state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage.cuda(device))
|
||||
src_state_dict = state_dict['net']
|
||||
target_state_dict = model.state_dict()
|
||||
skip_keys = []
|
||||
|
||||
Loading…
Reference in New Issue
Block a user