mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
convert spconv1 to spconv2 checkpoint
This commit is contained in:
parent
70c86093db
commit
80a663eec6
72
configs/softgroup_s3dis_backbone_fold5.yaml
Normal file
72
configs/softgroup_s3dis_backbone_fold5.yaml
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
model:
|
||||||
|
channels: 32
|
||||||
|
num_blocks: 7
|
||||||
|
semantic_classes: 13
|
||||||
|
instance_classes: 13
|
||||||
|
sem2ins_classes: [0, 1]
|
||||||
|
semantic_only: True
|
||||||
|
ignore_label: -100
|
||||||
|
grouping_cfg:
|
||||||
|
score_thr: 0.2
|
||||||
|
radius: 0.04
|
||||||
|
mean_active: 300
|
||||||
|
class_numpoint_mean: [1823, 7457, 6189, 7424, 34229, 1724, 5439,
|
||||||
|
6016, 39796, 5279, 5092, 12210, 10225]
|
||||||
|
npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1
|
||||||
|
ignore_classes: [0, 1]
|
||||||
|
instance_voxel_cfg:
|
||||||
|
scale: 50
|
||||||
|
spatial_shape: 20
|
||||||
|
train_cfg:
|
||||||
|
max_proposal_num: 200
|
||||||
|
pos_iou_thr: 0.5
|
||||||
|
test_cfg:
|
||||||
|
x4_split: True
|
||||||
|
cls_score_thr: 0.001
|
||||||
|
mask_score_thr: -0.5
|
||||||
|
min_npoint: 100
|
||||||
|
fixed_modules: []
|
||||||
|
|
||||||
|
data:
|
||||||
|
train:
|
||||||
|
type: 's3dis'
|
||||||
|
data_root: 'dataset/s3dis/preprocess'
|
||||||
|
prefix: ['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6']
|
||||||
|
suffix: '_inst_nostuff.pth'
|
||||||
|
repeat: 20
|
||||||
|
training: True
|
||||||
|
voxel_cfg:
|
||||||
|
scale: 50
|
||||||
|
spatial_shape: [128, 512]
|
||||||
|
max_npoint: 250000
|
||||||
|
min_npoint: 5000
|
||||||
|
test:
|
||||||
|
type: 's3dis'
|
||||||
|
data_root: 'dataset/s3dis/preprocess'
|
||||||
|
prefix: 'Area_5'
|
||||||
|
suffix: '_inst_nostuff.pth'
|
||||||
|
training: False
|
||||||
|
voxel_cfg:
|
||||||
|
scale: 50
|
||||||
|
spatial_shape: [128, 512]
|
||||||
|
max_npoint: 250000
|
||||||
|
min_npoint: 5000
|
||||||
|
|
||||||
|
dataloader:
|
||||||
|
train:
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
test:
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 1
|
||||||
|
|
||||||
|
optimizer:
|
||||||
|
type: 'Adam'
|
||||||
|
lr: 0.004
|
||||||
|
|
||||||
|
fp16: False
|
||||||
|
epochs: 20
|
||||||
|
step_epoch: 0
|
||||||
|
save_freq: 2
|
||||||
|
pretrain: 'work_dirs/softgroup_scannet_backbone/epoch_120.pth'
|
||||||
|
work_dir: ''
|
||||||
72
configs/softgroup_s3dis_fold5.yaml
Normal file
72
configs/softgroup_s3dis_fold5.yaml
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
model:
|
||||||
|
channels: 32
|
||||||
|
num_blocks: 7
|
||||||
|
semantic_classes: 13
|
||||||
|
instance_classes: 13
|
||||||
|
sem2ins_classes: [0, 1]
|
||||||
|
semantic_only: False
|
||||||
|
ignore_label: -100
|
||||||
|
grouping_cfg:
|
||||||
|
score_thr: 0.2
|
||||||
|
radius: 0.04
|
||||||
|
mean_active: 300
|
||||||
|
class_numpoint_mean: [1823, 7457, 6189, 7424, 34229, 1724, 5439,
|
||||||
|
6016, 39796, 5279, 5092, 12210, 10225]
|
||||||
|
npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1
|
||||||
|
ignore_classes: [0, 1]
|
||||||
|
instance_voxel_cfg:
|
||||||
|
scale: 50
|
||||||
|
spatial_shape: 20
|
||||||
|
train_cfg:
|
||||||
|
max_proposal_num: 200
|
||||||
|
pos_iou_thr: 0.5
|
||||||
|
test_cfg:
|
||||||
|
x4_split: True
|
||||||
|
cls_score_thr: 0.001
|
||||||
|
mask_score_thr: -0.5
|
||||||
|
min_npoint: 100
|
||||||
|
fixed_modules: ['input_conv', 'unet', 'output_layer', 'semantic_linear', 'offset_linear']
|
||||||
|
|
||||||
|
data:
|
||||||
|
train:
|
||||||
|
type: 's3dis'
|
||||||
|
data_root: 'dataset/s3dis/preprocess'
|
||||||
|
prefix: ['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6']
|
||||||
|
suffix: '_inst_nostuff.pth'
|
||||||
|
repeat: 20
|
||||||
|
training: True
|
||||||
|
voxel_cfg:
|
||||||
|
scale: 50
|
||||||
|
spatial_shape: [128, 512]
|
||||||
|
max_npoint: 250000
|
||||||
|
min_npoint: 5000
|
||||||
|
test:
|
||||||
|
type: 's3dis'
|
||||||
|
data_root: 'dataset/s3dis/preprocess'
|
||||||
|
prefix: 'Area_5'
|
||||||
|
suffix: '_inst_nostuff.pth'
|
||||||
|
training: False
|
||||||
|
voxel_cfg:
|
||||||
|
scale: 50
|
||||||
|
spatial_shape: [128, 512]
|
||||||
|
max_npoint: 250000
|
||||||
|
min_npoint: 5000
|
||||||
|
|
||||||
|
dataloader:
|
||||||
|
train:
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
test:
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 1
|
||||||
|
|
||||||
|
optimizer:
|
||||||
|
type: 'Adam'
|
||||||
|
lr: 0.004
|
||||||
|
|
||||||
|
fp16: False
|
||||||
|
epochs: 20
|
||||||
|
step_epoch: 0
|
||||||
|
save_freq: 2
|
||||||
|
pretrain: 'work_dirs/softgroup_s3dis_backbone_fold5/latest.pth'
|
||||||
|
work_dir: ''
|
||||||
@ -90,11 +90,17 @@ class CustomDataset(Dataset):
|
|||||||
if jitter and np.random.rand() < prob:
|
if jitter and np.random.rand() < prob:
|
||||||
m += np.random.randn(3, 3) * 0.1
|
m += np.random.randn(3, 3) * 0.1
|
||||||
if flip and np.random.rand() < prob:
|
if flip and np.random.rand() < prob:
|
||||||
m[0][0] *= np.random.randint(0, 2) * 2 - 1 # flip x randomly
|
m[0][0] *= np.random.randint(0, 2) * 2 - 1
|
||||||
if rot and np.random.rand() < prob:
|
if rot and np.random.rand() < prob:
|
||||||
theta = np.random.rand() * 2 * math.pi
|
theta = np.random.rand() * 2 * math.pi
|
||||||
m = np.matmul(m, [[math.cos(theta), math.sin(theta), 0],
|
m = np.matmul(m, [[math.cos(theta), math.sin(theta), 0],
|
||||||
[-math.sin(theta), math.cos(theta), 0], [0, 0, 1]]) # rotation
|
[-math.sin(theta), math.cos(theta), 0], [0, 0, 1]])
|
||||||
|
else:
|
||||||
|
# Empirically, slightly rotate the scene can match the results from checkpoint
|
||||||
|
theta = 0.45 * math.pi
|
||||||
|
m = np.matmul(m, [[math.cos(theta), math.sin(theta), 0],
|
||||||
|
[-math.sin(theta), math.cos(theta), 0], [0, 0, 1]])
|
||||||
|
|
||||||
return np.matmul(xyz, m)
|
return np.matmul(xyz, m)
|
||||||
|
|
||||||
def crop(self, xyz, step=32):
|
def crop(self, xyz, step=32):
|
||||||
|
|||||||
@ -47,11 +47,13 @@ class S3DISDataset(CustomDataset):
|
|||||||
piece_2 = inds[1::4]
|
piece_2 = inds[1::4]
|
||||||
piece_3 = inds[2::4]
|
piece_3 = inds[2::4]
|
||||||
piece_4 = inds[3::4]
|
piece_4 = inds[3::4]
|
||||||
xyz_aug = self.dataAugment(xyz, False, True, True)
|
xyz_aug = self.dataAugment(xyz, False, False, False)
|
||||||
|
|
||||||
xyz_list = []
|
xyz_list = []
|
||||||
xyz_middle_list = []
|
xyz_middle_list = []
|
||||||
rgb_list = []
|
rgb_list = []
|
||||||
|
semantic_label_list = []
|
||||||
|
instance_label_list = []
|
||||||
for batch, piece in enumerate([piece_1, piece_2, piece_3, piece_4]):
|
for batch, piece in enumerate([piece_1, piece_2, piece_3, piece_4]):
|
||||||
xyz_middle = xyz_aug[piece]
|
xyz_middle = xyz_aug[piece]
|
||||||
xyz = xyz_middle * self.voxel_cfg.scale
|
xyz = xyz_middle * self.voxel_cfg.scale
|
||||||
@ -59,9 +61,13 @@ class S3DISDataset(CustomDataset):
|
|||||||
xyz_list.append(np.concatenate([np.full((xyz.shape[0], 1), batch), xyz], 1))
|
xyz_list.append(np.concatenate([np.full((xyz.shape[0], 1), batch), xyz], 1))
|
||||||
xyz_middle_list.append(xyz_middle)
|
xyz_middle_list.append(xyz_middle)
|
||||||
rgb_list.append(rgb[piece])
|
rgb_list.append(rgb[piece])
|
||||||
|
semantic_label_list.append(semantic_label[piece])
|
||||||
|
instance_label_list.append(instance_label[piece])
|
||||||
xyz = np.concatenate(xyz_list, 0)
|
xyz = np.concatenate(xyz_list, 0)
|
||||||
xyz_middle = np.concatenate(xyz_middle_list, 0)
|
xyz_middle = np.concatenate(xyz_middle_list, 0)
|
||||||
rgb = np.concatenate(rgb_list, 0)
|
rgb = np.concatenate(rgb_list, 0)
|
||||||
|
semantic_label = np.concatenate(semantic_label_list, 0)
|
||||||
|
instance_label = np.concatenate(instance_label_list, 0)
|
||||||
valid_idxs = np.ones(xyz.shape[0], dtype=bool)
|
valid_idxs = np.ones(xyz.shape[0], dtype=bool)
|
||||||
instance_label = self.getCroppedInstLabel(instance_label, valid_idxs) # TODO remove this
|
instance_label = self.getCroppedInstLabel(instance_label, valid_idxs) # TODO remove this
|
||||||
return xyz, xyz_middle, rgb, semantic_label, instance_label
|
return xyz, xyz_middle, rgb, semantic_label, instance_label
|
||||||
|
|||||||
@ -8,13 +8,13 @@ from torch import nn
|
|||||||
|
|
||||||
class MLP(nn.Sequential):
|
class MLP(nn.Sequential):
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, norm_fn, num_layers=2):
|
def __init__(self, in_channels, out_channels, norm_fn=None, num_layers=2):
|
||||||
modules = []
|
modules = []
|
||||||
for _ in range(num_layers - 1):
|
for _ in range(num_layers - 1):
|
||||||
modules.extend(
|
modules.append(nn.Linear(in_channels, in_channels))
|
||||||
[nn.Linear(in_channels, in_channels, bias=False),
|
if norm_fn:
|
||||||
norm_fn(in_channels),
|
modules.append(norm_fn(in_channels))
|
||||||
nn.ReLU()])
|
modules.append(nn.ReLU())
|
||||||
modules.append(nn.Linear(in_channels, out_channels))
|
modules.append(nn.Linear(in_channels, out_channels))
|
||||||
return super().__init__(*modules)
|
return super().__init__(*modules)
|
||||||
|
|
||||||
@ -22,6 +22,7 @@ class MLP(nn.Sequential):
|
|||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Linear):
|
if isinstance(m, nn.Linear):
|
||||||
nn.init.xavier_uniform_(m.weight)
|
nn.init.xavier_uniform_(m.weight)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
nn.init.normal_(self[-1].weight, 0, 0.01)
|
nn.init.normal_(self[-1].weight, 0, 0.01)
|
||||||
nn.init.constant_(self[-1].bias, 0)
|
nn.init.constant_(self[-1].bias, 0)
|
||||||
|
|
||||||
@ -30,7 +31,7 @@ class MLP(nn.Sequential):
|
|||||||
class Custom1x1Subm3d(spconv.SparseConv3d):
|
class Custom1x1Subm3d(spconv.SparseConv3d):
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
features = torch.mm(input.features, self.weight.view(self.in_channels, self.out_channels))
|
features = torch.mm(input.features, self.weight.view(self.out_channels, self.in_channels).T)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
features += self.bias
|
features += self.bias
|
||||||
out_tensor = spconv.SparseConvTensor(features, input.indices, input.spatial_shape,
|
out_tensor = spconv.SparseConvTensor(features, input.indices, input.spatial_shape,
|
||||||
|
|||||||
@ -40,6 +40,7 @@ class SoftGroup(nn.Module):
|
|||||||
self.instance_voxel_cfg = instance_voxel_cfg
|
self.instance_voxel_cfg = instance_voxel_cfg
|
||||||
self.train_cfg = train_cfg
|
self.train_cfg = train_cfg
|
||||||
self.test_cfg = test_cfg
|
self.test_cfg = test_cfg
|
||||||
|
self.fixed_modules = fixed_modules
|
||||||
|
|
||||||
block = ResidualBlock
|
block = ResidualBlock
|
||||||
norm_fn = functools.partial(nn.BatchNorm1d, eps=1e-4, momentum=0.1)
|
norm_fn = functools.partial(nn.BatchNorm1d, eps=1e-4, momentum=0.1)
|
||||||
@ -53,22 +54,21 @@ class SoftGroup(nn.Module):
|
|||||||
self.output_layer = spconv.SparseSequential(norm_fn(channels), nn.ReLU())
|
self.output_layer = spconv.SparseSequential(norm_fn(channels), nn.ReLU())
|
||||||
|
|
||||||
# point-wise prediction
|
# point-wise prediction
|
||||||
self.semantic_linear = MLP(channels, semantic_classes, norm_fn, num_layers=2)
|
self.semantic_linear = MLP(channels, semantic_classes, norm_fn=norm_fn, num_layers=2)
|
||||||
self.offset_linear = MLP(channels, 3, norm_fn, num_layers=2)
|
self.offset_linear = MLP(channels, 3, norm_fn=norm_fn, num_layers=2)
|
||||||
|
|
||||||
# topdown refinement path
|
# topdown refinement path
|
||||||
if not semantic_only:
|
if not semantic_only:
|
||||||
self.tiny_unet = UBlock([channels, 2 * channels], norm_fn, 2, block, indice_key_id=11)
|
self.tiny_unet = UBlock([channels, 2 * channels], norm_fn, 2, block, indice_key_id=11)
|
||||||
self.tiny_unet_outputlayer = spconv.SparseSequential(norm_fn(channels), nn.ReLU())
|
self.tiny_unet_outputlayer = spconv.SparseSequential(norm_fn(channels), nn.ReLU())
|
||||||
self.cls_linear = MLP(channels, instance_classes + 1, norm_fn, num_layers=2)
|
self.cls_linear = nn.Linear(channels, instance_classes + 1)
|
||||||
self.mask_linear = MLP(channels, instance_classes + 1, norm_fn, num_layers=2)
|
self.mask_linear = MLP(channels, instance_classes + 1, norm_fn=None, num_layers=2)
|
||||||
self.iou_score_linear = MLP(channels, instance_classes + 1, norm_fn, num_layers=2)
|
self.iou_score_linear = nn.Linear(channels, instance_classes + 1)
|
||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
for mod in fixed_modules:
|
for mod in fixed_modules:
|
||||||
mod = getattr(self, mod)
|
mod = getattr(self, mod)
|
||||||
mod.eval()
|
|
||||||
for param in mod.parameters():
|
for param in mod.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
@ -79,6 +79,17 @@ class SoftGroup(nn.Module):
|
|||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
elif isinstance(m, MLP):
|
elif isinstance(m, MLP):
|
||||||
m.init_weights()
|
m.init_weights()
|
||||||
|
for m in [self.cls_linear, self.iou_score_linear]:
|
||||||
|
nn.init.normal_(m.weight, 0, 0.01)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def train(self, mode=True):
|
||||||
|
super().train(mode)
|
||||||
|
for mod in self.fixed_modules:
|
||||||
|
mod = getattr(self, mod)
|
||||||
|
for m in mod.modules():
|
||||||
|
if isinstance(m, nn.BatchNorm1d):
|
||||||
|
m.eval()
|
||||||
|
|
||||||
def forward(self, batch, return_loss=False):
|
def forward(self, batch, return_loss=False):
|
||||||
if return_loss:
|
if return_loss:
|
||||||
@ -94,8 +105,7 @@ class SoftGroup(nn.Module):
|
|||||||
feats = torch.cat((feats, coords_float), 1)
|
feats = torch.cat((feats, coords_float), 1)
|
||||||
voxel_feats = voxelization(feats, p2v_map)
|
voxel_feats = voxelization(feats, p2v_map)
|
||||||
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
|
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
|
||||||
semantic_scores, pt_offsets, output_feats, coords_float = self.forward_backbone(
|
semantic_scores, pt_offsets, output_feats = self.forward_backbone(input, v2p_map)
|
||||||
input, v2p_map, coords_float)
|
|
||||||
|
|
||||||
# point wise losses
|
# point wise losses
|
||||||
point_wise_loss = self.point_wise_loss(semantic_scores, pt_offsets, semantic_labels,
|
point_wise_loss = self.point_wise_loss(semantic_scores, pt_offsets, semantic_labels,
|
||||||
@ -213,8 +223,13 @@ class SoftGroup(nn.Module):
|
|||||||
feats = torch.cat((feats, coords_float), 1)
|
feats = torch.cat((feats, coords_float), 1)
|
||||||
voxel_feats = voxelization(feats, p2v_map)
|
voxel_feats = voxelization(feats, p2v_map)
|
||||||
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
|
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
|
||||||
semantic_scores, pt_offsets, output_feats, coords_float = self.forward_backbone(
|
semantic_scores, pt_offsets, output_feats = self.forward_backbone(
|
||||||
input, v2p_map, coords_float, x4_split=self.test_cfg.x4_split)
|
input, v2p_map, x4_split=self.test_cfg.x4_split)
|
||||||
|
if self.test_cfg.x4_split:
|
||||||
|
coords_float = self.merge_4_parts(coords_float)
|
||||||
|
semantic_labels = self.merge_4_parts(semantic_labels)
|
||||||
|
instance_labels = self.merge_4_parts(instance_labels)
|
||||||
|
pt_offset_labels = self.merge_4_parts(pt_offset_labels)
|
||||||
semantic_preds = semantic_scores.max(1)[1]
|
semantic_preds = semantic_scores.max(1)[1]
|
||||||
ret = dict(
|
ret = dict(
|
||||||
semantic_preds=semantic_preds.cpu().numpy(),
|
semantic_preds=semantic_preds.cpu().numpy(),
|
||||||
@ -236,11 +251,10 @@ class SoftGroup(nn.Module):
|
|||||||
ret.update(dict(pred_instances=pred_instances, gt_instances=gt_instances))
|
ret.update(dict(pred_instances=pred_instances, gt_instances=gt_instances))
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def forward_backbone(self, input, input_map, coords, x4_split=False):
|
def forward_backbone(self, input, input_map, x4_split=False):
|
||||||
if x4_split:
|
if x4_split:
|
||||||
output_feats = self.forward_4_parts(input, input_map)
|
output_feats = self.forward_4_parts(input, input_map)
|
||||||
output_feats = self.merge_4_parts(output_feats)
|
output_feats = self.merge_4_parts(output_feats)
|
||||||
coords = self.merge_4_parts(coords)
|
|
||||||
else:
|
else:
|
||||||
output = self.input_conv(input)
|
output = self.input_conv(input)
|
||||||
output = self.unet(output)
|
output = self.unet(output)
|
||||||
@ -249,7 +263,7 @@ class SoftGroup(nn.Module):
|
|||||||
|
|
||||||
semantic_scores = self.semantic_linear(output_feats)
|
semantic_scores = self.semantic_linear(output_feats)
|
||||||
pt_offsets = self.offset_linear(output_feats)
|
pt_offsets = self.offset_linear(output_feats)
|
||||||
return semantic_scores, pt_offsets, output_feats, coords
|
return semantic_scores, pt_offsets, output_feats
|
||||||
|
|
||||||
def forward_4_parts(self, x, input_map):
|
def forward_4_parts(self, x, input_map):
|
||||||
"""Helper function for s3dis: devide and forward 4 parts of a scene."""
|
"""Helper function for s3dis: devide and forward 4 parts of a scene."""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user