From c68be64a35580343e42d844f52c6c9bd8b068171 Mon Sep 17 00:00:00 2001 From: Thang Vu Date: Mon, 11 Apr 2022 02:00:22 +0000 Subject: [PATCH] fix load checkpoint ddp --- softgroup/util/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/softgroup/util/utils.py b/softgroup/util/utils.py index 8557ffc..81374b5 100644 --- a/softgroup/util/utils.py +++ b/softgroup/util/utils.py @@ -86,7 +86,10 @@ def checkpoint_save(epoch, model, optimizer, work_dir, save_freq=16): def load_checkpoint(checkpoint, logger, model, optimizer=None, strict=False): - state_dict = torch.load(checkpoint) + if hasattr(model, 'module'): + model = model.module + device = torch.cuda.current_device() + state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage.cuda(device)) src_state_dict = state_dict['net'] target_state_dict = model.state_dict() skip_keys = []