refactor train point wise net

This commit is contained in:
Thang Vu 2022-04-07 07:38:38 +00:00
parent e5e813ab67
commit af2c982653
7 changed files with 265 additions and 94 deletions

View File

@ -31,12 +31,13 @@ model:
data:
train:
data_root: 'dataset/s3dis/preprocess'
prefix: 'val'
prefix: ['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6']
suffix: '_inst_nostuff.pth'
voxel_cfg:
scale: 50
spatial_shape: [128, 512]
max_npoint: 250000
min_npoint: 5000
test:
data_root: 'dataset/s3dis/preprocess'
prefix: 'Area_5'
@ -45,10 +46,17 @@ data:
scale: 50
spatial_shape: [128, 512]
max_npoint: 250000
min_npoint: 5000
data_loader:
batch_size: 4
num_workers: 4
optim:
lr: 0.001
pretrain: 'hais_ckpt.pth'
resume: ''
DATA:
data_root: dataset

4
data/__init__.py Normal file
View File

@ -0,0 +1,4 @@
from .s3dis import S3DISDataset
from .scannetv2 import ScanNetDataset
__all__ = ['S3DISDataset', 'ScanNetDataset']

View File

@ -141,10 +141,18 @@ class CustomDataset(Dataset):
def transform_train(self, xyz, rgb, label, instance_label):
xyz_middle = self.dataAugment(xyz, True, True, True)
xyz = xyz_middle * self.voxel_cfg.scale
xyz = self.elastic(xyz, 6 * self.scale // 50, 40 * self.scale / 50)
xyz = self.elastic(xyz, 20 * self.scale // 50, 160 * self.scale / 50)
xyz = self.elastic(xyz, 6 * self.voxel_cfg.scale // 50, 40 * self.voxel_cfg.scale / 50)
xyz = self.elastic(xyz, 20 * self.voxel_cfg.scale // 50, 160 * self.voxel_cfg.scale / 50)
xyz -= xyz.min(0)
xyz, valid_idxs = self.crop(xyz)
max_tries = 5
while (max_tries > 0):
xyz_offset, valid_idxs = self.crop(xyz)
if valid_idxs.sum() >= self.voxel_cfg.min_npoint:
xyz = xyz_offset
break
max_tries -= 1
if valid_idxs.sum() < self.voxel_cfg.min_npoint:
return None
xyz = xyz[valid_idxs]
xyz_middle = xyz_middle[valid_idxs]
rgb = rgb[valid_idxs]
@ -173,7 +181,7 @@ class CustomDataset(Dataset):
inst_cls = inst_infos["instance_cls"]
loc = torch.from_numpy(xyz).long()
loc_float = torch.from_numpy(xyz_middle)
feat = torch.from_numpy(rgb)
feat = torch.from_numpy(rgb).float()
if self.training:
feat += torch.randn(3) * 0.1
label = torch.from_numpy(label)
@ -197,26 +205,28 @@ class CustomDataset(Dataset):
batch_offsets = [0]
total_inst_num = 0
for i, data in enumerate(batch):
batch_id = 0
for data in batch:
if data is None:
continue
(scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_info,
inst_pointnum, inst_cls) = data
instance_label[np.where(instance_label != -100)] += total_inst_num
total_inst_num += inst_num
# merge the scene to the batch
batch_offsets.append(batch_offsets[-1] + loc.size(0))
scan_ids.append(scan_id)
locs.append(torch.cat([loc.new_full((loc.size(0), 1), i), loc], 1))
locs.append(torch.cat([loc.new_full((loc.size(0), 1), batch_id), loc], 1))
locs_float.append(loc_float)
feats.append(feat)
labels.append(label)
instance_labels.append(instance_label)
instance_infos.append(inst_info)
instance_pointnum.extend(inst_pointnum)
instance_cls.extend(inst_cls)
batch_id += 1
assert batch_id > 0, 'empty batch'
if batch_id < len(batch):
print(f'batch is truncated from size {len(batch)} to {batch_id}')
# merge all the scenes in the batch
batch_offsets = torch.tensor(batch_offsets, dtype=torch.int) # int (B+1)
@ -226,18 +236,14 @@ class CustomDataset(Dataset):
feats = torch.cat(feats, 0) # float (N, C)
labels = torch.cat(labels, 0).long() # long (N)
instance_labels = torch.cat(instance_labels, 0).long() # long (N)
instance_infos = torch.cat(instance_infos,
0).to(torch.float32) # float (N, 9) (meanxyz, minxyz, maxxyz)
instance_pointnum = torch.tensor(instance_pointnum, dtype=torch.int) # int (total_nInst)
instance_cls = torch.tensor(instance_cls, dtype=torch.long) # long (total_nInst)
spatial_shape = np.clip((locs.max(0)[0][1:] + 1).numpy(), self.voxel_cfg.spatial_shape[0],
None) # long (3)
# voxelize
spatial_shape = np.clip(
locs.max(0)[0][1:].numpy() + 1, self.voxel_cfg.spatial_shape[0], None)
voxel_locs, p2v_map, v2p_map = softgroup_ops.voxelization_idx(locs, 1)
return {
'scan_ids': scan_ids,
'locs': locs,
@ -252,5 +258,6 @@ class CustomDataset(Dataset):
'instance_pointnum': instance_pointnum,
'instance_cls': instance_cls,
'offsets': batch_offsets,
'spatial_shape': spatial_shape
'spatial_shape': spatial_shape,
'batch_size': batch_id,
}

View File

@ -16,9 +16,15 @@ class S3DISDataset(CustomDataset):
"bookcase", "sofa", "board", "clutter")
def get_filenames(self):
filenames = sorted(glob(osp.join(self.data_root, self.prefix + '*' + self.suffix)))
assert len(filenames) > 0, 'Empty dataset.'
return filenames
if isinstance(self.prefix, str):
self.prefix = [self.prefix]
filenames_all = []
for p in self.prefix:
filenames = glob(osp.join(self.data_root, p + '*' + self.suffix))
assert len(filenames) > 0, f'Empty {p}'
filenames_all.extend(filenames)
filenames_all.sort()
return filenames_all
def load(self, filename):
# TODO make file load results consistent
@ -35,18 +41,19 @@ class S3DISDataset(CustomDataset):
def crop(self, xyz, step=64):
xyz_offset = xyz.copy()
valid_idxs = (xyz_offset.min(1) >= 0) * ((xyz < self.full_scale[1]).sum(1) == 3)
valid_idxs = (xyz_offset.min(1) >= 0) * (
(xyz < self.voxel_cfg.spatial_shape[1]).sum(1) == 3)
full_scale = np.array([self.full_scale[1]] * 3)
spatial_shape = np.array([self.voxel_cfg.spatial_shape[1]] * 3)
room_range = xyz.max(0) - xyz.min(0)
while (valid_idxs.sum() > self.max_npoint):
while (valid_idxs.sum() > self.voxel_cfg.max_npoint):
step_temp = step
if valid_idxs.sum() > 1e6:
step_temp = step * 2
offset = np.clip(full_scale - room_range + 0.001, None, 0) * np.random.rand(3)
offset = np.clip(spatial_shape - room_range + 0.001, None, 0) * np.random.rand(3)
xyz_offset = xyz + offset
valid_idxs = (xyz_offset.min(1) >= 0) * ((xyz_offset < full_scale).sum(1) == 3)
full_scale[:2] -= step_temp
valid_idxs = (xyz_offset.min(1) >= 0) * ((xyz_offset < spatial_shape).sum(1) == 3)
spatial_shape[:2] -= step_temp
return xyz_offset, valid_idxs

View File

@ -6,29 +6,36 @@ import torch
class ResidualBlock(SparseModule):
def __init__(self, in_channels, out_channels, norm_fn, indice_key=None):
super().__init__()
if in_channels == out_channels:
self.i_branch = spconv.SparseSequential(
nn.Identity()
)
self.i_branch = spconv.SparseSequential(nn.Identity())
else:
self.i_branch = spconv.SparseSequential(
spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, bias=False)
)
spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, bias=False))
self.conv_branch = spconv.SparseSequential(
norm_fn(in_channels),
nn.ReLU(),
spconv.SubMConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False, indice_key=indice_key),
norm_fn(out_channels),
nn.ReLU(),
spconv.SubMConv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False, indice_key=indice_key)
)
norm_fn(in_channels), nn.ReLU(),
spconv.SubMConv3d(
in_channels,
out_channels,
kernel_size=3,
padding=1,
bias=False,
indice_key=indice_key), norm_fn(out_channels), nn.ReLU(),
spconv.SubMConv3d(
out_channels,
out_channels,
kernel_size=3,
padding=1,
bias=False,
indice_key=indice_key))
def forward(self, input):
identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape, input.batch_size)
identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape,
input.batch_size)
output = self.conv_branch(input)
output.features += self.i_branch(identity).features
@ -36,41 +43,59 @@ class ResidualBlock(SparseModule):
class UBlock(nn.Module):
def __init__(self, nPlanes, norm_fn, block_reps, block, indice_key_id=1):
super().__init__()
self.nPlanes = nPlanes
blocks = {'block{}'.format(i): block(nPlanes[0], nPlanes[0], norm_fn, indice_key='subm{}'.format(indice_key_id)) for i in range(block_reps)}
blocks = {
'block{}'.format(i):
block(nPlanes[0], nPlanes[0], norm_fn, indice_key='subm{}'.format(indice_key_id))
for i in range(block_reps)
}
blocks = OrderedDict(blocks)
self.blocks = spconv.SparseSequential(blocks)
if len(nPlanes) > 1:
self.conv = spconv.SparseSequential(
norm_fn(nPlanes[0]),
nn.ReLU(),
spconv.SparseConv3d(nPlanes[0], nPlanes[1], kernel_size=2, stride=2, bias=False, indice_key='spconv{}'.format(indice_key_id))
)
norm_fn(nPlanes[0]), nn.ReLU(),
spconv.SparseConv3d(
nPlanes[0],
nPlanes[1],
kernel_size=2,
stride=2,
bias=False,
indice_key='spconv{}'.format(indice_key_id)))
self.u = UBlock(nPlanes[1:], norm_fn, block_reps, block, indice_key_id=indice_key_id+1)
self.u = UBlock(
nPlanes[1:], norm_fn, block_reps, block, indice_key_id=indice_key_id + 1)
self.deconv = spconv.SparseSequential(
norm_fn(nPlanes[1]),
nn.ReLU(),
spconv.SparseInverseConv3d(nPlanes[1], nPlanes[0], kernel_size=2, bias=False, indice_key='spconv{}'.format(indice_key_id))
)
norm_fn(nPlanes[1]), nn.ReLU(),
spconv.SparseInverseConv3d(
nPlanes[1],
nPlanes[0],
kernel_size=2,
bias=False,
indice_key='spconv{}'.format(indice_key_id)))
blocks_tail = {}
for i in range(block_reps):
blocks_tail['block{}'.format(i)] = block(nPlanes[0] * (2 - i), nPlanes[0], norm_fn, indice_key='subm{}'.format(indice_key_id))
blocks_tail['block{}'.format(i)] = block(
nPlanes[0] * (2 - i),
nPlanes[0],
norm_fn,
indice_key='subm{}'.format(indice_key_id))
blocks_tail = OrderedDict(blocks_tail)
self.blocks_tail = spconv.SparseSequential(blocks_tail)
def forward(self, input):
output = self.blocks(input)
identity = spconv.SparseConvTensor(output.features, output.indices, output.spatial_shape, output.batch_size)
identity = spconv.SparseConvTensor(output.features, output.indices, output.spatial_shape,
output.batch_size)
if len(self.nPlanes) > 1:
output_decoder = self.conv(output)
output_decoder = self.u(output_decoder)

View File

@ -25,8 +25,7 @@ class SoftGroup(nn.Module):
grouping_cfg=None,
instance_voxel_cfg=None,
test_cfg=None,
fixed_modules=[],
pretrained=None):
fixed_modules=[]):
super().__init__()
self.channels = channels
self.num_blocks = num_blocks
@ -77,6 +76,9 @@ class SoftGroup(nn.Module):
nn.Linear(channels, channels), nn.ReLU(), nn.Linear(channels, instance_classes + 1))
self.score_linear = nn.Linear(channels, instance_classes + 1)
self.semantic_loss = nn.CrossEntropyLoss(ignore_index=ignore_label)
self.offset_loss = nn.L1Loss(reduction='sum')
self.apply(self.set_bn_init)
nn.init.normal_(self.score_linear.weight, 0, 0.01)
nn.init.constant_(self.score_linear.bias, 0)
@ -100,6 +102,52 @@ class SoftGroup(nn.Module):
else:
return self.forward_test(batch)
def forward_train(self, batch):
coords = batch['locs'].cuda()
voxel_coords = batch['voxel_locs'].cuda()
p2v_map = batch['p2v_map'].cuda()
v2p_map = batch['v2p_map'].cuda()
coords_float = batch['locs_float'].cuda()
feats = batch['feats'].cuda()
semantic_labels = batch['labels'].cuda()
instance_labels = batch['instance_labels'].cuda()
instance_info = batch['instance_info'].cuda()
# instance_pointnum = batch['instance_pointnum'].cuda()
# instance_cls = batch['instance_cls'].cuda()
# batch_offsets = batch['offsets'].cuda()
spatial_shape = batch['spatial_shape']
batch_size = batch['batch_size']
feats = torch.cat((feats, coords_float), 1)
voxel_feats = softgroup_ops.voxelization(feats, v2p_map)
losses = {}
pt_offset_labels = instance_info[:, :3] - coords_float
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
semantic_scores, pt_offsets, output_feats, coords_float = self.forward_backbone(
input, p2v_map, coords_float) # TODO check name for map
point_wise_loss = self.point_wise_loss(semantic_scores, pt_offsets, semantic_labels,
instance_labels, pt_offset_labels)
losses.update(point_wise_loss)
loss = sum(v[0] for v in losses.values())
losses['loss'] = (loss, coords.size(0))
return loss, losses
def point_wise_loss(self, semantic_scores, pt_offsets, semantic_labels, instance_labels,
pt_offset_labels):
losses = {}
semantic_loss = self.semantic_loss(semantic_scores, semantic_labels)
losses['semantic_loss'] = (semantic_loss, semantic_scores.size(0))
pos_inds = instance_labels != self.ignore_label
if pos_inds.sum() == 0:
offset_loss = 0 * pt_offset.sum()
else:
offset_loss = self.offset_loss(pt_offsets[pos_inds],
pt_offset_labels[pos_inds]) / pos_inds.sum()
losses['offset_loss'] = (offset_loss, pos_inds.sum())
return losses
def forward_test(self, batch):
coords = batch['locs'].cuda()
voxel_coords = batch['voxel_locs'].cuda()
@ -124,7 +172,8 @@ class SoftGroup(nn.Module):
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, 1)
batch_idxs = coords[:, 0].int()
semantic_scores, pt_offsets, output_feats, coords_float = self.forward_backbone(
input, p2v_map, coords_float, x4_split=self.test_cfg.x4_split) # TODO check name for map
input, p2v_map, coords_float,
x4_split=self.test_cfg.x4_split) # TODO check name for map
proposals_idx, proposals_offset = self.forward_grouping(semantic_scores, pt_offsets,
batch_idxs, coords_float,
self.grouping_cfg)
@ -150,7 +199,6 @@ class SoftGroup(nn.Module):
output_feats = output.features[input_map.long()]
semantic_scores = self.semantic_linear(output_feats)
semantic_scores = semantic_scores.softmax(dim=-1)
pt_offsets = self.offset_linear(output_feats)
return semantic_scores, pt_offsets, output_feats, coords
@ -194,6 +242,7 @@ class SoftGroup(nn.Module):
proposals_idx_list = []
proposals_offset_list = []
batch_size = batch_idxs.max() + 1
semantic_scores = semantic_scores.softmax(dim=-1)
semantic_preds = semantic_scores.max(1)[1] # TODO remove this
radius = self.grouping_cfg.radius

145
train.py
View File

@ -7,6 +7,13 @@ import numpy as np
from util.config import cfg
import torch.distributed as dist
import argparse
from munch import Munch
import yaml
from model.softgroup import SoftGroup
from util import utils
from data import S3DISDataset
from torch.utils.data import DataLoader
def init():
@ -31,8 +38,9 @@ def init():
torch.manual_seed(cfg.manual_seed)
torch.cuda.manual_seed_all(cfg.manual_seed)
# epoch counts from 1 to N
def train_epoch(train_loader, model, model_fn, optimizer, epoch):
def train_epoch(epoch, train_loader, model, optimizer):
iter_time = utils.AverageMeter()
data_time = utils.AverageMeter()
am_dict = {}
@ -41,8 +49,8 @@ def train_epoch(train_loader, model, model_fn, optimizer, epoch):
start_epoch = time.time()
end = time.time()
if train_loader.sampler is not None and cfg.dist == True:
train_loader.sampler.set_epoch(epoch)
# if train_loader.sampler is not None and cfg.dist == True:
# train_loader.sampler.set_epoch(epoch)
for i, batch in enumerate(train_loader):
@ -53,15 +61,16 @@ def train_epoch(train_loader, model, model_fn, optimizer, epoch):
data_time.update(time.time() - end)
torch.cuda.empty_cache()
loss, log_vars = model(batch, return_loss=True)
# adjust learning rate
utils.cosine_lr_after_step(optimizer, cfg.lr, epoch - 1, cfg.step_epoch, cfg.epochs)
# utils.cosine_lr_after_step(optimizer, cfg.lr, epoch - 1, cfg.step_epoch, cfg.epochs)
# prepare input and forward
loss, _, visual_dict, meter_dict = model_fn(batch, model, epoch, semantic_only=cfg.semantic_only)
# loss, _, visual_dict, meter_dict = model_fn(batch, model, epoch, semantic_only=cfg.semantic_only)
# meter_dict
for k, v in meter_dict.items():
for k, v in log_vars.items():
if k not in am_dict.keys():
am_dict[k] = utils.AverageMeter()
am_dict[k].update(v[0], v[1])
@ -73,7 +82,7 @@ def train_epoch(train_loader, model, model_fn, optimizer, epoch):
# time and print
current_iter = (epoch - 1) * len(train_loader) + i + 1
max_iter = cfg.epochs * len(train_loader)
max_iter = 500 * len(train_loader)
remain_iter = max_iter - current_iter
iter_time.update(time.time() - end)
@ -84,24 +93,38 @@ def train_epoch(train_loader, model, model_fn, optimizer, epoch):
t_h, t_m = divmod(t_m, 60)
remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))
if cfg.local_rank == 0 and i % 10 == 0:
if i % 1 == 0:
lr = optimizer.param_groups[0]['lr']
sys.stdout.write(
"epoch: {}/{} iter: {}/{} lr: {:.5f} loss: {:.4f}({:.4f}) data_time: {:.2f}({:.2f}) iter_time: {:.2f}({:.2f}) remain_time: {remain_time}\n".format
(epoch, cfg.epochs, i + 1, len(train_loader), lr, am_dict['loss'].val, am_dict['loss'].avg,
data_time.val, data_time.avg, iter_time.val, iter_time.avg, remain_time=remain_time))
log_str = "epoch: {}/{} iter: {}/{} lr: {:.5f} loss: {:.4f}({:.4f}) data_time: {:.2f}({:.2f}) iter_time: {:.2f}({:.2f}) remain_time: {remain_time}".format(
epoch,
500,
i + 1,
len(train_loader),
lr,
am_dict['loss'].val,
am_dict['loss'].avg,
data_time.val,
data_time.avg,
iter_time.val,
iter_time.avg,
remain_time=remain_time)
for k, v in am_dict.items():
log_str += f' {k}: {v.avg:.4f}'
print(log_str)
if (i == len(train_loader) - 1): print()
logger.info("epoch: {}/{}, train loss: {:.4f}, time: {}s".format(epoch, cfg.epochs, am_dict['loss'].avg, time.time() - start_epoch))
logger.info("epoch: {}/{}, train loss: {:.4f}, time: {}s".format(epoch, cfg.epochs,
am_dict['loss'].avg,
time.time() - start_epoch))
if cfg.local_rank == 0:
utils.checkpoint_save(model, optimizer, cfg.exp_path, cfg.config.split('/')[-1][:-5], epoch, cfg.save_freq, use_cuda)
utils.checkpoint_save(model, optimizer, cfg.exp_path,
cfg.config.split('/')[-1][:-5], epoch, cfg.save_freq, use_cuda)
for k in am_dict.keys():
if k in visual_dict.keys():
writer.add_scalar(k+'_train', am_dict[k].avg, epoch)
writer.add_scalar(k + '_train', am_dict[k].avg, epoch)
def eval_epoch(val_loader, model, model_fn, epoch):
@ -114,28 +137,76 @@ def eval_epoch(val_loader, model, model_fn, epoch):
for i, batch in enumerate(val_loader):
# prepare input and forward
loss, preds, visual_dict, meter_dict = model_fn(batch, model, epoch, semantic_only=cfg.semantic_only)
loss, preds, visual_dict, meter_dict = model_fn(
batch, model, epoch, semantic_only=cfg.semantic_only)
for k, v in meter_dict.items():
if k not in am_dict.keys():
am_dict[k] = utils.AverageMeter()
am_dict[k].update(v[0], v[1])
sys.stdout.write("\riter: {}/{} loss: {:.4f}({:.4f})".format(i + 1, len(val_loader), am_dict['loss'].val, am_dict['loss'].avg))
sys.stdout.write("\riter: {}/{} loss: {:.4f}({:.4f})".format(
i + 1, len(val_loader), am_dict['loss'].val, am_dict['loss'].avg))
if (i == len(val_loader) - 1): print()
logger.info("epoch: {}/{}, val loss: {:.4f}, time: {}s".format(epoch, cfg.epochs, am_dict['loss'].avg, time.time() - start_epoch))
logger.info("epoch: {}/{}, val loss: {:.4f}, time: {}s".format(
epoch, cfg.epochs, am_dict['loss'].avg,
time.time() - start_epoch))
for k in am_dict.keys():
if k in visual_dict.keys():
writer.add_scalar(k + '_eval', am_dict[k].avg, epoch)
def get_args():
parser = argparse.ArgumentParser('SoftGroup')
parser.add_argument('config', type=str, help='path to config file')
args = parser.parse_args()
return args
if __name__ == '__main__':
torch.backends.cudnn.enabled = False
torch.backends.cudnn.enabled = False # TODO remove this
test_seed = 123
random.seed(test_seed)
np.random.seed(test_seed)
torch.manual_seed(test_seed)
torch.cuda.manual_seed_all(test_seed)
args = get_args()
cfg = Munch.fromDict(yaml.safe_load(open(args.config, 'r')))
model = SoftGroup(**cfg.model)
print(f'Load pretrained state dict from {cfg.pretrain}')
# model = utils.load_checkpoint(model, cfg.pretrain)
model.cuda()
dataset = S3DISDataset(**cfg.data.train)
dataloader = DataLoader(
dataset,
batch_size=4,
collate_fn=dataset.collate_fn,
num_workers=4,
sampler=None,
shuffle=True,
drop_last=True,
pin_memory=True)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.optim.lr)
# train and val
for epoch in range(1, 500 + 1):
train_epoch(epoch, dataloader, model, optimizer)
if utils.is_multiple(epoch, cfg.save_freq) or utils.is_power2(epoch):
eval_epoch(dataset.val_data_loader, model, model_fn, epoch)
import pdb
pdb.set_trace()
if cfg.dist == True:
raise NotImplementedError
# num_gpus = torch.cuda.device_count()
# dist.init_process_group(backend='nccl', rank=cfg.local_rank,
# dist.init_process_group(backend='nccl', rank=cfg.local_rank,
# world_size=num_gpus)
# torch.cuda.set_device(cfg.local_rank)
@ -164,15 +235,18 @@ if __name__ == '__main__':
assert use_cuda
model = model.cuda()
logger.info('#classifier parameters: {}'.format(sum([x.nelement() for x in model.parameters()])))
logger.info('#classifier parameters: {}'.format(
sum([x.nelement() for x in model.parameters()])))
# optimizer
if cfg.optim == 'Adam':
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr)
elif cfg.optim == 'SGD':
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, model.parameters()),
lr=cfg.lr,
momentum=cfg.momentum,
weight_decay=cfg.weight_decay)
model_fn = model_fn_decorator()
@ -197,23 +271,20 @@ if __name__ == '__main__':
else:
raise NotImplementedError("Not yet supported")
# resume from the latest epoch, or specify the epoch to restore
start_epoch = utils.checkpoint_restore(cfg, model, optimizer, cfg.exp_path,
cfg.config.split('/')[-1][:-5], use_cuda)
start_epoch = utils.checkpoint_restore(cfg, model, optimizer, cfg.exp_path,
cfg.config.split('/')[-1][:-5], use_cuda)
if cfg.dist:
# model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = torch.nn.parallel.DistributedDataParallel(
model.cuda(cfg.local_rank),
device_ids=[cfg.local_rank],
output_device=cfg.local_rank,
find_unused_parameters=True)
model.cuda(cfg.local_rank),
device_ids=[cfg.local_rank],
output_device=cfg.local_rank,
find_unused_parameters=True)
# train and val
for epoch in range(start_epoch, cfg.epochs + 1):
for epoch in range(1, cfg.epochs + 1):
train_epoch(dataset.train_data_loader, model, model_fn, optimizer, epoch)
if utils.is_multiple(epoch, cfg.save_freq) or utils.is_power2(epoch):