mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
refactor train point wise net
This commit is contained in:
parent
e5e813ab67
commit
af2c982653
@ -31,12 +31,13 @@ model:
|
||||
data:
|
||||
train:
|
||||
data_root: 'dataset/s3dis/preprocess'
|
||||
prefix: 'val'
|
||||
prefix: ['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6']
|
||||
suffix: '_inst_nostuff.pth'
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
test:
|
||||
data_root: 'dataset/s3dis/preprocess'
|
||||
prefix: 'Area_5'
|
||||
@ -45,10 +46,17 @@ data:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
data_loader:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
|
||||
optim:
|
||||
lr: 0.001
|
||||
|
||||
pretrain: 'hais_ckpt.pth'
|
||||
resume: ''
|
||||
|
||||
|
||||
DATA:
|
||||
data_root: dataset
|
||||
|
||||
4
data/__init__.py
Normal file
4
data/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .s3dis import S3DISDataset
|
||||
from .scannetv2 import ScanNetDataset
|
||||
|
||||
__all__ = ['S3DISDataset', 'ScanNetDataset']
|
||||
@ -141,10 +141,18 @@ class CustomDataset(Dataset):
|
||||
def transform_train(self, xyz, rgb, label, instance_label):
|
||||
xyz_middle = self.dataAugment(xyz, True, True, True)
|
||||
xyz = xyz_middle * self.voxel_cfg.scale
|
||||
xyz = self.elastic(xyz, 6 * self.scale // 50, 40 * self.scale / 50)
|
||||
xyz = self.elastic(xyz, 20 * self.scale // 50, 160 * self.scale / 50)
|
||||
xyz = self.elastic(xyz, 6 * self.voxel_cfg.scale // 50, 40 * self.voxel_cfg.scale / 50)
|
||||
xyz = self.elastic(xyz, 20 * self.voxel_cfg.scale // 50, 160 * self.voxel_cfg.scale / 50)
|
||||
xyz -= xyz.min(0)
|
||||
xyz, valid_idxs = self.crop(xyz)
|
||||
max_tries = 5
|
||||
while (max_tries > 0):
|
||||
xyz_offset, valid_idxs = self.crop(xyz)
|
||||
if valid_idxs.sum() >= self.voxel_cfg.min_npoint:
|
||||
xyz = xyz_offset
|
||||
break
|
||||
max_tries -= 1
|
||||
if valid_idxs.sum() < self.voxel_cfg.min_npoint:
|
||||
return None
|
||||
xyz = xyz[valid_idxs]
|
||||
xyz_middle = xyz_middle[valid_idxs]
|
||||
rgb = rgb[valid_idxs]
|
||||
@ -173,7 +181,7 @@ class CustomDataset(Dataset):
|
||||
inst_cls = inst_infos["instance_cls"]
|
||||
loc = torch.from_numpy(xyz).long()
|
||||
loc_float = torch.from_numpy(xyz_middle)
|
||||
feat = torch.from_numpy(rgb)
|
||||
feat = torch.from_numpy(rgb).float()
|
||||
if self.training:
|
||||
feat += torch.randn(3) * 0.1
|
||||
label = torch.from_numpy(label)
|
||||
@ -197,26 +205,28 @@ class CustomDataset(Dataset):
|
||||
batch_offsets = [0]
|
||||
|
||||
total_inst_num = 0
|
||||
for i, data in enumerate(batch):
|
||||
batch_id = 0
|
||||
for data in batch:
|
||||
if data is None:
|
||||
continue
|
||||
(scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_info,
|
||||
inst_pointnum, inst_cls) = data
|
||||
|
||||
instance_label[np.where(instance_label != -100)] += total_inst_num
|
||||
total_inst_num += inst_num
|
||||
|
||||
# merge the scene to the batch
|
||||
batch_offsets.append(batch_offsets[-1] + loc.size(0))
|
||||
|
||||
scan_ids.append(scan_id)
|
||||
locs.append(torch.cat([loc.new_full((loc.size(0), 1), i), loc], 1))
|
||||
locs.append(torch.cat([loc.new_full((loc.size(0), 1), batch_id), loc], 1))
|
||||
locs_float.append(loc_float)
|
||||
feats.append(feat)
|
||||
labels.append(label)
|
||||
instance_labels.append(instance_label)
|
||||
|
||||
instance_infos.append(inst_info)
|
||||
instance_pointnum.extend(inst_pointnum)
|
||||
instance_cls.extend(inst_cls)
|
||||
batch_id += 1
|
||||
assert batch_id > 0, 'empty batch'
|
||||
if batch_id < len(batch):
|
||||
print(f'batch is truncated from size {len(batch)} to {batch_id}')
|
||||
|
||||
# merge all the scenes in the batch
|
||||
batch_offsets = torch.tensor(batch_offsets, dtype=torch.int) # int (B+1)
|
||||
@ -226,18 +236,14 @@ class CustomDataset(Dataset):
|
||||
feats = torch.cat(feats, 0) # float (N, C)
|
||||
labels = torch.cat(labels, 0).long() # long (N)
|
||||
instance_labels = torch.cat(instance_labels, 0).long() # long (N)
|
||||
|
||||
instance_infos = torch.cat(instance_infos,
|
||||
0).to(torch.float32) # float (N, 9) (meanxyz, minxyz, maxxyz)
|
||||
instance_pointnum = torch.tensor(instance_pointnum, dtype=torch.int) # int (total_nInst)
|
||||
instance_cls = torch.tensor(instance_cls, dtype=torch.long) # long (total_nInst)
|
||||
|
||||
spatial_shape = np.clip((locs.max(0)[0][1:] + 1).numpy(), self.voxel_cfg.spatial_shape[0],
|
||||
None) # long (3)
|
||||
|
||||
# voxelize
|
||||
spatial_shape = np.clip(
|
||||
locs.max(0)[0][1:].numpy() + 1, self.voxel_cfg.spatial_shape[0], None)
|
||||
voxel_locs, p2v_map, v2p_map = softgroup_ops.voxelization_idx(locs, 1)
|
||||
|
||||
return {
|
||||
'scan_ids': scan_ids,
|
||||
'locs': locs,
|
||||
@ -252,5 +258,6 @@ class CustomDataset(Dataset):
|
||||
'instance_pointnum': instance_pointnum,
|
||||
'instance_cls': instance_cls,
|
||||
'offsets': batch_offsets,
|
||||
'spatial_shape': spatial_shape
|
||||
'spatial_shape': spatial_shape,
|
||||
'batch_size': batch_id,
|
||||
}
|
||||
|
||||
@ -16,9 +16,15 @@ class S3DISDataset(CustomDataset):
|
||||
"bookcase", "sofa", "board", "clutter")
|
||||
|
||||
def get_filenames(self):
|
||||
filenames = sorted(glob(osp.join(self.data_root, self.prefix + '*' + self.suffix)))
|
||||
assert len(filenames) > 0, 'Empty dataset.'
|
||||
return filenames
|
||||
if isinstance(self.prefix, str):
|
||||
self.prefix = [self.prefix]
|
||||
filenames_all = []
|
||||
for p in self.prefix:
|
||||
filenames = glob(osp.join(self.data_root, p + '*' + self.suffix))
|
||||
assert len(filenames) > 0, f'Empty {p}'
|
||||
filenames_all.extend(filenames)
|
||||
filenames_all.sort()
|
||||
return filenames_all
|
||||
|
||||
def load(self, filename):
|
||||
# TODO make file load results consistent
|
||||
@ -35,18 +41,19 @@ class S3DISDataset(CustomDataset):
|
||||
|
||||
def crop(self, xyz, step=64):
|
||||
xyz_offset = xyz.copy()
|
||||
valid_idxs = (xyz_offset.min(1) >= 0) * ((xyz < self.full_scale[1]).sum(1) == 3)
|
||||
valid_idxs = (xyz_offset.min(1) >= 0) * (
|
||||
(xyz < self.voxel_cfg.spatial_shape[1]).sum(1) == 3)
|
||||
|
||||
full_scale = np.array([self.full_scale[1]] * 3)
|
||||
spatial_shape = np.array([self.voxel_cfg.spatial_shape[1]] * 3)
|
||||
room_range = xyz.max(0) - xyz.min(0)
|
||||
while (valid_idxs.sum() > self.max_npoint):
|
||||
while (valid_idxs.sum() > self.voxel_cfg.max_npoint):
|
||||
step_temp = step
|
||||
if valid_idxs.sum() > 1e6:
|
||||
step_temp = step * 2
|
||||
offset = np.clip(full_scale - room_range + 0.001, None, 0) * np.random.rand(3)
|
||||
offset = np.clip(spatial_shape - room_range + 0.001, None, 0) * np.random.rand(3)
|
||||
xyz_offset = xyz + offset
|
||||
valid_idxs = (xyz_offset.min(1) >= 0) * ((xyz_offset < full_scale).sum(1) == 3)
|
||||
full_scale[:2] -= step_temp
|
||||
valid_idxs = (xyz_offset.min(1) >= 0) * ((xyz_offset < spatial_shape).sum(1) == 3)
|
||||
spatial_shape[:2] -= step_temp
|
||||
|
||||
return xyz_offset, valid_idxs
|
||||
|
||||
|
||||
@ -6,29 +6,36 @@ import torch
|
||||
|
||||
|
||||
class ResidualBlock(SparseModule):
|
||||
|
||||
def __init__(self, in_channels, out_channels, norm_fn, indice_key=None):
|
||||
super().__init__()
|
||||
|
||||
if in_channels == out_channels:
|
||||
self.i_branch = spconv.SparseSequential(
|
||||
nn.Identity()
|
||||
)
|
||||
self.i_branch = spconv.SparseSequential(nn.Identity())
|
||||
else:
|
||||
self.i_branch = spconv.SparseSequential(
|
||||
spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, bias=False)
|
||||
)
|
||||
spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, bias=False))
|
||||
|
||||
self.conv_branch = spconv.SparseSequential(
|
||||
norm_fn(in_channels),
|
||||
nn.ReLU(),
|
||||
spconv.SubMConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False, indice_key=indice_key),
|
||||
norm_fn(out_channels),
|
||||
nn.ReLU(),
|
||||
spconv.SubMConv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False, indice_key=indice_key)
|
||||
)
|
||||
norm_fn(in_channels), nn.ReLU(),
|
||||
spconv.SubMConv3d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
indice_key=indice_key), norm_fn(out_channels), nn.ReLU(),
|
||||
spconv.SubMConv3d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
indice_key=indice_key))
|
||||
|
||||
def forward(self, input):
|
||||
identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape, input.batch_size)
|
||||
identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape,
|
||||
input.batch_size)
|
||||
output = self.conv_branch(input)
|
||||
output.features += self.i_branch(identity).features
|
||||
|
||||
@ -36,41 +43,59 @@ class ResidualBlock(SparseModule):
|
||||
|
||||
|
||||
class UBlock(nn.Module):
|
||||
|
||||
def __init__(self, nPlanes, norm_fn, block_reps, block, indice_key_id=1):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.nPlanes = nPlanes
|
||||
|
||||
blocks = {'block{}'.format(i): block(nPlanes[0], nPlanes[0], norm_fn, indice_key='subm{}'.format(indice_key_id)) for i in range(block_reps)}
|
||||
blocks = {
|
||||
'block{}'.format(i):
|
||||
block(nPlanes[0], nPlanes[0], norm_fn, indice_key='subm{}'.format(indice_key_id))
|
||||
for i in range(block_reps)
|
||||
}
|
||||
blocks = OrderedDict(blocks)
|
||||
self.blocks = spconv.SparseSequential(blocks)
|
||||
|
||||
if len(nPlanes) > 1:
|
||||
self.conv = spconv.SparseSequential(
|
||||
norm_fn(nPlanes[0]),
|
||||
nn.ReLU(),
|
||||
spconv.SparseConv3d(nPlanes[0], nPlanes[1], kernel_size=2, stride=2, bias=False, indice_key='spconv{}'.format(indice_key_id))
|
||||
)
|
||||
norm_fn(nPlanes[0]), nn.ReLU(),
|
||||
spconv.SparseConv3d(
|
||||
nPlanes[0],
|
||||
nPlanes[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
bias=False,
|
||||
indice_key='spconv{}'.format(indice_key_id)))
|
||||
|
||||
self.u = UBlock(nPlanes[1:], norm_fn, block_reps, block, indice_key_id=indice_key_id+1)
|
||||
self.u = UBlock(
|
||||
nPlanes[1:], norm_fn, block_reps, block, indice_key_id=indice_key_id + 1)
|
||||
|
||||
self.deconv = spconv.SparseSequential(
|
||||
norm_fn(nPlanes[1]),
|
||||
nn.ReLU(),
|
||||
spconv.SparseInverseConv3d(nPlanes[1], nPlanes[0], kernel_size=2, bias=False, indice_key='spconv{}'.format(indice_key_id))
|
||||
)
|
||||
norm_fn(nPlanes[1]), nn.ReLU(),
|
||||
spconv.SparseInverseConv3d(
|
||||
nPlanes[1],
|
||||
nPlanes[0],
|
||||
kernel_size=2,
|
||||
bias=False,
|
||||
indice_key='spconv{}'.format(indice_key_id)))
|
||||
|
||||
blocks_tail = {}
|
||||
for i in range(block_reps):
|
||||
blocks_tail['block{}'.format(i)] = block(nPlanes[0] * (2 - i), nPlanes[0], norm_fn, indice_key='subm{}'.format(indice_key_id))
|
||||
blocks_tail['block{}'.format(i)] = block(
|
||||
nPlanes[0] * (2 - i),
|
||||
nPlanes[0],
|
||||
norm_fn,
|
||||
indice_key='subm{}'.format(indice_key_id))
|
||||
blocks_tail = OrderedDict(blocks_tail)
|
||||
self.blocks_tail = spconv.SparseSequential(blocks_tail)
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
output = self.blocks(input)
|
||||
identity = spconv.SparseConvTensor(output.features, output.indices, output.spatial_shape, output.batch_size)
|
||||
identity = spconv.SparseConvTensor(output.features, output.indices, output.spatial_shape,
|
||||
output.batch_size)
|
||||
if len(self.nPlanes) > 1:
|
||||
output_decoder = self.conv(output)
|
||||
output_decoder = self.u(output_decoder)
|
||||
|
||||
@ -25,8 +25,7 @@ class SoftGroup(nn.Module):
|
||||
grouping_cfg=None,
|
||||
instance_voxel_cfg=None,
|
||||
test_cfg=None,
|
||||
fixed_modules=[],
|
||||
pretrained=None):
|
||||
fixed_modules=[]):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.num_blocks = num_blocks
|
||||
@ -77,6 +76,9 @@ class SoftGroup(nn.Module):
|
||||
nn.Linear(channels, channels), nn.ReLU(), nn.Linear(channels, instance_classes + 1))
|
||||
self.score_linear = nn.Linear(channels, instance_classes + 1)
|
||||
|
||||
self.semantic_loss = nn.CrossEntropyLoss(ignore_index=ignore_label)
|
||||
self.offset_loss = nn.L1Loss(reduction='sum')
|
||||
|
||||
self.apply(self.set_bn_init)
|
||||
nn.init.normal_(self.score_linear.weight, 0, 0.01)
|
||||
nn.init.constant_(self.score_linear.bias, 0)
|
||||
@ -100,6 +102,52 @@ class SoftGroup(nn.Module):
|
||||
else:
|
||||
return self.forward_test(batch)
|
||||
|
||||
def forward_train(self, batch):
|
||||
coords = batch['locs'].cuda()
|
||||
voxel_coords = batch['voxel_locs'].cuda()
|
||||
p2v_map = batch['p2v_map'].cuda()
|
||||
v2p_map = batch['v2p_map'].cuda()
|
||||
coords_float = batch['locs_float'].cuda()
|
||||
feats = batch['feats'].cuda()
|
||||
semantic_labels = batch['labels'].cuda()
|
||||
instance_labels = batch['instance_labels'].cuda()
|
||||
instance_info = batch['instance_info'].cuda()
|
||||
# instance_pointnum = batch['instance_pointnum'].cuda()
|
||||
# instance_cls = batch['instance_cls'].cuda()
|
||||
# batch_offsets = batch['offsets'].cuda()
|
||||
spatial_shape = batch['spatial_shape']
|
||||
batch_size = batch['batch_size']
|
||||
|
||||
feats = torch.cat((feats, coords_float), 1)
|
||||
voxel_feats = softgroup_ops.voxelization(feats, v2p_map)
|
||||
|
||||
losses = {}
|
||||
pt_offset_labels = instance_info[:, :3] - coords_float
|
||||
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
|
||||
semantic_scores, pt_offsets, output_feats, coords_float = self.forward_backbone(
|
||||
input, p2v_map, coords_float) # TODO check name for map
|
||||
point_wise_loss = self.point_wise_loss(semantic_scores, pt_offsets, semantic_labels,
|
||||
instance_labels, pt_offset_labels)
|
||||
losses.update(point_wise_loss)
|
||||
loss = sum(v[0] for v in losses.values())
|
||||
losses['loss'] = (loss, coords.size(0))
|
||||
return loss, losses
|
||||
|
||||
def point_wise_loss(self, semantic_scores, pt_offsets, semantic_labels, instance_labels,
|
||||
pt_offset_labels):
|
||||
losses = {}
|
||||
semantic_loss = self.semantic_loss(semantic_scores, semantic_labels)
|
||||
losses['semantic_loss'] = (semantic_loss, semantic_scores.size(0))
|
||||
|
||||
pos_inds = instance_labels != self.ignore_label
|
||||
if pos_inds.sum() == 0:
|
||||
offset_loss = 0 * pt_offset.sum()
|
||||
else:
|
||||
offset_loss = self.offset_loss(pt_offsets[pos_inds],
|
||||
pt_offset_labels[pos_inds]) / pos_inds.sum()
|
||||
losses['offset_loss'] = (offset_loss, pos_inds.sum())
|
||||
return losses
|
||||
|
||||
def forward_test(self, batch):
|
||||
coords = batch['locs'].cuda()
|
||||
voxel_coords = batch['voxel_locs'].cuda()
|
||||
@ -124,7 +172,8 @@ class SoftGroup(nn.Module):
|
||||
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, 1)
|
||||
batch_idxs = coords[:, 0].int()
|
||||
semantic_scores, pt_offsets, output_feats, coords_float = self.forward_backbone(
|
||||
input, p2v_map, coords_float, x4_split=self.test_cfg.x4_split) # TODO check name for map
|
||||
input, p2v_map, coords_float,
|
||||
x4_split=self.test_cfg.x4_split) # TODO check name for map
|
||||
proposals_idx, proposals_offset = self.forward_grouping(semantic_scores, pt_offsets,
|
||||
batch_idxs, coords_float,
|
||||
self.grouping_cfg)
|
||||
@ -150,7 +199,6 @@ class SoftGroup(nn.Module):
|
||||
output_feats = output.features[input_map.long()]
|
||||
|
||||
semantic_scores = self.semantic_linear(output_feats)
|
||||
semantic_scores = semantic_scores.softmax(dim=-1)
|
||||
pt_offsets = self.offset_linear(output_feats)
|
||||
return semantic_scores, pt_offsets, output_feats, coords
|
||||
|
||||
@ -194,6 +242,7 @@ class SoftGroup(nn.Module):
|
||||
proposals_idx_list = []
|
||||
proposals_offset_list = []
|
||||
batch_size = batch_idxs.max() + 1
|
||||
semantic_scores = semantic_scores.softmax(dim=-1)
|
||||
semantic_preds = semantic_scores.max(1)[1] # TODO remove this
|
||||
|
||||
radius = self.grouping_cfg.radius
|
||||
|
||||
145
train.py
145
train.py
@ -7,6 +7,13 @@ import numpy as np
|
||||
from util.config import cfg
|
||||
|
||||
import torch.distributed as dist
|
||||
import argparse
|
||||
from munch import Munch
|
||||
import yaml
|
||||
from model.softgroup import SoftGroup
|
||||
from util import utils
|
||||
from data import S3DISDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def init():
|
||||
@ -31,8 +38,9 @@ def init():
|
||||
torch.manual_seed(cfg.manual_seed)
|
||||
torch.cuda.manual_seed_all(cfg.manual_seed)
|
||||
|
||||
|
||||
# epoch counts from 1 to N
|
||||
def train_epoch(train_loader, model, model_fn, optimizer, epoch):
|
||||
def train_epoch(epoch, train_loader, model, optimizer):
|
||||
iter_time = utils.AverageMeter()
|
||||
data_time = utils.AverageMeter()
|
||||
am_dict = {}
|
||||
@ -41,8 +49,8 @@ def train_epoch(train_loader, model, model_fn, optimizer, epoch):
|
||||
start_epoch = time.time()
|
||||
end = time.time()
|
||||
|
||||
if train_loader.sampler is not None and cfg.dist == True:
|
||||
train_loader.sampler.set_epoch(epoch)
|
||||
# if train_loader.sampler is not None and cfg.dist == True:
|
||||
# train_loader.sampler.set_epoch(epoch)
|
||||
|
||||
for i, batch in enumerate(train_loader):
|
||||
|
||||
@ -53,15 +61,16 @@ def train_epoch(train_loader, model, model_fn, optimizer, epoch):
|
||||
data_time.update(time.time() - end)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
loss, log_vars = model(batch, return_loss=True)
|
||||
|
||||
# adjust learning rate
|
||||
utils.cosine_lr_after_step(optimizer, cfg.lr, epoch - 1, cfg.step_epoch, cfg.epochs)
|
||||
|
||||
# utils.cosine_lr_after_step(optimizer, cfg.lr, epoch - 1, cfg.step_epoch, cfg.epochs)
|
||||
|
||||
# prepare input and forward
|
||||
loss, _, visual_dict, meter_dict = model_fn(batch, model, epoch, semantic_only=cfg.semantic_only)
|
||||
# loss, _, visual_dict, meter_dict = model_fn(batch, model, epoch, semantic_only=cfg.semantic_only)
|
||||
|
||||
# meter_dict
|
||||
for k, v in meter_dict.items():
|
||||
for k, v in log_vars.items():
|
||||
if k not in am_dict.keys():
|
||||
am_dict[k] = utils.AverageMeter()
|
||||
am_dict[k].update(v[0], v[1])
|
||||
@ -73,7 +82,7 @@ def train_epoch(train_loader, model, model_fn, optimizer, epoch):
|
||||
|
||||
# time and print
|
||||
current_iter = (epoch - 1) * len(train_loader) + i + 1
|
||||
max_iter = cfg.epochs * len(train_loader)
|
||||
max_iter = 500 * len(train_loader)
|
||||
remain_iter = max_iter - current_iter
|
||||
|
||||
iter_time.update(time.time() - end)
|
||||
@ -84,24 +93,38 @@ def train_epoch(train_loader, model, model_fn, optimizer, epoch):
|
||||
t_h, t_m = divmod(t_m, 60)
|
||||
remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))
|
||||
|
||||
if cfg.local_rank == 0 and i % 10 == 0:
|
||||
if i % 1 == 0:
|
||||
lr = optimizer.param_groups[0]['lr']
|
||||
sys.stdout.write(
|
||||
"epoch: {}/{} iter: {}/{} lr: {:.5f} loss: {:.4f}({:.4f}) data_time: {:.2f}({:.2f}) iter_time: {:.2f}({:.2f}) remain_time: {remain_time}\n".format
|
||||
(epoch, cfg.epochs, i + 1, len(train_loader), lr, am_dict['loss'].val, am_dict['loss'].avg,
|
||||
data_time.val, data_time.avg, iter_time.val, iter_time.avg, remain_time=remain_time))
|
||||
|
||||
log_str = "epoch: {}/{} iter: {}/{} lr: {:.5f} loss: {:.4f}({:.4f}) data_time: {:.2f}({:.2f}) iter_time: {:.2f}({:.2f}) remain_time: {remain_time}".format(
|
||||
epoch,
|
||||
500,
|
||||
i + 1,
|
||||
len(train_loader),
|
||||
lr,
|
||||
am_dict['loss'].val,
|
||||
am_dict['loss'].avg,
|
||||
data_time.val,
|
||||
data_time.avg,
|
||||
iter_time.val,
|
||||
iter_time.avg,
|
||||
remain_time=remain_time)
|
||||
for k, v in am_dict.items():
|
||||
log_str += f' {k}: {v.avg:.4f}'
|
||||
print(log_str)
|
||||
|
||||
if (i == len(train_loader) - 1): print()
|
||||
|
||||
|
||||
logger.info("epoch: {}/{}, train loss: {:.4f}, time: {}s".format(epoch, cfg.epochs, am_dict['loss'].avg, time.time() - start_epoch))
|
||||
logger.info("epoch: {}/{}, train loss: {:.4f}, time: {}s".format(epoch, cfg.epochs,
|
||||
am_dict['loss'].avg,
|
||||
time.time() - start_epoch))
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
utils.checkpoint_save(model, optimizer, cfg.exp_path, cfg.config.split('/')[-1][:-5], epoch, cfg.save_freq, use_cuda)
|
||||
utils.checkpoint_save(model, optimizer, cfg.exp_path,
|
||||
cfg.config.split('/')[-1][:-5], epoch, cfg.save_freq, use_cuda)
|
||||
|
||||
for k in am_dict.keys():
|
||||
if k in visual_dict.keys():
|
||||
writer.add_scalar(k+'_train', am_dict[k].avg, epoch)
|
||||
writer.add_scalar(k + '_train', am_dict[k].avg, epoch)
|
||||
|
||||
|
||||
def eval_epoch(val_loader, model, model_fn, epoch):
|
||||
@ -114,28 +137,76 @@ def eval_epoch(val_loader, model, model_fn, epoch):
|
||||
for i, batch in enumerate(val_loader):
|
||||
|
||||
# prepare input and forward
|
||||
loss, preds, visual_dict, meter_dict = model_fn(batch, model, epoch, semantic_only=cfg.semantic_only)
|
||||
loss, preds, visual_dict, meter_dict = model_fn(
|
||||
batch, model, epoch, semantic_only=cfg.semantic_only)
|
||||
|
||||
for k, v in meter_dict.items():
|
||||
if k not in am_dict.keys():
|
||||
am_dict[k] = utils.AverageMeter()
|
||||
am_dict[k].update(v[0], v[1])
|
||||
sys.stdout.write("\riter: {}/{} loss: {:.4f}({:.4f})".format(i + 1, len(val_loader), am_dict['loss'].val, am_dict['loss'].avg))
|
||||
sys.stdout.write("\riter: {}/{} loss: {:.4f}({:.4f})".format(
|
||||
i + 1, len(val_loader), am_dict['loss'].val, am_dict['loss'].avg))
|
||||
if (i == len(val_loader) - 1): print()
|
||||
|
||||
logger.info("epoch: {}/{}, val loss: {:.4f}, time: {}s".format(epoch, cfg.epochs, am_dict['loss'].avg, time.time() - start_epoch))
|
||||
logger.info("epoch: {}/{}, val loss: {:.4f}, time: {}s".format(
|
||||
epoch, cfg.epochs, am_dict['loss'].avg,
|
||||
time.time() - start_epoch))
|
||||
|
||||
for k in am_dict.keys():
|
||||
if k in visual_dict.keys():
|
||||
writer.add_scalar(k + '_eval', am_dict[k].avg, epoch)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser('SoftGroup')
|
||||
parser.add_argument('config', type=str, help='path to config file')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cudnn.enabled = False
|
||||
torch.backends.cudnn.enabled = False # TODO remove this
|
||||
test_seed = 123
|
||||
random.seed(test_seed)
|
||||
np.random.seed(test_seed)
|
||||
torch.manual_seed(test_seed)
|
||||
torch.cuda.manual_seed_all(test_seed)
|
||||
|
||||
args = get_args()
|
||||
cfg = Munch.fromDict(yaml.safe_load(open(args.config, 'r')))
|
||||
|
||||
model = SoftGroup(**cfg.model)
|
||||
print(f'Load pretrained state dict from {cfg.pretrain}')
|
||||
# model = utils.load_checkpoint(model, cfg.pretrain)
|
||||
model.cuda()
|
||||
|
||||
dataset = S3DISDataset(**cfg.data.train)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=4,
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=4,
|
||||
sampler=None,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=True)
|
||||
|
||||
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.optim.lr)
|
||||
|
||||
# train and val
|
||||
for epoch in range(1, 500 + 1):
|
||||
train_epoch(epoch, dataloader, model, optimizer)
|
||||
|
||||
if utils.is_multiple(epoch, cfg.save_freq) or utils.is_power2(epoch):
|
||||
eval_epoch(dataset.val_data_loader, model, model_fn, epoch)
|
||||
|
||||
import pdb
|
||||
pdb.set_trace()
|
||||
|
||||
if cfg.dist == True:
|
||||
raise NotImplementedError
|
||||
# num_gpus = torch.cuda.device_count()
|
||||
# dist.init_process_group(backend='nccl', rank=cfg.local_rank,
|
||||
# dist.init_process_group(backend='nccl', rank=cfg.local_rank,
|
||||
# world_size=num_gpus)
|
||||
# torch.cuda.set_device(cfg.local_rank)
|
||||
|
||||
@ -164,15 +235,18 @@ if __name__ == '__main__':
|
||||
assert use_cuda
|
||||
model = model.cuda()
|
||||
|
||||
logger.info('#classifier parameters: {}'.format(sum([x.nelement() for x in model.parameters()])))
|
||||
logger.info('#classifier parameters: {}'.format(
|
||||
sum([x.nelement() for x in model.parameters()])))
|
||||
|
||||
# optimizer
|
||||
if cfg.optim == 'Adam':
|
||||
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr)
|
||||
elif cfg.optim == 'SGD':
|
||||
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
|
||||
lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
|
||||
|
||||
optimizer = optim.SGD(
|
||||
filter(lambda p: p.requires_grad, model.parameters()),
|
||||
lr=cfg.lr,
|
||||
momentum=cfg.momentum,
|
||||
weight_decay=cfg.weight_decay)
|
||||
|
||||
model_fn = model_fn_decorator()
|
||||
|
||||
@ -197,23 +271,20 @@ if __name__ == '__main__':
|
||||
else:
|
||||
raise NotImplementedError("Not yet supported")
|
||||
|
||||
|
||||
# resume from the latest epoch, or specify the epoch to restore
|
||||
start_epoch = utils.checkpoint_restore(cfg, model, optimizer, cfg.exp_path,
|
||||
cfg.config.split('/')[-1][:-5], use_cuda)
|
||||
|
||||
start_epoch = utils.checkpoint_restore(cfg, model, optimizer, cfg.exp_path,
|
||||
cfg.config.split('/')[-1][:-5], use_cuda)
|
||||
|
||||
if cfg.dist:
|
||||
# model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model.cuda(cfg.local_rank),
|
||||
device_ids=[cfg.local_rank],
|
||||
output_device=cfg.local_rank,
|
||||
find_unused_parameters=True)
|
||||
|
||||
model.cuda(cfg.local_rank),
|
||||
device_ids=[cfg.local_rank],
|
||||
output_device=cfg.local_rank,
|
||||
find_unused_parameters=True)
|
||||
|
||||
# train and val
|
||||
for epoch in range(start_epoch, cfg.epochs + 1):
|
||||
for epoch in range(1, cfg.epochs + 1):
|
||||
train_epoch(dataset.train_data_loader, model, model_fn, optimizer, epoch)
|
||||
|
||||
if utils.is_multiple(epoch, cfg.save_freq) or utils.is_power2(epoch):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user