mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
support mix precision training
This commit is contained in:
parent
c620cfc435
commit
3475ab88b9
@ -66,6 +66,7 @@ optimizer:
|
|||||||
type: 'Adam'
|
type: 'Adam'
|
||||||
lr: 0.004
|
lr: 0.004
|
||||||
|
|
||||||
|
fp16: False
|
||||||
epochs: 128
|
epochs: 128
|
||||||
step_epoch: 50
|
step_epoch: 50
|
||||||
save_freq: 4
|
save_freq: 4
|
||||||
|
|||||||
74
configs/softgroup_scannet_backbone_fp16.yaml
Normal file
74
configs/softgroup_scannet_backbone_fp16.yaml
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
model:
|
||||||
|
channels: 32
|
||||||
|
num_blocks: 7
|
||||||
|
semantic_classes: 20
|
||||||
|
instance_classes: 18
|
||||||
|
sem2ins_classes: []
|
||||||
|
semantic_only: True
|
||||||
|
ignore_label: -100
|
||||||
|
grouping_cfg:
|
||||||
|
score_thr: 0.2
|
||||||
|
radius: 0.04
|
||||||
|
mean_active: 300
|
||||||
|
class_numpoint_mean: [-1., -1., 3917., 12056., 2303.,
|
||||||
|
8331., 3948., 3166., 5629., 11719.,
|
||||||
|
1003., 3317., 4912., 10221., 3889.,
|
||||||
|
4136., 2120., 945., 3967., 2589.]
|
||||||
|
npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1
|
||||||
|
ignore_classes: [0, 1]
|
||||||
|
instance_voxel_cfg:
|
||||||
|
scale: 50
|
||||||
|
spatial_shape: 20
|
||||||
|
train_cfg:
|
||||||
|
max_proposal_num: 200
|
||||||
|
pos_iou_thr: 0.5
|
||||||
|
test_cfg:
|
||||||
|
x4_split: False
|
||||||
|
cls_score_thr: 0.001
|
||||||
|
mask_score_thr: -0.5
|
||||||
|
min_npoint: 100
|
||||||
|
fixed_modules: []
|
||||||
|
|
||||||
|
data:
|
||||||
|
train:
|
||||||
|
type: 'scannetv2'
|
||||||
|
data_root: 'dataset/scannetv2'
|
||||||
|
prefix: 'train'
|
||||||
|
suffix: '_inst_nostuff.pth'
|
||||||
|
training: True
|
||||||
|
repeat: 4
|
||||||
|
voxel_cfg:
|
||||||
|
scale: 50
|
||||||
|
spatial_shape: [128, 512]
|
||||||
|
max_npoint: 250000
|
||||||
|
min_npoint: 5000
|
||||||
|
test:
|
||||||
|
type: 'scannetv2'
|
||||||
|
data_root: 'dataset/scannetv2'
|
||||||
|
prefix: 'val'
|
||||||
|
suffix: '_inst_nostuff.pth'
|
||||||
|
training: False
|
||||||
|
voxel_cfg:
|
||||||
|
scale: 50
|
||||||
|
spatial_shape: [128, 512]
|
||||||
|
max_npoint: 250000
|
||||||
|
min_npoint: 5000
|
||||||
|
|
||||||
|
dataloader:
|
||||||
|
train:
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
test:
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 1
|
||||||
|
|
||||||
|
optimizer:
|
||||||
|
type: 'Adam'
|
||||||
|
lr: 0.004
|
||||||
|
|
||||||
|
fp16: True
|
||||||
|
epochs: 128
|
||||||
|
step_epoch: 50
|
||||||
|
save_freq: 4
|
||||||
|
pretrain: ''
|
||||||
|
work_dir: ''
|
||||||
74
configs/softgroup_scannet_fp16.yaml
Normal file
74
configs/softgroup_scannet_fp16.yaml
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
model:
|
||||||
|
channels: 32
|
||||||
|
num_blocks: 7
|
||||||
|
semantic_classes: 20
|
||||||
|
instance_classes: 18
|
||||||
|
sem2ins_classes: []
|
||||||
|
semantic_only: False
|
||||||
|
ignore_label: -100
|
||||||
|
grouping_cfg:
|
||||||
|
score_thr: 0.2
|
||||||
|
radius: 0.04
|
||||||
|
mean_active: 300
|
||||||
|
class_numpoint_mean: [-1., -1., 3917., 12056., 2303.,
|
||||||
|
8331., 3948., 3166., 5629., 11719.,
|
||||||
|
1003., 3317., 4912., 10221., 3889.,
|
||||||
|
4136., 2120., 945., 3967., 2589.]
|
||||||
|
npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1
|
||||||
|
ignore_classes: [0, 1]
|
||||||
|
instance_voxel_cfg:
|
||||||
|
scale: 50
|
||||||
|
spatial_shape: 20
|
||||||
|
train_cfg:
|
||||||
|
max_proposal_num: 200
|
||||||
|
pos_iou_thr: 0.5
|
||||||
|
test_cfg:
|
||||||
|
x4_split: False
|
||||||
|
cls_score_thr: 0.001
|
||||||
|
mask_score_thr: -0.5
|
||||||
|
min_npoint: 100
|
||||||
|
fixed_modules: ['input_conv', 'unet', 'output_layer', 'semantic_linear', 'offset_linear']
|
||||||
|
|
||||||
|
data:
|
||||||
|
train:
|
||||||
|
type: 'scannetv2'
|
||||||
|
data_root: 'dataset/scannetv2'
|
||||||
|
prefix: 'train'
|
||||||
|
suffix: '_inst_nostuff.pth'
|
||||||
|
training: True
|
||||||
|
repeat: 4
|
||||||
|
voxel_cfg:
|
||||||
|
scale: 50
|
||||||
|
spatial_shape: [128, 512]
|
||||||
|
max_npoint: 250000
|
||||||
|
min_npoint: 5000
|
||||||
|
test:
|
||||||
|
type: 'scannetv2'
|
||||||
|
data_root: 'dataset/scannetv2'
|
||||||
|
prefix: 'val'
|
||||||
|
suffix: '_inst_nostuff.pth'
|
||||||
|
training: False
|
||||||
|
voxel_cfg:
|
||||||
|
scale: 50
|
||||||
|
spatial_shape: [128, 512]
|
||||||
|
max_npoint: 250000
|
||||||
|
min_npoint: 5000
|
||||||
|
|
||||||
|
dataloader:
|
||||||
|
train:
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
test:
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 1
|
||||||
|
|
||||||
|
optimizer:
|
||||||
|
type: 'Adam'
|
||||||
|
lr: 0.004
|
||||||
|
|
||||||
|
fp16: True
|
||||||
|
epochs: 128
|
||||||
|
step_epoch: 50
|
||||||
|
save_freq: 4
|
||||||
|
pretrain: 'work_dirs/softgroup_scannet_backbone_spconv2_dist/epoch_116.pth'
|
||||||
|
work_dir: ''
|
||||||
@ -70,7 +70,7 @@ class CustomDataset(Dataset):
|
|||||||
|
|
||||||
return x + g(x) * mag
|
return x + g(x) * mag
|
||||||
|
|
||||||
def getInstanceInfo(self, xyz, instance_label, label):
|
def getInstanceInfo(self, xyz, instance_label, semantic_label):
|
||||||
pt_mean = np.ones((xyz.shape[0], 3), dtype=np.float32) * -100.0
|
pt_mean = np.ones((xyz.shape[0], 3), dtype=np.float32) * -100.0
|
||||||
instance_pointnum = []
|
instance_pointnum = []
|
||||||
instance_cls = []
|
instance_cls = []
|
||||||
@ -80,8 +80,8 @@ class CustomDataset(Dataset):
|
|||||||
xyz_i = xyz[inst_idx_i]
|
xyz_i = xyz[inst_idx_i]
|
||||||
pt_mean[inst_idx_i] = xyz_i.mean(0)
|
pt_mean[inst_idx_i] = xyz_i.mean(0)
|
||||||
instance_pointnum.append(inst_idx_i[0].size)
|
instance_pointnum.append(inst_idx_i[0].size)
|
||||||
cls_loc = inst_idx_i[0][0]
|
cls_idx = inst_idx_i[0][0]
|
||||||
instance_cls.append(label[cls_loc])
|
instance_cls.append(semantic_label[cls_idx])
|
||||||
pt_offset_label = pt_mean - xyz
|
pt_offset_label = pt_mean - xyz
|
||||||
return instance_num, instance_pointnum, instance_cls, pt_offset_label
|
return instance_num, instance_pointnum, instance_cls, pt_offset_label
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ class CustomDataset(Dataset):
|
|||||||
j += 1
|
j += 1
|
||||||
return instance_label
|
return instance_label
|
||||||
|
|
||||||
def transform_train(self, xyz, rgb, label, instance_label):
|
def transform_train(self, xyz, rgb, semantic_label, instance_label):
|
||||||
xyz_middle = self.dataAugment(xyz, True, True, True)
|
xyz_middle = self.dataAugment(xyz, True, True, True)
|
||||||
xyz = xyz_middle * self.voxel_cfg.scale
|
xyz = xyz_middle * self.voxel_cfg.scale
|
||||||
xyz = self.elastic(xyz, 6 * self.voxel_cfg.scale // 50, 40 * self.voxel_cfg.scale / 50)
|
xyz = self.elastic(xyz, 6 * self.voxel_cfg.scale // 50, 40 * self.voxel_cfg.scale / 50)
|
||||||
@ -140,17 +140,17 @@ class CustomDataset(Dataset):
|
|||||||
xyz = xyz[valid_idxs]
|
xyz = xyz[valid_idxs]
|
||||||
xyz_middle = xyz_middle[valid_idxs]
|
xyz_middle = xyz_middle[valid_idxs]
|
||||||
rgb = rgb[valid_idxs]
|
rgb = rgb[valid_idxs]
|
||||||
label = label[valid_idxs]
|
semantic_label = semantic_label[valid_idxs]
|
||||||
instance_label = self.getCroppedInstLabel(instance_label, valid_idxs)
|
instance_label = self.getCroppedInstLabel(instance_label, valid_idxs)
|
||||||
return xyz, xyz_middle, rgb, label, instance_label
|
return xyz, xyz_middle, rgb, semantic_label, instance_label
|
||||||
|
|
||||||
def transform_test(self, xyz, rgb, label, instance_label):
|
def transform_test(self, xyz, rgb, semantic_label, instance_label):
|
||||||
xyz_middle = self.dataAugment(xyz, False, True, True)
|
xyz_middle = self.dataAugment(xyz, False, True, True)
|
||||||
xyz = xyz_middle * self.voxel_cfg.scale
|
xyz = xyz_middle * self.voxel_cfg.scale
|
||||||
xyz -= xyz.min(0)
|
xyz -= xyz.min(0)
|
||||||
valid_idxs = np.ones(xyz.shape[0], dtype=bool)
|
valid_idxs = np.ones(xyz.shape[0], dtype=bool)
|
||||||
instance_label = self.getCroppedInstLabel(instance_label, valid_idxs) # TODO remove this
|
instance_label = self.getCroppedInstLabel(instance_label, valid_idxs) # TODO remove this
|
||||||
return xyz, xyz_middle, rgb, label, instance_label
|
return xyz, xyz_middle, rgb, semantic_label, instance_label
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
filename = self.filenames[index]
|
filename = self.filenames[index]
|
||||||
@ -159,26 +159,26 @@ class CustomDataset(Dataset):
|
|||||||
data = self.transform_train(*data) if self.training else self.transform_test(*data)
|
data = self.transform_train(*data) if self.training else self.transform_test(*data)
|
||||||
if data is None:
|
if data is None:
|
||||||
return None
|
return None
|
||||||
xyz, xyz_middle, rgb, label, instance_label = data
|
xyz, xyz_middle, rgb, semantic_label, instance_label = data
|
||||||
info = self.getInstanceInfo(xyz_middle, instance_label.astype(np.int32), label)
|
info = self.getInstanceInfo(xyz_middle, instance_label.astype(np.int32), semantic_label)
|
||||||
inst_num, inst_pointnum, inst_cls, pt_offset_label = info
|
inst_num, inst_pointnum, inst_cls, pt_offset_label = info
|
||||||
loc = torch.from_numpy(xyz).long()
|
coord = torch.from_numpy(xyz).long()
|
||||||
loc_float = torch.from_numpy(xyz_middle)
|
coord_float = torch.from_numpy(xyz_middle)
|
||||||
feat = torch.from_numpy(rgb).float()
|
feat = torch.from_numpy(rgb).float()
|
||||||
if self.training:
|
if self.training:
|
||||||
feat += torch.randn(3) * 0.1
|
feat += torch.randn(3) * 0.1
|
||||||
label = torch.from_numpy(label)
|
semantic_label = torch.from_numpy(semantic_label)
|
||||||
instance_label = torch.from_numpy(instance_label)
|
instance_label = torch.from_numpy(instance_label)
|
||||||
pt_offset_label = torch.from_numpy(pt_offset_label)
|
pt_offset_label = torch.from_numpy(pt_offset_label)
|
||||||
return (scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_pointnum,
|
return (scan_id, coord, coord_float, feat, semantic_label, instance_label, inst_num,
|
||||||
inst_cls, pt_offset_label)
|
inst_pointnum, inst_cls, pt_offset_label)
|
||||||
|
|
||||||
def collate_fn(self, batch):
|
def collate_fn(self, batch):
|
||||||
scan_ids = []
|
scan_ids = []
|
||||||
locs = []
|
coords = []
|
||||||
locs_float = []
|
coords_float = []
|
||||||
feats = []
|
feats = []
|
||||||
labels = []
|
semantic_labels = []
|
||||||
instance_labels = []
|
instance_labels = []
|
||||||
|
|
||||||
instance_pointnum = [] # (total_nInst), int
|
instance_pointnum = [] # (total_nInst), int
|
||||||
@ -190,15 +190,15 @@ class CustomDataset(Dataset):
|
|||||||
for data in batch:
|
for data in batch:
|
||||||
if data is None:
|
if data is None:
|
||||||
continue
|
continue
|
||||||
(scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_pointnum,
|
(scan_id, coord, coord_float, feat, semantic_label, instance_label, inst_num,
|
||||||
inst_cls, pt_offset_label) = data
|
inst_pointnum, inst_cls, pt_offset_label) = data
|
||||||
instance_label[np.where(instance_label != -100)] += total_inst_num
|
instance_label[np.where(instance_label != -100)] += total_inst_num
|
||||||
total_inst_num += inst_num
|
total_inst_num += inst_num
|
||||||
scan_ids.append(scan_id)
|
scan_ids.append(scan_id)
|
||||||
locs.append(torch.cat([loc.new_full((loc.size(0), 1), batch_id), loc], 1))
|
coords.append(torch.cat([coord.new_full((coord.size(0), 1), batch_id), coord], 1))
|
||||||
locs_float.append(loc_float)
|
coords_float.append(coord_float)
|
||||||
feats.append(feat)
|
feats.append(feat)
|
||||||
labels.append(label)
|
semantic_labels.append(semantic_label)
|
||||||
instance_labels.append(instance_label)
|
instance_labels.append(instance_label)
|
||||||
instance_pointnum.extend(inst_pointnum)
|
instance_pointnum.extend(inst_pointnum)
|
||||||
instance_cls.extend(inst_cls)
|
instance_cls.extend(inst_cls)
|
||||||
@ -209,29 +209,29 @@ class CustomDataset(Dataset):
|
|||||||
self.logger.info(f'batch is truncated from size {len(batch)} to {batch_id}')
|
self.logger.info(f'batch is truncated from size {len(batch)} to {batch_id}')
|
||||||
|
|
||||||
# merge all the scenes in the batch
|
# merge all the scenes in the batch
|
||||||
locs = torch.cat(locs, 0) # long (N, 1 + 3), the batch item idx is put in locs[:, 0]
|
coords = torch.cat(coords, 0) # long (N, 1 + 3), the batch item idx is put in coords[:, 0]
|
||||||
batch_idxs = locs[:, 0].int()
|
batch_idxs = coords[:, 0].int()
|
||||||
locs_float = torch.cat(locs_float, 0).to(torch.float32) # float (N, 3)
|
coords_float = torch.cat(coords_float, 0).to(torch.float32) # float (N, 3)
|
||||||
feats = torch.cat(feats, 0) # float (N, C)
|
feats = torch.cat(feats, 0) # float (N, C)
|
||||||
labels = torch.cat(labels, 0).long() # long (N)
|
semantic_labels = torch.cat(semantic_labels, 0).long() # long (N)
|
||||||
instance_labels = torch.cat(instance_labels, 0).long() # long (N)
|
instance_labels = torch.cat(instance_labels, 0).long() # long (N)
|
||||||
instance_pointnum = torch.tensor(instance_pointnum, dtype=torch.int) # int (total_nInst)
|
instance_pointnum = torch.tensor(instance_pointnum, dtype=torch.int) # int (total_nInst)
|
||||||
instance_cls = torch.tensor(instance_cls, dtype=torch.long) # long (total_nInst)
|
instance_cls = torch.tensor(instance_cls, dtype=torch.long) # long (total_nInst)
|
||||||
pt_offset_labels = torch.cat(pt_offset_labels).float()
|
pt_offset_labels = torch.cat(pt_offset_labels).float()
|
||||||
|
|
||||||
spatial_shape = np.clip(
|
spatial_shape = np.clip(
|
||||||
locs.max(0)[0][1:].numpy() + 1, self.voxel_cfg.spatial_shape[0], None)
|
coords.max(0)[0][1:].numpy() + 1, self.voxel_cfg.spatial_shape[0], None)
|
||||||
voxel_locs, v2p_map, p2v_map = voxelization_idx(locs, 1)
|
voxel_coords, v2p_map, p2v_map = voxelization_idx(coords, 1)
|
||||||
return {
|
return {
|
||||||
'scan_ids': scan_ids,
|
'scan_ids': scan_ids,
|
||||||
'locs': locs,
|
'coords': coords,
|
||||||
'batch_idxs': batch_idxs,
|
'batch_idxs': batch_idxs,
|
||||||
'voxel_locs': voxel_locs,
|
'voxel_coords': voxel_coords,
|
||||||
'p2v_map': p2v_map,
|
'p2v_map': p2v_map,
|
||||||
'v2p_map': v2p_map,
|
'v2p_map': v2p_map,
|
||||||
'locs_float': locs_float,
|
'coords_float': coords_float,
|
||||||
'feats': feats,
|
'feats': feats,
|
||||||
'labels': labels,
|
'semantic_labels': semantic_labels,
|
||||||
'instance_labels': instance_labels,
|
'instance_labels': instance_labels,
|
||||||
'instance_pointnum': instance_pointnum,
|
'instance_pointnum': instance_pointnum,
|
||||||
'instance_cls': instance_cls,
|
'instance_cls': instance_cls,
|
||||||
|
|||||||
@ -26,21 +26,21 @@ class S3DISDataset(CustomDataset):
|
|||||||
|
|
||||||
def load(self, filename):
|
def load(self, filename):
|
||||||
# TODO make file load results consistent
|
# TODO make file load results consistent
|
||||||
xyz, rgb, label, instance_label, _, _ = torch.load(filename)
|
xyz, rgb, semantic_label, instance_label, _, _ = torch.load(filename)
|
||||||
# subsample data
|
# subsample data
|
||||||
if self.training:
|
if self.training:
|
||||||
N = xyz.shape[0]
|
N = xyz.shape[0]
|
||||||
inds = np.random.choice(N, int(N * 0.25), replace=False)
|
inds = np.random.choice(N, int(N * 0.25), replace=False)
|
||||||
xyz = xyz[inds]
|
xyz = xyz[inds]
|
||||||
rgb = rgb[inds]
|
rgb = rgb[inds]
|
||||||
label = label[inds]
|
semantic_label = semantic_label[inds]
|
||||||
instance_label = self.getCroppedInstLabel(instance_label, inds)
|
instance_label = self.getCroppedInstLabel(instance_label, inds)
|
||||||
return xyz, rgb, label, instance_label
|
return xyz, rgb, semantic_label, instance_label
|
||||||
|
|
||||||
def crop(self, xyz, step=64):
|
def crop(self, xyz, step=64):
|
||||||
return super().crop(xyz, step=step)
|
return super().crop(xyz, step=step)
|
||||||
|
|
||||||
def transform_test(self, xyz, rgb, label, instance_label):
|
def transform_test(self, xyz, rgb, semantic_label, instance_label):
|
||||||
# devide into 4 piecies
|
# devide into 4 piecies
|
||||||
inds = np.arange(xyz.shape[0])
|
inds = np.arange(xyz.shape[0])
|
||||||
piece_1 = inds[::4]
|
piece_1 = inds[::4]
|
||||||
@ -64,37 +64,37 @@ class S3DISDataset(CustomDataset):
|
|||||||
rgb = np.concatenate(rgb_list, 0)
|
rgb = np.concatenate(rgb_list, 0)
|
||||||
valid_idxs = np.ones(xyz.shape[0], dtype=bool)
|
valid_idxs = np.ones(xyz.shape[0], dtype=bool)
|
||||||
instance_label = self.getCroppedInstLabel(instance_label, valid_idxs) # TODO remove this
|
instance_label = self.getCroppedInstLabel(instance_label, valid_idxs) # TODO remove this
|
||||||
return xyz, xyz_middle, rgb, label, instance_label
|
return xyz, xyz_middle, rgb, semantic_label, instance_label
|
||||||
|
|
||||||
def collate_fn(self, batch):
|
def collate_fn(self, batch):
|
||||||
if self.training:
|
if self.training:
|
||||||
return super().collate_fn(batch)
|
return super().collate_fn(batch)
|
||||||
|
|
||||||
# assume 1 scan only
|
# assume 1 scan only
|
||||||
(scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_pointnum, inst_cls,
|
(scan_id, coord, coord_float, feat, semantic_label, instance_label, inst_num, inst_pointnum,
|
||||||
pt_offset_label) = batch[0]
|
inst_cls, pt_offset_label) = batch[0]
|
||||||
scan_ids = [scan_id]
|
scan_ids = [scan_id]
|
||||||
locs = loc.long()
|
coords = coord.long()
|
||||||
batch_idxs = torch.zeros_like(loc[:, 0].int())
|
batch_idxs = torch.zeros_like(coord[:, 0].int())
|
||||||
locs_float = loc_float.float()
|
coords_float = coord_float.float()
|
||||||
feats = feat.float()
|
feats = feat.float()
|
||||||
labels = label.long()
|
semantic_labels = semantic_label.long()
|
||||||
instance_labels = instance_label.long()
|
instance_labels = instance_label.long()
|
||||||
instance_pointnum = torch.tensor([inst_pointnum], dtype=torch.int)
|
instance_pointnum = torch.tensor([inst_pointnum], dtype=torch.int)
|
||||||
instance_cls = torch.tensor([inst_cls], dtype=torch.long)
|
instance_cls = torch.tensor([inst_cls], dtype=torch.long)
|
||||||
pt_offset_labels = pt_offset_label.float()
|
pt_offset_labels = pt_offset_label.float()
|
||||||
spatial_shape = np.clip((locs.max(0)[0][1:] + 1).numpy(), self.voxel_cfg.spatial_shape[0],
|
spatial_shape = np.clip((coords.max(0)[0][1:] + 1).numpy(), self.voxel_cfg.spatial_shape[0],
|
||||||
None)
|
None)
|
||||||
voxel_locs, v2p_map, p2v_map = voxelization_idx(locs, 4)
|
voxel_coords, v2p_map, p2v_map = voxelization_idx(coords, 4)
|
||||||
return {
|
return {
|
||||||
'scan_ids': scan_ids,
|
'scan_ids': scan_ids,
|
||||||
'batch_idxs': batch_idxs,
|
'batch_idxs': batch_idxs,
|
||||||
'voxel_locs': voxel_locs,
|
'voxel_coords': voxel_coords,
|
||||||
'p2v_map': p2v_map,
|
'p2v_map': p2v_map,
|
||||||
'v2p_map': v2p_map,
|
'v2p_map': v2p_map,
|
||||||
'locs_float': locs_float,
|
'coords_float': coords_float,
|
||||||
'feats': feats,
|
'feats': feats,
|
||||||
'labels': labels,
|
'semantic_labels': semantic_labels,
|
||||||
'instance_labels': instance_labels,
|
'instance_labels': instance_labels,
|
||||||
'instance_pointnum': instance_pointnum,
|
'instance_pointnum': instance_pointnum,
|
||||||
'instance_cls': instance_cls,
|
'instance_cls': instance_cls,
|
||||||
|
|||||||
@ -7,8 +7,8 @@ class ScanNetDataset(CustomDataset):
|
|||||||
'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink',
|
'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink',
|
||||||
'bathtub', 'otherfurniture')
|
'bathtub', 'otherfurniture')
|
||||||
|
|
||||||
def getInstanceInfo(self, xyz, instance_label, label):
|
def getInstanceInfo(self, xyz, instance_label, semantic_label):
|
||||||
ret = super().getInstanceInfo(xyz, instance_label, label)
|
ret = super().getInstanceInfo(xyz, instance_label, semantic_label)
|
||||||
instance_num, instance_pointnum, instance_cls, pt_offset_label = ret
|
instance_num, instance_pointnum, instance_cls, pt_offset_label = ret
|
||||||
instance_cls = [x - 2 if x != -100 else x for x in instance_cls]
|
instance_cls = [x - 2 if x != -100 else x for x in instance_cls]
|
||||||
return instance_num, instance_pointnum, instance_cls, pt_offset_label
|
return instance_num, instance_pointnum, instance_cls, pt_offset_label
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|||||||
from ..lib.softgroup_ops import (ballquery_batch_p, bfs_cluster, get_mask_iou_on_cluster,
|
from ..lib.softgroup_ops import (ballquery_batch_p, bfs_cluster, get_mask_iou_on_cluster,
|
||||||
get_mask_iou_on_pred, get_mask_label, global_avg_pool, sec_max,
|
get_mask_iou_on_pred, get_mask_label, global_avg_pool, sec_max,
|
||||||
sec_min, voxelization, voxelization_idx)
|
sec_min, voxelization, voxelization_idx)
|
||||||
|
from ..util import force_fp32
|
||||||
from .blocks import MLP, ResidualBlock, UBlock
|
from .blocks import MLP, ResidualBlock, UBlock
|
||||||
|
|
||||||
|
|
||||||
@ -80,25 +81,13 @@ class SoftGroup(nn.Module):
|
|||||||
|
|
||||||
def forward(self, batch, return_loss=False):
|
def forward(self, batch, return_loss=False):
|
||||||
if return_loss:
|
if return_loss:
|
||||||
return self.forward_train(batch)
|
return self.forward_train(**batch)
|
||||||
else:
|
else:
|
||||||
return self.forward_test(batch)
|
return self.forward_test(**batch)
|
||||||
|
|
||||||
def forward_train(self, batch):
|
|
||||||
batch_idxs = batch['batch_idxs'].cuda()
|
|
||||||
voxel_coords = batch['voxel_locs'].cuda()
|
|
||||||
p2v_map = batch['p2v_map'].cuda()
|
|
||||||
v2p_map = batch['v2p_map'].cuda()
|
|
||||||
coords_float = batch['locs_float'].cuda()
|
|
||||||
feats = batch['feats'].cuda()
|
|
||||||
semantic_labels = batch['labels'].cuda()
|
|
||||||
instance_labels = batch['instance_labels'].cuda()
|
|
||||||
instance_pointnum = batch['instance_pointnum'].cuda()
|
|
||||||
instance_cls = batch['instance_cls'].cuda()
|
|
||||||
pt_offset_labels = batch['pt_offset_labels'].cuda()
|
|
||||||
spatial_shape = batch['spatial_shape']
|
|
||||||
batch_size = batch['batch_size']
|
|
||||||
|
|
||||||
|
def forward_train(self, batch_idxs, voxel_coords, p2v_map, v2p_map, coords_float, feats,
|
||||||
|
semantic_labels, instance_labels, instance_pointnum, instance_cls,
|
||||||
|
pt_offset_labels, spatial_shape, batch_size, **kwargs):
|
||||||
losses = {}
|
losses = {}
|
||||||
feats = torch.cat((feats, coords_float), 1)
|
feats = torch.cat((feats, coords_float), 1)
|
||||||
voxel_feats = voxelization(feats, p2v_map)
|
voxel_feats = voxelization(feats, p2v_map)
|
||||||
@ -155,6 +144,7 @@ class SoftGroup(nn.Module):
|
|||||||
losses['offset_loss'] = (offset_loss, pos_inds.sum())
|
losses['offset_loss'] = (offset_loss, pos_inds.sum())
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
@force_fp32(apply_to=('cls_scores', 'mask_scores', 'iou_scores'))
|
||||||
def instance_loss(self, cls_scores, mask_scores, iou_scores, proposals_idx, proposals_offset,
|
def instance_loss(self, cls_scores, mask_scores, iou_scores, proposals_idx, proposals_offset,
|
||||||
instance_labels, instance_pointnum, instance_cls, instance_batch_idxs):
|
instance_labels, instance_pointnum, instance_cls, instance_batch_idxs):
|
||||||
losses = {}
|
losses = {}
|
||||||
@ -208,18 +198,9 @@ class SoftGroup(nn.Module):
|
|||||||
losses['iou_score_loss'] = (iou_score_loss, iou_score_weight.sum())
|
losses['iou_score_loss'] = (iou_score_loss, iou_score_weight.sum())
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def forward_test(self, batch):
|
def forward_test(self, batch_idxs, voxel_coords, p2v_map, v2p_map, coords_float, feats,
|
||||||
batch_idxs = batch['batch_idxs'].cuda()
|
semantic_labels, instance_labels, spatial_shape, batch_size, scan_ids,
|
||||||
voxel_coords = batch['voxel_locs'].cuda()
|
**kwargs):
|
||||||
p2v_map = batch['p2v_map'].cuda()
|
|
||||||
v2p_map = batch['v2p_map'].cuda()
|
|
||||||
coords_float = batch['locs_float'].cuda()
|
|
||||||
feats = batch['feats'].cuda()
|
|
||||||
labels = batch['labels'].cuda()
|
|
||||||
instance_labels = batch['instance_labels'].cuda()
|
|
||||||
spatial_shape = batch['spatial_shape']
|
|
||||||
batch_size = batch['batch_size']
|
|
||||||
|
|
||||||
feats = torch.cat((feats, coords_float), 1)
|
feats = torch.cat((feats, coords_float), 1)
|
||||||
voxel_feats = voxelization(feats, p2v_map)
|
voxel_feats = voxelization(feats, p2v_map)
|
||||||
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
|
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
|
||||||
@ -227,7 +208,8 @@ class SoftGroup(nn.Module):
|
|||||||
input, v2p_map, coords_float, x4_split=self.test_cfg.x4_split)
|
input, v2p_map, coords_float, x4_split=self.test_cfg.x4_split)
|
||||||
semantic_preds = semantic_scores.max(1)[1]
|
semantic_preds = semantic_scores.max(1)[1]
|
||||||
ret = dict(
|
ret = dict(
|
||||||
semantic_preds=semantic_preds.cpu().numpy(), semantic_labels=labels.cpu().numpy())
|
semantic_preds=semantic_preds.cpu().numpy(),
|
||||||
|
semantic_labels=semantic_labels.cpu().numpy())
|
||||||
if not self.semantic_only:
|
if not self.semantic_only:
|
||||||
proposals_idx, proposals_offset = self.forward_grouping(semantic_scores, pt_offsets,
|
proposals_idx, proposals_offset = self.forward_grouping(semantic_scores, pt_offsets,
|
||||||
batch_idxs, coords_float,
|
batch_idxs, coords_float,
|
||||||
@ -236,10 +218,9 @@ class SoftGroup(nn.Module):
|
|||||||
output_feats, coords_float,
|
output_feats, coords_float,
|
||||||
**self.instance_voxel_cfg)
|
**self.instance_voxel_cfg)
|
||||||
_, cls_scores, iou_scores, mask_scores = self.forward_instance(inst_feats, inst_map)
|
_, cls_scores, iou_scores, mask_scores = self.forward_instance(inst_feats, inst_map)
|
||||||
pred_instances = self.get_instances(batch['scan_ids'][0], proposals_idx,
|
pred_instances = self.get_instances(scan_ids[0], proposals_idx, semantic_scores,
|
||||||
semantic_scores, cls_scores, iou_scores,
|
cls_scores, iou_scores, mask_scores)
|
||||||
mask_scores)
|
gt_instances = self.get_gt_instances(semantic_labels, instance_labels)
|
||||||
gt_instances = self.get_gt_instances(labels, instance_labels)
|
|
||||||
ret.update(dict(pred_instances=pred_instances, gt_instances=gt_instances))
|
ret.update(dict(pred_instances=pred_instances, gt_instances=gt_instances))
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@ -289,6 +270,7 @@ class SoftGroup(nn.Module):
|
|||||||
x_new[p] = x_split[i]
|
x_new[p] = x_split[i]
|
||||||
return x_new
|
return x_new
|
||||||
|
|
||||||
|
@force_fp32(apply_to=('semantic_scores, pt_offsets'))
|
||||||
def forward_grouping(self,
|
def forward_grouping(self,
|
||||||
semantic_scores,
|
semantic_scores,
|
||||||
pt_offsets,
|
pt_offsets,
|
||||||
@ -350,6 +332,7 @@ class SoftGroup(nn.Module):
|
|||||||
|
|
||||||
return instance_batch_idxs, cls_scores, iou_scores, mask_scores
|
return instance_batch_idxs, cls_scores, iou_scores, mask_scores
|
||||||
|
|
||||||
|
@force_fp32(apply_to=('semantic_scores', 'cls_scores', 'iou_scores', 'mask_scores'))
|
||||||
def get_instances(self, scan_id, proposals_idx, semantic_scores, cls_scores, iou_scores,
|
def get_instances(self, scan_id, proposals_idx, semantic_scores, cls_scores, iou_scores,
|
||||||
mask_scores):
|
mask_scores):
|
||||||
num_instances = cls_scores.size(0)
|
num_instances = cls_scores.size(0)
|
||||||
@ -402,19 +385,21 @@ class SoftGroup(nn.Module):
|
|||||||
instances.append(pred)
|
instances.append(pred)
|
||||||
return instances
|
return instances
|
||||||
|
|
||||||
def get_gt_instances(self, labels, instance_labels):
|
def get_gt_instances(self, semantic_labels, instance_labels):
|
||||||
"""Get gt instances for evaluation."""
|
"""Get gt instances for evaluation."""
|
||||||
# convert to evaluation format 0: ignore, 1->N: valid
|
# convert to evaluation format 0: ignore, 1->N: valid
|
||||||
label_shift = self.semantic_classes - self.instance_classes
|
label_shift = self.semantic_classes - self.instance_classes
|
||||||
labels = labels - label_shift + 1
|
semantic_labels = semantic_labels - label_shift + 1
|
||||||
labels[labels < 0] = 0
|
semantic_labels[semantic_labels < 0] = 0
|
||||||
instance_labels += 1
|
instance_labels += 1
|
||||||
ignore_inds = instance_labels < 0
|
ignore_inds = instance_labels < 0
|
||||||
gt_ins = labels * 1000 + instance_labels
|
# scannet encoding rule
|
||||||
|
gt_ins = semantic_labels * 1000 + instance_labels
|
||||||
gt_ins[ignore_inds] = 0
|
gt_ins[ignore_inds] = 0
|
||||||
gt_ins = gt_ins.cpu().numpy()
|
gt_ins = gt_ins.cpu().numpy()
|
||||||
return gt_ins
|
return gt_ins
|
||||||
|
|
||||||
|
@force_fp32(apply_to='feats')
|
||||||
def clusters_voxelization(self,
|
def clusters_voxelization(self,
|
||||||
clusters_idx,
|
clusters_idx,
|
||||||
clusters_offset,
|
clusters_offset,
|
||||||
@ -466,6 +451,7 @@ class SoftGroup(nn.Module):
|
|||||||
assert batch_offsets[-1] == batch_idxs.shape[0]
|
assert batch_offsets[-1] == batch_idxs.shape[0]
|
||||||
return batch_offsets
|
return batch_offsets
|
||||||
|
|
||||||
|
@force_fp32(apply_to=('x'))
|
||||||
def global_pool(self, x, expand=False):
|
def global_pool(self, x, expand=False):
|
||||||
indices = x.indices[:, 0]
|
indices = x.indices[:, 0]
|
||||||
batch_counts = torch.bincount(indices)
|
batch_counts = torch.bincount(indices)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from .dist import get_dist_info, init_dist
|
from .dist import get_dist_info, init_dist
|
||||||
|
from .fp16 import force_fp32
|
||||||
from .logger import get_root_logger
|
from .logger import get_root_logger
|
||||||
from .optim import build_optimizer
|
from .optim import build_optimizer
|
||||||
from .utils import *
|
from .utils import *
|
||||||
|
|||||||
66
softgroup/util/fp16.py
Normal file
66
softgroup/util/fp16.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# Simplfied from mmcv.
|
||||||
|
# Directly use torch.cuda.amp.autocast for mix-precision and support sparse tensor
|
||||||
|
import functools
|
||||||
|
from collections import abc
|
||||||
|
from inspect import getfullargspec
|
||||||
|
|
||||||
|
import spconv.pytorch as spconv
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def cast_tensor_type(inputs, src_type, dst_type):
|
||||||
|
if isinstance(inputs, torch.Tensor):
|
||||||
|
return inputs.to(dst_type) if inputs.dtype == src_type else inputs
|
||||||
|
elif isinstance(inputs, spconv.SparseConvTensor):
|
||||||
|
if inputs.features.dtype == src_type:
|
||||||
|
features = inputs.features.to(dst_type)
|
||||||
|
inputs = inputs.replace_feature(features)
|
||||||
|
return inputs
|
||||||
|
elif isinstance(inputs, abc.Mapping):
|
||||||
|
return type(inputs)({k: cast_tensor_type(v, src_type, dst_type) for k, v in inputs.items()})
|
||||||
|
elif isinstance(inputs, abc.Iterable):
|
||||||
|
return type(inputs)(cast_tensor_type(item, src_type, dst_type) for item in inputs)
|
||||||
|
else:
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
def force_fp32(apply_to=None, out_fp16=False):
|
||||||
|
|
||||||
|
def force_fp32_wrapper(old_func):
|
||||||
|
|
||||||
|
@functools.wraps(old_func)
|
||||||
|
def new_func(*args, **kwargs):
|
||||||
|
if not isinstance(args[0], torch.nn.Module):
|
||||||
|
raise TypeError('@force_fp32 can only be used to decorate the '
|
||||||
|
'method of nn.Module')
|
||||||
|
# get the arg spec of the decorated method
|
||||||
|
args_info = getfullargspec(old_func)
|
||||||
|
# get the argument names to be casted
|
||||||
|
args_to_cast = args_info.args if apply_to is None else apply_to
|
||||||
|
# convert the args that need to be processed
|
||||||
|
new_args = []
|
||||||
|
if args:
|
||||||
|
arg_names = args_info.args[:len(args)]
|
||||||
|
for i, arg_name in enumerate(arg_names):
|
||||||
|
if arg_name in args_to_cast:
|
||||||
|
new_args.append(cast_tensor_type(args[i], torch.half, torch.float))
|
||||||
|
else:
|
||||||
|
new_args.append(args[i])
|
||||||
|
# convert the kwargs that need to be processed
|
||||||
|
new_kwargs = dict()
|
||||||
|
if kwargs:
|
||||||
|
for arg_name, arg_value in kwargs.items():
|
||||||
|
if arg_name in args_to_cast:
|
||||||
|
new_kwargs[arg_name] = cast_tensor_type(arg_value, torch.half, torch.float)
|
||||||
|
else:
|
||||||
|
new_kwargs[arg_name] = arg_value
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
output = old_func(*new_args, **new_kwargs)
|
||||||
|
# cast the results back to fp32 if necessary
|
||||||
|
if out_fp16:
|
||||||
|
output = cast_tensor_type(output, torch.float, torch.half)
|
||||||
|
return output
|
||||||
|
|
||||||
|
return new_func
|
||||||
|
|
||||||
|
return force_fp32_wrapper
|
||||||
15
train.py
15
train.py
@ -38,9 +38,9 @@ if __name__ == '__main__':
|
|||||||
init_dist()
|
init_dist()
|
||||||
|
|
||||||
# work_dir & logger
|
# work_dir & logger
|
||||||
if args.work_dir is not None:
|
if args.work_dir:
|
||||||
cfg.work_dir = args.work_dir
|
cfg.work_dir = args.work_dir
|
||||||
elif cfg.get('work_dir', None) is None:
|
else:
|
||||||
cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
|
cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
|
||||||
os.makedirs(osp.abspath(cfg.work_dir), exist_ok=True)
|
os.makedirs(osp.abspath(cfg.work_dir), exist_ok=True)
|
||||||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||||
@ -48,6 +48,7 @@ if __name__ == '__main__':
|
|||||||
logger = get_root_logger(log_file=log_file)
|
logger = get_root_logger(log_file=log_file)
|
||||||
logger.info(f'Config:\n{cfg_txt}')
|
logger.info(f'Config:\n{cfg_txt}')
|
||||||
logger.info(f'Distributed: {args.dist}')
|
logger.info(f'Distributed: {args.dist}')
|
||||||
|
logger.info(f'Mix precision training: {cfg.fp16}')
|
||||||
shutil.copy(args.config, osp.join(cfg.work_dir, osp.basename(args.config)))
|
shutil.copy(args.config, osp.join(cfg.work_dir, osp.basename(args.config)))
|
||||||
writer = SummaryWriter(cfg.work_dir)
|
writer = SummaryWriter(cfg.work_dir)
|
||||||
|
|
||||||
@ -55,6 +56,7 @@ if __name__ == '__main__':
|
|||||||
model = SoftGroup(**cfg.model).cuda()
|
model = SoftGroup(**cfg.model).cuda()
|
||||||
if args.dist:
|
if args.dist:
|
||||||
model = DistributedDataParallel(model, device_ids=[torch.cuda.current_device()])
|
model = DistributedDataParallel(model, device_ids=[torch.cuda.current_device()])
|
||||||
|
scaler = torch.cuda.amp.GradScaler(enabled=cfg.fp16)
|
||||||
|
|
||||||
# data
|
# data
|
||||||
train_set = build_dataset(cfg.data.train, logger)
|
train_set = build_dataset(cfg.data.train, logger)
|
||||||
@ -91,7 +93,9 @@ if __name__ == '__main__':
|
|||||||
data_time.update(time.time() - end)
|
data_time.update(time.time() - end)
|
||||||
|
|
||||||
cosine_lr_after_step(optimizer, cfg.optimizer.lr, epoch - 1, cfg.step_epoch, cfg.epochs)
|
cosine_lr_after_step(optimizer, cfg.optimizer.lr, epoch - 1, cfg.step_epoch, cfg.epochs)
|
||||||
loss, log_vars = model(batch, return_loss=True)
|
|
||||||
|
with torch.cuda.amp.autocast(enabled=cfg.fp16):
|
||||||
|
loss, log_vars = model(batch, return_loss=True)
|
||||||
|
|
||||||
# meter_dict
|
# meter_dict
|
||||||
for k, v in log_vars.items():
|
for k, v in log_vars.items():
|
||||||
@ -101,8 +105,9 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# backward
|
# backward
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
scaler.scale(loss).backward()
|
||||||
optimizer.step()
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
# time and print
|
# time and print
|
||||||
current_iter = (epoch - 1) * len(train_loader) + i
|
current_iter = (epoch - 1) * len(train_loader) + i
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user