mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
update stpls3d
This commit is contained in:
parent
fe9b225f40
commit
365a9d8900
@ -4,10 +4,11 @@ model:
|
||||
semantic_classes: 15
|
||||
instance_classes: 14
|
||||
sem2ins_classes: []
|
||||
semantic_only: True
|
||||
semantic_only: False
|
||||
semantic_weight: [1.0, 1.0, 44.0, 21.9, 1.8, 25.1, 31.5, 21.8, 24.0, 54.4, 114.4,
|
||||
81.2, 43.6, 9.7, 22.4]
|
||||
ignore_label: -100
|
||||
with_coords: False
|
||||
grouping_cfg:
|
||||
score_thr: 0.2
|
||||
radius: 0.9
|
||||
@ -20,13 +21,15 @@ model:
|
||||
scale: 3
|
||||
spatial_shape: 20
|
||||
train_cfg:
|
||||
max_proposal_num: 200
|
||||
max_proposal_num: 300
|
||||
pos_iou_thr: 0.5
|
||||
match_low_quality: True
|
||||
min_pos_thr: 0.1
|
||||
test_cfg:
|
||||
x4_split: False
|
||||
cls_score_thr: 0.001
|
||||
mask_score_thr: -0.5
|
||||
min_npoint: 100
|
||||
min_npoint: 15
|
||||
fixed_modules: []
|
||||
|
||||
data:
|
||||
@ -56,7 +59,7 @@ data:
|
||||
|
||||
dataloader:
|
||||
train:
|
||||
batch_size: 12
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
test:
|
||||
batch_size: 1
|
||||
@ -64,16 +67,18 @@ dataloader:
|
||||
|
||||
optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.002 # TODO change to 4 gpu
|
||||
lr: 0.004
|
||||
|
||||
save_cfg:
|
||||
semantic: True
|
||||
offset: True
|
||||
instance: True
|
||||
|
||||
eval_min_npoint: 10
|
||||
|
||||
fp16: False
|
||||
epochs: 20
|
||||
epochs: 108
|
||||
step_epoch: 20
|
||||
save_freq: 4
|
||||
pretrain: ''
|
||||
pretrain: './work_dirs/softgroup_stpls3d/latest.pth'
|
||||
work_dir: ''
|
||||
|
||||
79
configs/softgroup_stpls3d_backbone.yaml
Normal file
79
configs/softgroup_stpls3d_backbone.yaml
Normal file
@ -0,0 +1,79 @@
|
||||
model:
|
||||
channels: 16
|
||||
num_blocks: 7
|
||||
semantic_classes: 15
|
||||
instance_classes: 14
|
||||
sem2ins_classes: []
|
||||
semantic_only: True
|
||||
semantic_weight: [1.0, 1.0, 44.0, 21.9, 1.8, 25.1, 31.5, 21.8, 24.0, 54.4, 114.4,
|
||||
81.2, 43.6, 9.7, 22.4]
|
||||
ignore_label: -100
|
||||
grouping_cfg:
|
||||
score_thr: 0.2
|
||||
radius: 0.9
|
||||
mean_active: 3
|
||||
class_numpoint_mean: [-1., 10408., 58., 124., 1351., 162., 430., 1090., 451., 26., 43.,
|
||||
61., 39., 109., 1239]
|
||||
npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1
|
||||
ignore_classes: [0]
|
||||
instance_voxel_cfg:
|
||||
scale: 3
|
||||
spatial_shape: 20
|
||||
train_cfg:
|
||||
max_proposal_num: 200
|
||||
pos_iou_thr: 0.5
|
||||
test_cfg:
|
||||
x4_split: False
|
||||
cls_score_thr: 0.001
|
||||
mask_score_thr: -0.5
|
||||
min_npoint: 100
|
||||
fixed_modules: []
|
||||
|
||||
data:
|
||||
train:
|
||||
type: 'stpls3d'
|
||||
data_root: 'dataset/Synthetic_v3_InstanceSegmentation'
|
||||
prefix: 'train'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: True
|
||||
repeat: 4
|
||||
voxel_cfg:
|
||||
scale: 3
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
test:
|
||||
type: 'stpls3d'
|
||||
data_root: 'dataset/Synthetic_v3_InstanceSegmentation'
|
||||
prefix: 'val'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: False
|
||||
voxel_cfg:
|
||||
scale: 3
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
|
||||
dataloader:
|
||||
train:
|
||||
batch_size: 12
|
||||
num_workers: 4
|
||||
test:
|
||||
batch_size: 1
|
||||
num_workers: 1
|
||||
|
||||
optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.002 # TODO change to 4 gpu
|
||||
|
||||
save_cfg:
|
||||
semantic: True
|
||||
offset: True
|
||||
instance: True
|
||||
|
||||
fp16: False
|
||||
epochs: 20
|
||||
step_epoch: 20
|
||||
save_freq: 4
|
||||
pretrain: ''
|
||||
work_dir: ''
|
||||
@ -3,9 +3,9 @@ from .custom import CustomDataset
|
||||
|
||||
class STPLS3DDataset(CustomDataset):
|
||||
|
||||
CLASSES = ('ground', 'building', 'low vegetation', 'medium vegetation', 'high vegetation',
|
||||
'vehicle', 'truck', 'aircraft', 'militaryVehicle', 'bike', 'motorcycle',
|
||||
'light pole', 'street sign', 'clutter', 'fence')
|
||||
CLASSES = ('building', 'low vegetation', 'med. vegetation', 'high vegetation', 'vehicle',
|
||||
'truck', 'aircraft', 'militaryVehicle', 'bike', 'motorcycle', 'light pole',
|
||||
'street sign', 'clutter', 'fence')
|
||||
|
||||
def getInstanceInfo(self, xyz, instance_label, semantic_label):
|
||||
ret = super().getInstanceInfo(xyz, instance_label, semantic_label)
|
||||
|
||||
@ -12,7 +12,7 @@ from .instance_eval_util import get_instances
|
||||
|
||||
class ScanNetEval(object):
|
||||
|
||||
def __init__(self, class_labels, iou_type=None, use_label=True):
|
||||
def __init__(self, class_labels, min_npoint=None, iou_type=None, use_label=True):
|
||||
self.valid_class_labels = class_labels
|
||||
self.valid_class_ids = np.arange(len(class_labels)) + 1
|
||||
self.id2label = {}
|
||||
@ -22,7 +22,10 @@ class ScanNetEval(object):
|
||||
self.id2label[self.valid_class_ids[i]] = self.valid_class_labels[i]
|
||||
|
||||
self.ious = np.append(np.arange(0.5, 0.95, 0.05), 0.25)
|
||||
self.min_region_sizes = np.array([100])
|
||||
if min_npoint:
|
||||
self.min_region_sizes = np.array([min_npoint])
|
||||
else:
|
||||
self.min_region_sizes = np.array([100])
|
||||
self.distance_threshes = np.array([float('inf')])
|
||||
self.distance_confs = np.array([-float('inf')])
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ class SoftGroup(nn.Module):
|
||||
semantic_weight=None,
|
||||
sem2ins_classes=[],
|
||||
ignore_label=-100,
|
||||
with_coords=True,
|
||||
grouping_cfg=None,
|
||||
instance_voxel_cfg=None,
|
||||
train_cfg=None,
|
||||
@ -38,6 +39,7 @@ class SoftGroup(nn.Module):
|
||||
self.semantic_weight = semantic_weight
|
||||
self.sem2ins_classes = sem2ins_classes
|
||||
self.ignore_label = ignore_label
|
||||
self.with_coords = with_coords
|
||||
self.grouping_cfg = grouping_cfg
|
||||
self.instance_voxel_cfg = instance_voxel_cfg
|
||||
self.train_cfg = train_cfg
|
||||
@ -48,9 +50,10 @@ class SoftGroup(nn.Module):
|
||||
norm_fn = functools.partial(nn.BatchNorm1d, eps=1e-4, momentum=0.1)
|
||||
|
||||
# backbone
|
||||
in_channels = 6 if with_coords else 3
|
||||
self.input_conv = spconv.SparseSequential(
|
||||
spconv.SubMConv3d(
|
||||
6, channels, kernel_size=3, padding=1, bias=False, indice_key='subm1'))
|
||||
in_channels, channels, kernel_size=3, padding=1, bias=False, indice_key='subm1'))
|
||||
block_channels = [channels * (i + 1) for i in range(num_blocks)]
|
||||
self.unet = UBlock(block_channels, norm_fn, 2, block, indice_key_id=1)
|
||||
self.output_layer = spconv.SparseSequential(norm_fn(channels), nn.ReLU())
|
||||
@ -105,7 +108,8 @@ class SoftGroup(nn.Module):
|
||||
semantic_labels, instance_labels, instance_pointnum, instance_cls,
|
||||
pt_offset_labels, spatial_shape, batch_size, **kwargs):
|
||||
losses = {}
|
||||
feats = torch.cat((feats, coords_float), 1)
|
||||
if self.with_coords:
|
||||
feats = torch.cat((feats, coords_float), 1)
|
||||
voxel_feats = voxelization(feats, p2v_map)
|
||||
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
|
||||
semantic_scores, pt_offsets, output_feats = self.forward_backbone(input, v2p_map)
|
||||
@ -175,14 +179,30 @@ class SoftGroup(nn.Module):
|
||||
fg_instance_cls = instance_cls[fg_inds]
|
||||
fg_ious_on_cluster = ious_on_cluster[:, fg_inds]
|
||||
|
||||
# assign proposal to gt idx. -1: negative, 0 -> num_gts - 1: positive
|
||||
num_proposals = fg_ious_on_cluster.size(0)
|
||||
num_gts = fg_ious_on_cluster.size(1)
|
||||
assigned_gt_inds = fg_ious_on_cluster.new_full((num_proposals, ), -1, dtype=torch.long)
|
||||
|
||||
# overlap > thr on fg instances are positive samples
|
||||
max_iou, gt_inds = fg_ious_on_cluster.max(1)
|
||||
max_iou, argmax_iou = fg_ious_on_cluster.max(1)
|
||||
pos_inds = max_iou >= self.train_cfg.pos_iou_thr
|
||||
pos_gt_inds = gt_inds[pos_inds]
|
||||
assigned_gt_inds[pos_inds] = argmax_iou[pos_inds]
|
||||
|
||||
# allow low-quality proposals with best iou to be as positive sample
|
||||
# in case pos_iou_thr is too high to achieve
|
||||
match_low_quality = getattr(self.train_cfg, 'match_low_quality', False)
|
||||
min_pos_thr = getattr(self.train_cfg, 'min_pos_thr', 0)
|
||||
if match_low_quality:
|
||||
gt_max_iou, gt_argmax_iou = fg_ious_on_cluster.max(0)
|
||||
for i in range(num_gts):
|
||||
if gt_max_iou[i] >= min_pos_thr:
|
||||
assigned_gt_inds[gt_argmax_iou[i]] = i
|
||||
|
||||
# compute cls loss. follow detection convention: 0 -> K - 1 are fg, K is bg
|
||||
labels = fg_instance_cls.new_full((fg_ious_on_cluster.size(0), ), self.instance_classes)
|
||||
labels[pos_inds] = fg_instance_cls[pos_gt_inds]
|
||||
labels = fg_instance_cls.new_full((num_proposals, ), self.instance_classes)
|
||||
pos_inds = assigned_gt_inds >= 0
|
||||
labels[pos_inds] = fg_instance_cls[assigned_gt_inds[pos_inds]]
|
||||
cls_loss = F.cross_entropy(cls_scores, labels)
|
||||
losses['cls_loss'] = cls_loss
|
||||
|
||||
@ -227,7 +247,8 @@ class SoftGroup(nn.Module):
|
||||
def forward_test(self, batch_idxs, voxel_coords, p2v_map, v2p_map, coords_float, feats,
|
||||
semantic_labels, instance_labels, pt_offset_labels, spatial_shape, batch_size,
|
||||
scan_ids, **kwargs):
|
||||
feats = torch.cat((feats, coords_float), 1)
|
||||
if self.with_coords:
|
||||
feats = torch.cat((feats, coords_float), 1)
|
||||
voxel_feats = voxelization(feats, p2v_map)
|
||||
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
|
||||
semantic_scores, pt_offsets, output_feats = self.forward_backbone(
|
||||
|
||||
@ -116,7 +116,8 @@ def main():
|
||||
gt_insts.append(res['gt_instances'])
|
||||
if not cfg.model.semantic_only:
|
||||
logger.info('Evaluate instance segmentation')
|
||||
scannet_eval = ScanNetEval(dataset.CLASSES)
|
||||
eval_min_npoint = getattr(cfg, 'eval_min_npoint', None)
|
||||
scannet_eval = ScanNetEval(dataset.CLASSES, eval_min_npoint)
|
||||
scannet_eval.evaluate(pred_insts, gt_insts)
|
||||
logger.info('Evaluate semantic segmentation and offset MAE')
|
||||
ignore_label = cfg.model.ignore_label
|
||||
|
||||
@ -108,7 +108,8 @@ def validate(epoch, model, val_loader, cfg, logger, writer):
|
||||
all_gt_insts.append(res['gt_instances'])
|
||||
if not cfg.model.semantic_only:
|
||||
logger.info('Evaluate instance segmentation')
|
||||
scannet_eval = ScanNetEval(val_set.CLASSES)
|
||||
eval_min_npoint = getattr(cfg, 'eval_min_npoint', None)
|
||||
scannet_eval = ScanNetEval(val_set.CLASSES, eval_min_npoint)
|
||||
eval_res = scannet_eval.evaluate(all_pred_insts, all_gt_insts)
|
||||
writer.add_scalar('val/AP', eval_res['all_ap'], epoch)
|
||||
writer.add_scalar('val/AP_50', eval_res['all_ap_50%'], epoch)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user