mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
10 lines
296 B
Python
10 lines
296 B
Python
import torch.optim
|
|
|
|
|
|
def build_optimizer(model, optim_cfg):
|
|
assert 'type' in optim_cfg
|
|
_optim_cfg = optim_cfg.copy()
|
|
optim_type = _optim_cfg.pop('type')
|
|
optim = getattr(torch.optim, optim_type)
|
|
return optim(filter(lambda p: p.requires_grad, model.parameters()), **_optim_cfg)
|