diff --git a/.gitignore b/.gitignore index c5d7b1c..e616337 100644 --- a/.gitignore +++ b/.gitignore @@ -76,3 +76,7 @@ dataset/s3dis/preprocess dataset/s3dis/val_gt dataset/s3dis/preprocess_sample dataset/s3dis/Stanford3dDataset_v1.2 + +dataset/stpls3d/train +dataset/stpls3d/val +dataset/stpls3d/Synthetic_v3_InstanceSegmentation diff --git a/configs/softgroup_stpls3d.yaml b/configs/softgroup_stpls3d.yaml new file mode 100644 index 0000000..e832bee --- /dev/null +++ b/configs/softgroup_stpls3d.yaml @@ -0,0 +1,84 @@ +model: + channels: 16 + num_blocks: 7 + semantic_classes: 15 + instance_classes: 14 + sem2ins_classes: [] + semantic_only: False + semantic_weight: [1.0, 1.0, 44.0, 21.9, 1.8, 25.1, 31.5, 21.8, 24.0, 54.4, 114.4, + 81.2, 43.6, 9.7, 22.4] + ignore_label: -100 + with_coords: False + grouping_cfg: + score_thr: 0.2 + radius: 0.9 + mean_active: 3 + class_numpoint_mean: [-1., 10408., 58., 124., 1351., 162., 430., 1090., 451., 26., 43., + 61., 39., 109., 1239] + npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1 + ignore_classes: [0] + instance_voxel_cfg: + scale: 3 + spatial_shape: 20 + train_cfg: + max_proposal_num: 300 + pos_iou_thr: 0.5 + match_low_quality: True + min_pos_thr: 0.1 + test_cfg: + x4_split: False + cls_score_thr: 0.001 + mask_score_thr: -0.5 + min_npoint: 15 + fixed_modules: [] + +data: + train: + type: 'stpls3d' + data_root: 'dataset/stpls3d' + prefix: 'train' + suffix: '_inst_nostuff.pth' + training: True + repeat: 4 + voxel_cfg: + scale: 3 + spatial_shape: [128, 512] + max_npoint: 250000 + min_npoint: 5000 + test: + type: 'stpls3d' + data_root: 'dataset/stpls3d' + prefix: 'val' + suffix: '_inst_nostuff.pth' + training: False + voxel_cfg: + scale: 3 + spatial_shape: [128, 512] + max_npoint: 250000 + min_npoint: 5000 + +dataloader: + train: + batch_size: 4 + num_workers: 4 + test: + batch_size: 1 + num_workers: 1 + +optimizer: + type: 'Adam' + lr: 0.004 + +save_cfg: + semantic: True + offset: True + instance: True + +eval_min_npoint: 10 + +fp16: False +epochs: 108 +step_epoch: 20 +save_freq: 4 +pretrain: './work_dirs/softgroup_stpls3d_backbone/latest.pth' +work_dir: '' diff --git a/configs/softgroup_stpls3d_backbone.yaml b/configs/softgroup_stpls3d_backbone.yaml new file mode 100644 index 0000000..f0a92b8 --- /dev/null +++ b/configs/softgroup_stpls3d_backbone.yaml @@ -0,0 +1,80 @@ +model: + channels: 16 + num_blocks: 7 + semantic_classes: 15 + instance_classes: 14 + sem2ins_classes: [] + semantic_only: True + semantic_weight: [1.0, 1.0, 44.0, 21.9, 1.8, 25.1, 31.5, 21.8, 24.0, 54.4, 114.4, + 81.2, 43.6, 9.7, 22.4] + with_coords: False + ignore_label: -100 + grouping_cfg: + score_thr: 0.2 + radius: 0.9 + mean_active: 3 + class_numpoint_mean: [-1., 10408., 58., 124., 1351., 162., 430., 1090., 451., 26., 43., + 61., 39., 109., 1239] + npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1 + ignore_classes: [0] + instance_voxel_cfg: + scale: 3 + spatial_shape: 20 + train_cfg: + max_proposal_num: 200 + pos_iou_thr: 0.5 + test_cfg: + x4_split: False + cls_score_thr: 0.001 + mask_score_thr: -0.5 + min_npoint: 100 + fixed_modules: [] + +data: + train: + type: 'stpls3d' + data_root: 'dataset/stpls3d' + prefix: 'train' + suffix: '_inst_nostuff.pth' + training: True + repeat: 4 + voxel_cfg: + scale: 3 + spatial_shape: [128, 512] + max_npoint: 250000 + min_npoint: 5000 + test: + type: 'stpls3d' + data_root: 'dataset/stpls3d' + prefix: 'val' + suffix: '_inst_nostuff.pth' + training: False + voxel_cfg: + scale: 3 + spatial_shape: [128, 512] + max_npoint: 250000 + min_npoint: 5000 + +dataloader: + train: + batch_size: 4 + num_workers: 4 + test: + batch_size: 1 + num_workers: 1 + +optimizer: + type: 'Adam' + lr: 0.004 + +save_cfg: + semantic: True + offset: True + instance: True + +fp16: False +epochs: 20 +step_epoch: 20 +save_freq: 4 +pretrain: '' +work_dir: '' diff --git a/dataset/stpls3d/prepare_data.sh b/dataset/stpls3d/prepare_data.sh new file mode 100644 index 0000000..c1fc446 --- /dev/null +++ b/dataset/stpls3d/prepare_data.sh @@ -0,0 +1,3 @@ +#!/bin/bash +echo Preprocess data +python prepare_data_inst_instance_stpls3d.py diff --git a/dataset/stpls3d/prepare_data_inst_instance_stpls3d.py b/dataset/stpls3d/prepare_data_inst_instance_stpls3d.py new file mode 100644 index 0000000..5161ee9 --- /dev/null +++ b/dataset/stpls3d/prepare_data_inst_instance_stpls3d.py @@ -0,0 +1,170 @@ +# https://github.com/meidachen/STPLS3D/blob/main/HAIS/data/prepare_data_inst_instance_stpls3d.py +import glob +import json +import math +import os +import random + +import numpy as np +import pandas as pd +import torch + + +def splitPointCloud(cloud, size=50.0, stride=50): + limitMax = np.amax(cloud[:, 0:3], axis=0) + width = int(np.ceil((limitMax[0] - size) / stride)) + 1 + depth = int(np.ceil((limitMax[1] - size) / stride)) + 1 + cells = [(x * stride, y * stride) for x in range(width) for y in range(depth)] + blocks = [] + for (x, y) in cells: + xcond = (cloud[:, 0] <= x + size) & (cloud[:, 0] >= x) + ycond = (cloud[:, 1] <= y + size) & (cloud[:, 1] >= y) + cond = xcond & ycond + block = cloud[cond, :] + blocks.append(block) + return blocks + + +def getFiles(files, fileSplit): + res = [] + for filePath in files: + name = os.path.basename(filePath) + num = name[:2] if name[:2].isdigit() else name[:1] + if int(num) in fileSplit: + res.append(filePath) + return res + + +def dataAug(file, semanticKeep): + points = pd.read_csv(file, header=None).values + angle = random.randint(1, 359) + angleRadians = math.radians(angle) + rotationMatrix = np.array([[math.cos(angleRadians), -math.sin(angleRadians), 0], + [math.sin(angleRadians), + math.cos(angleRadians), 0], [0, 0, 1]]) + points[:, :3] = points[:, :3].dot(rotationMatrix) + pointsKept = points[np.in1d(points[:, 6], semanticKeep)] + return pointsKept + + +def preparePthFiles(files, split, outPutFolder, AugTimes=0): + # save the coordinates so that we can merge the data to a single scene + # after segmentation for visualization + outJsonPath = os.path.join(outPutFolder, 'coordShift.json') + coordShift = {} + # used to increase z range if it is smaller than this, + # over come the issue where spconv may crash for voxlization. + zThreshold = 6 + + # Map relevant classes to {1,...,14}, and ignored classes to -100 + remapper = np.ones(150) * (-100) + for i, x in enumerate([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]): + remapper[x] = i + # Map instance to -100 based on selected semantic + # (change a semantic to -100 if you want to ignore it for instance) + remapper_disableInstanceBySemantic = np.ones(150) * (-100) + for i, x in enumerate([-100, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]): + remapper_disableInstanceBySemantic[x] = i + + # only augment data for these classes + semanticKeep = [0, 2, 3, 7, 8, 9, 12, 13] + + counter = 0 + for file in files: + + for AugTime in range(AugTimes + 1): + if AugTime == 0: + points = pd.read_csv(file, header=None).values + else: + points = dataAug(file, semanticKeep) + name = os.path.basename(file).strip('.txt') + '_%d' % AugTime + + if split != 'test': + coordShift['globalShift'] = list(points[:, :3].min(0)) + points[:, :3] = points[:, :3] - points[:, :3].min(0) + + blocks = splitPointCloud(points, size=50, stride=50) + for blockNum, block in enumerate(blocks): + if (len(block) > 10000): + outFilePath = os.path.join(outPutFolder, + name + str(blockNum) + '_inst_nostuff.pth') + if (block[:, 2].max(0) - block[:, 2].min(0) < zThreshold): + block = np.append( + block, [[ + block[:, 0].mean(0), block[:, 1].mean(0), block[:, 2].max(0) + + (zThreshold - + (block[:, 2].max(0) - block[:, 2].min(0))), block[:, 3].mean(0), + block[:, 4].mean(0), block[:, 5].mean(0), -100, -100 + ]], + axis=0) + print('range z is smaller than threshold ') + print(name + str(blockNum) + '_inst_nostuff') + if split != 'test': + outFileName = name + str(blockNum) + '_inst_nostuff' + coordShift[outFileName] = list(block[:, :3].mean(0)) + coords = np.ascontiguousarray(block[:, :3] - block[:, :3].mean(0)) + + # coords = block[:, :3] + colors = np.ascontiguousarray(block[:, 3:6]) / 127.5 - 1 + + coords = np.float32(coords) + colors = np.float32(colors) + if split != 'test': + sem_labels = np.ascontiguousarray(block[:, -2]) + sem_labels = sem_labels.astype(np.int32) + sem_labels = remapper[np.array(sem_labels)] + + instance_labels = np.ascontiguousarray(block[:, -1]) + instance_labels = instance_labels.astype(np.float32) + + disableInstanceBySemantic_labels = np.ascontiguousarray(block[:, -2]) + disableInstanceBySemantic_labels = disableInstanceBySemantic_labels.astype( + np.int32) + disableInstanceBySemantic_labels = remapper_disableInstanceBySemantic[ + np.array(disableInstanceBySemantic_labels)] + instance_labels = np.where(disableInstanceBySemantic_labels == -100, -100, + instance_labels) + + # map instance from 0. + # [1:] because there are -100 + uniqueInstances = (np.unique(instance_labels))[1:].astype(np.int32) + remapper_instance = np.ones(50000) * (-100) + for i, j in enumerate(uniqueInstances): + remapper_instance[j] = i + + instance_labels = remapper_instance[instance_labels.astype(np.int32)] + + uniqueSemantics = (np.unique(sem_labels))[1:].astype(np.int32) + + if split == 'train' and (len(uniqueInstances) < 10 or + (len(uniqueSemantics) >= + (len(uniqueInstances) - 2))): + print('unique insance: %d' % len(uniqueInstances)) + print('unique semantic: %d' % len(uniqueSemantics)) + print() + counter += 1 + else: + torch.save((coords, colors, sem_labels, instance_labels), outFilePath) + else: + torch.save((coords, colors), outFilePath) + print('Total skipped file :%d' % counter) + json.dump(coordShift, open(outJsonPath, 'w')) + + +if __name__ == '__main__': + data_folder = 'Synthetic_v3_InstanceSegmentation' + filesOri = sorted(glob.glob(data_folder + '/*.txt')) + + trainSplit = [1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 21, 22, 23, 24] + trainFiles = getFiles(filesOri, trainSplit) + split = 'train' + trainOutDir = split + os.makedirs(trainOutDir, exist_ok=True) + preparePthFiles(trainFiles, split, trainOutDir, AugTimes=6) + + valSplit = [5, 10, 15, 20, 25] + split = 'val' + valFiles = getFiles(filesOri, valSplit) + valOutDir = split + os.makedirs(valOutDir, exist_ok=True) + preparePthFiles(valFiles, split, valOutDir) diff --git a/dataset/stpls3d/prepare_data_statistic_stpls3d.py b/dataset/stpls3d/prepare_data_statistic_stpls3d.py new file mode 100644 index 0000000..d7f37fe --- /dev/null +++ b/dataset/stpls3d/prepare_data_statistic_stpls3d.py @@ -0,0 +1,67 @@ +import glob +import math +import os + +import numpy as np +import torch + +data_folder = os.path.join( + os.path.dirname(os.getcwd()), 'dataset', 'Synthetic_v3_InstanceSegmentation', 'train') +files = sorted(glob.glob(data_folder + '/*.pth')) +numclass = 15 +semanticIDs = [] +for i in range(numclass): + semanticIDs.append(i) + +class_numpoint_mean_dict = {} +class_radius_mean = {} +for semanticID in semanticIDs: + class_numpoint_mean_dict[semanticID] = [] + class_radius_mean[semanticID] = [] +num_points_semantic = np.array([0 for i in range(numclass)]) + +for file in files: + coords, colors, sem_labels, instance_labels = torch.load(file) + points = np.concatenate( + [coords, colors, sem_labels[:, None].astype(int), instance_labels[:, None].astype(int)], + axis=1) + for semanticID in semanticIDs: + singleSemantic = points[np.where(points[:, 6] == semanticID)] + uniqueInstances, counts = np.unique(singleSemantic[:, 7], return_counts=True) + for count in counts: + class_numpoint_mean_dict[semanticID].append(count) + allRadius = [] + for uniqueInstance in uniqueInstances: + eachInstance = singleSemantic[np.where(singleSemantic[:, 7] == uniqueInstance)] + radius = (np.max(eachInstance, axis=0) - np.min(eachInstance, axis=0)) / 2 + radius = math.sqrt(radius[0]**2 + radius[1]**2 + radius[2]**2) + class_radius_mean[semanticID].append(radius) + + uniqueSemantic, semanticCount = np.unique(points[:, 6], return_counts=True) + uniqueSemanticCount = np.array([0 for i in range(numclass)]) + uniqueSemantic = uniqueSemantic.astype(int) + indexOf100 = np.where(uniqueSemantic == -100) + semanticCount = np.delete(semanticCount, indexOf100) + uniqueSemantic = np.delete(uniqueSemantic, indexOf100) + uniqueSemanticCount[uniqueSemantic] = semanticCount + num_points_semantic += uniqueSemanticCount + +class_numpoint_mean_list = [] +class_radius_mean_list = [] +for semanticID in semanticIDs: + class_numpoint_mean_list.append( + sum(class_numpoint_mean_dict[semanticID]) * 1.0 / len(class_numpoint_mean_dict[semanticID])) + class_radius_mean_list.append( + sum(class_radius_mean[semanticID]) / len(class_radius_mean[semanticID])) + +print('Using the printed list in hierarchical_aggregation.cpp for class_numpoint_mean_dict: ') +print([1.0] + [float('{0:0.0f}'.format(i)) for i in class_numpoint_mean_list][1:], sep=',') +print('Using the printed list in hierarchical_aggregation.cu for class_radius_mean: ') +print([1.0] + [float('{0:0.2f}'.format(i)) for i in class_radius_mean_list][1:], sep='') + +# make ground to 1 the make building to 1 +maxSemantic = np.max(num_points_semantic) +num_points_semantic = maxSemantic / num_points_semantic +num_points_semantic = num_points_semantic / num_points_semantic[1] +print('Using the printed list in hais_run_stpls3d.yaml for class_weight') +print([1.0, 1.0] + [float('{0:0.2f}'.format(i)) for i in num_points_semantic][2:], sep='') diff --git a/softgroup/data/__init__.py b/softgroup/data/__init__.py index 8efb18d..bccc5df 100644 --- a/softgroup/data/__init__.py +++ b/softgroup/data/__init__.py @@ -3,6 +3,7 @@ from torch.utils.data.distributed import DistributedSampler from .s3dis import S3DISDataset from .scannetv2 import ScanNetDataset +from .stpls3d import STPLS3DDataset __all__ = ['S3DISDataset', 'ScanNetDataset', 'build_dataset'] @@ -16,6 +17,8 @@ def build_dataset(data_cfg, logger): return S3DISDataset(**_data_cfg) elif data_type == 'scannetv2': return ScanNetDataset(**_data_cfg) + elif data_type == 'stpls3d': + return STPLS3DDataset(**_data_cfg) else: raise ValueError(f'Unknown {data_type}') diff --git a/softgroup/data/custom.py b/softgroup/data/custom.py index 3652bb1..b17e0a6 100644 --- a/softgroup/data/custom.py +++ b/softgroup/data/custom.py @@ -132,9 +132,8 @@ class CustomDataset(Dataset): xyz_middle = self.dataAugment(xyz, True, True, True, aug_prob) xyz = xyz_middle * self.voxel_cfg.scale if np.random.rand() < aug_prob: - xyz = self.elastic(xyz, 6 * self.voxel_cfg.scale // 50, 40 * self.voxel_cfg.scale / 50) - xyz = self.elastic(xyz, 20 * self.voxel_cfg.scale // 50, - 160 * self.voxel_cfg.scale / 50) + xyz = self.elastic(xyz, 6, 40.) + xyz = self.elastic(xyz, 20, 160.) # xyz_middle = xyz / self.voxel_cfg.scale xyz = xyz - xyz.min(0) max_tries = 5 diff --git a/softgroup/data/stpls3d.py b/softgroup/data/stpls3d.py new file mode 100644 index 0000000..7835f9c --- /dev/null +++ b/softgroup/data/stpls3d.py @@ -0,0 +1,15 @@ +from .custom import CustomDataset + + +class STPLS3DDataset(CustomDataset): + + CLASSES = ('building', 'low vegetation', 'med. vegetation', 'high vegetation', 'vehicle', + 'truck', 'aircraft', 'militaryVehicle', 'bike', 'motorcycle', 'light pole', + 'street sign', 'clutter', 'fence') + + def getInstanceInfo(self, xyz, instance_label, semantic_label): + ret = super().getInstanceInfo(xyz, instance_label, semantic_label) + instance_num, instance_pointnum, instance_cls, pt_offset_label = ret + # ignore instance of class 0 and reorder class id + instance_cls = [x - 1 if x != -100 else x for x in instance_cls] + return instance_num, instance_pointnum, instance_cls, pt_offset_label diff --git a/softgroup/evaluation/instance_eval.py b/softgroup/evaluation/instance_eval.py index 4d6c3bf..4b3a7d6 100644 --- a/softgroup/evaluation/instance_eval.py +++ b/softgroup/evaluation/instance_eval.py @@ -12,7 +12,7 @@ from .instance_eval_util import get_instances class ScanNetEval(object): - def __init__(self, class_labels, iou_type=None, use_label=True): + def __init__(self, class_labels, min_npoint=None, iou_type=None, use_label=True): self.valid_class_labels = class_labels self.valid_class_ids = np.arange(len(class_labels)) + 1 self.id2label = {} @@ -22,7 +22,10 @@ class ScanNetEval(object): self.id2label[self.valid_class_ids[i]] = self.valid_class_labels[i] self.ious = np.append(np.arange(0.5, 0.95, 0.05), 0.25) - self.min_region_sizes = np.array([100]) + if min_npoint: + self.min_region_sizes = np.array([min_npoint]) + else: + self.min_region_sizes = np.array([100]) self.distance_threshes = np.array([float('inf')]) self.distance_confs = np.array([-float('inf')]) diff --git a/softgroup/model/softgroup.py b/softgroup/model/softgroup.py index 45cc590..aa3713a 100644 --- a/softgroup/model/softgroup.py +++ b/softgroup/model/softgroup.py @@ -21,8 +21,10 @@ class SoftGroup(nn.Module): semantic_only=False, semantic_classes=20, instance_classes=18, + semantic_weight=None, sem2ins_classes=[], ignore_label=-100, + with_coords=True, grouping_cfg=None, instance_voxel_cfg=None, train_cfg=None, @@ -34,8 +36,10 @@ class SoftGroup(nn.Module): self.semantic_only = semantic_only self.semantic_classes = semantic_classes self.instance_classes = instance_classes + self.semantic_weight = semantic_weight self.sem2ins_classes = sem2ins_classes self.ignore_label = ignore_label + self.with_coords = with_coords self.grouping_cfg = grouping_cfg self.instance_voxel_cfg = instance_voxel_cfg self.train_cfg = train_cfg @@ -46,9 +50,10 @@ class SoftGroup(nn.Module): norm_fn = functools.partial(nn.BatchNorm1d, eps=1e-4, momentum=0.1) # backbone + in_channels = 6 if with_coords else 3 self.input_conv = spconv.SparseSequential( spconv.SubMConv3d( - 6, channels, kernel_size=3, padding=1, bias=False, indice_key='subm1')) + in_channels, channels, kernel_size=3, padding=1, bias=False, indice_key='subm1')) block_channels = [channels * (i + 1) for i in range(num_blocks)] self.unet = UBlock(block_channels, norm_fn, 2, block, indice_key_id=1) self.output_layer = spconv.SparseSequential(norm_fn(channels), nn.ReLU()) @@ -103,7 +108,8 @@ class SoftGroup(nn.Module): semantic_labels, instance_labels, instance_pointnum, instance_cls, pt_offset_labels, spatial_shape, batch_size, **kwargs): losses = {} - feats = torch.cat((feats, coords_float), 1) + if self.with_coords: + feats = torch.cat((feats, coords_float), 1) voxel_feats = voxelization(feats, p2v_map) input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size) semantic_scores, pt_offsets, output_feats = self.forward_backbone(input, v2p_map) @@ -140,8 +146,12 @@ class SoftGroup(nn.Module): def point_wise_loss(self, semantic_scores, pt_offsets, semantic_labels, instance_labels, pt_offset_labels): losses = {} + if self.semantic_weight: + weight = torch.tensor(self.semantic_weight, dtype=torch.float, device='cuda') + else: + weight = None semantic_loss = F.cross_entropy( - semantic_scores, semantic_labels, ignore_index=self.ignore_label) + semantic_scores, semantic_labels, weight=weight, ignore_index=self.ignore_label) losses['semantic_loss'] = semantic_loss pos_inds = instance_labels != self.ignore_label @@ -169,14 +179,30 @@ class SoftGroup(nn.Module): fg_instance_cls = instance_cls[fg_inds] fg_ious_on_cluster = ious_on_cluster[:, fg_inds] + # assign proposal to gt idx. -1: negative, 0 -> num_gts - 1: positive + num_proposals = fg_ious_on_cluster.size(0) + num_gts = fg_ious_on_cluster.size(1) + assigned_gt_inds = fg_ious_on_cluster.new_full((num_proposals, ), -1, dtype=torch.long) + # overlap > thr on fg instances are positive samples - max_iou, gt_inds = fg_ious_on_cluster.max(1) + max_iou, argmax_iou = fg_ious_on_cluster.max(1) pos_inds = max_iou >= self.train_cfg.pos_iou_thr - pos_gt_inds = gt_inds[pos_inds] + assigned_gt_inds[pos_inds] = argmax_iou[pos_inds] + + # allow low-quality proposals with best iou to be as positive sample + # in case pos_iou_thr is too high to achieve + match_low_quality = getattr(self.train_cfg, 'match_low_quality', False) + min_pos_thr = getattr(self.train_cfg, 'min_pos_thr', 0) + if match_low_quality: + gt_max_iou, gt_argmax_iou = fg_ious_on_cluster.max(0) + for i in range(num_gts): + if gt_max_iou[i] >= min_pos_thr: + assigned_gt_inds[gt_argmax_iou[i]] = i # compute cls loss. follow detection convention: 0 -> K - 1 are fg, K is bg - labels = fg_instance_cls.new_full((fg_ious_on_cluster.size(0), ), self.instance_classes) - labels[pos_inds] = fg_instance_cls[pos_gt_inds] + labels = fg_instance_cls.new_full((num_proposals, ), self.instance_classes) + pos_inds = assigned_gt_inds >= 0 + labels[pos_inds] = fg_instance_cls[assigned_gt_inds[pos_inds]] cls_loss = F.cross_entropy(cls_scores, labels) losses['cls_loss'] = cls_loss @@ -221,7 +247,8 @@ class SoftGroup(nn.Module): def forward_test(self, batch_idxs, voxel_coords, p2v_map, v2p_map, coords_float, feats, semantic_labels, instance_labels, pt_offset_labels, spatial_shape, batch_size, scan_ids, **kwargs): - feats = torch.cat((feats, coords_float), 1) + if self.with_coords: + feats = torch.cat((feats, coords_float), 1) voxel_feats = voxelization(feats, p2v_map) input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size) semantic_scores, pt_offsets, output_feats = self.forward_backbone( diff --git a/tools/test.py b/tools/test.py index 49b5d6d..65b3196 100644 --- a/tools/test.py +++ b/tools/test.py @@ -116,7 +116,8 @@ def main(): gt_insts.append(res['gt_instances']) if not cfg.model.semantic_only: logger.info('Evaluate instance segmentation') - scannet_eval = ScanNetEval(dataset.CLASSES) + eval_min_npoint = getattr(cfg, 'eval_min_npoint', None) + scannet_eval = ScanNetEval(dataset.CLASSES, eval_min_npoint) scannet_eval.evaluate(pred_insts, gt_insts) logger.info('Evaluate semantic segmentation and offset MAE') ignore_label = cfg.model.ignore_label diff --git a/tools/train.py b/tools/train.py index c271c0d..d25ff4c 100644 --- a/tools/train.py +++ b/tools/train.py @@ -108,7 +108,8 @@ def validate(epoch, model, val_loader, cfg, logger, writer): all_gt_insts.append(res['gt_instances']) if not cfg.model.semantic_only: logger.info('Evaluate instance segmentation') - scannet_eval = ScanNetEval(val_set.CLASSES) + eval_min_npoint = getattr(cfg, 'eval_min_npoint', None) + scannet_eval = ScanNetEval(val_set.CLASSES, eval_min_npoint) eval_res = scannet_eval.evaluate(all_pred_insts, all_gt_insts) writer.add_scalar('val/AP', eval_res['all_ap'], epoch) writer.add_scalar('val/AP_50', eval_res['all_ap_50%'], epoch)