From 068d67b1d485647e6a39a336a000a38ea5ca7f29 Mon Sep 17 00:00:00 2001 From: Thang Vu Date: Sat, 9 Apr 2022 03:17:01 +0000 Subject: [PATCH] support evaluate during training --- configs/softgroup_scannet_backbone.yaml | 72 +++++++++++++++++ softgroup/evaluation/__init__.py | 5 +- ..._semantic_instance.py => instance_eval.py} | 2 +- .../{util_3d.py => instance_eval_util.py} | 0 softgroup/evaluation/semantic_eval.py | 32 ++++++++ softgroup/model/blocks.py | 20 +++++ softgroup/model/softgroup.py | 80 ++++++++----------- test.py | 59 ++++---------- train.py | 31 ++++++- 9 files changed, 205 insertions(+), 96 deletions(-) create mode 100644 configs/softgroup_scannet_backbone.yaml rename softgroup/evaluation/{evaluate_semantic_instance.py => instance_eval.py} (99%) rename softgroup/evaluation/{util_3d.py => instance_eval_util.py} (100%) create mode 100644 softgroup/evaluation/semantic_eval.py diff --git a/configs/softgroup_scannet_backbone.yaml b/configs/softgroup_scannet_backbone.yaml new file mode 100644 index 0000000..a0fbbf9 --- /dev/null +++ b/configs/softgroup_scannet_backbone.yaml @@ -0,0 +1,72 @@ +model: + channels: 32 + num_blocks: 7 + semantic_classes: 20 + instance_classes: 18 + sem2ins_classes: [] + semantic_only: True + ignore_label: -100 + grouping_cfg: + score_thr: 0.2 + radius: 0.04 + mean_active: 300 + class_numpoint_mean: [-1., -1., 3917., 12056., 2303., + 8331., 3948., 3166., 5629., 11719., + 1003., 3317., 4912., 10221., 3889., + 4136., 2120., 945., 3967., 2589.] + npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1 + ignore_classes: [0, 1] + instance_voxel_cfg: + scale: 50 + spatial_shape: 20 + train_cfg: + max_proposal_num: 200 + pos_iou_thr: 0.5 + test_cfg: + x4_split: False + cls_score_thr: 0.001 + mask_score_thr: -0.5 + min_npoint: 100 + fixed_modules: [] + +data: + train: + type: 'scannetv2' + data_root: 'dataset/scannetv2' + prefix: 'train' + suffix: '_inst_nostuff.pth' + training: True + voxel_cfg: + scale: 50 + spatial_shape: [128, 512] + max_npoint: 250000 + min_npoint: 5000 + test: + type: 'scannetv2' + data_root: 'dataset/scannetv2' + prefix: 'val' + suffix: '_inst_nostuff.pth' + training: False + voxel_cfg: + scale: 50 + spatial_shape: [128, 512] + max_npoint: 250000 + min_npoint: 5000 + +dataloader: + train: + batch_size: 4 + num_workers: 4 + test: + batch_size: 1 + num_workers: 1 + +optimizer: + type: 'Adam' + lr: 0.001 + +epochs: 512 +step_epoch: 200 +save_freq: 16 +pretrain: '' +work_dir: 'work_dirs/softgroup_scannet_backbone' diff --git a/softgroup/evaluation/__init__.py b/softgroup/evaluation/__init__.py index e6a4cc4..48ffa4c 100644 --- a/softgroup/evaluation/__init__.py +++ b/softgroup/evaluation/__init__.py @@ -1,3 +1,4 @@ -from .evaluate_semantic_instance import ScanNetEval +from .instance_eval import ScanNetEval +from .semantic_eval import evaluate_semantic_acc, evaluate_semantic_miou -__all__ = ['ScanNetEval'] +__all__ = ['ScanNetEval', 'evaluate_semantic_acc', 'evaluate_semantic_miou'] diff --git a/softgroup/evaluation/evaluate_semantic_instance.py b/softgroup/evaluation/instance_eval.py similarity index 99% rename from softgroup/evaluation/evaluate_semantic_instance.py rename to softgroup/evaluation/instance_eval.py index 291c352..01e730f 100644 --- a/softgroup/evaluation/evaluate_semantic_instance.py +++ b/softgroup/evaluation/instance_eval.py @@ -6,7 +6,7 @@ from copy import deepcopy import numpy as np from tqdm import tqdm -from .util_3d import get_instances +from .instance_eval_util import get_instances class ScanNetEval(object): diff --git a/softgroup/evaluation/util_3d.py b/softgroup/evaluation/instance_eval_util.py similarity index 100% rename from softgroup/evaluation/util_3d.py rename to softgroup/evaluation/instance_eval_util.py diff --git a/softgroup/evaluation/semantic_eval.py b/softgroup/evaluation/semantic_eval.py new file mode 100644 index 0000000..742859e --- /dev/null +++ b/softgroup/evaluation/semantic_eval.py @@ -0,0 +1,32 @@ +import numpy as np + + +def evaluate_semantic_acc(pred_list, gt_list, ignore_label=-100, logger=None): + gt = np.concatenate(gt_list, axis=0) + pred = np.concatenate(pred_list, axis=0) + assert gt.shape == pred.shape + correct = (gt[gt != ignore_label] == pred[gt != ignore_label]).sum() + whole = (gt != ignore_label).sum() + acc = correct.astype(float) / whole * 100 + logger.info(f'Acc: {acc:.1f}') + return acc + + +def evaluate_semantic_miou(pred_list, gt_list, ignore_label=-100, logger=None): + gt = np.concatenate(gt_list, axis=0) + pred = np.concatenate(pred_list, axis=0) + pos_inds = gt != ignore_label + gt = gt[pos_inds] + pred = pred[pos_inds] + assert gt.shape == pred.shape + iou_list = [] + for _index in np.unique(gt): + if _index != ignore_label: + intersection = ((gt == _index) & (pred == _index)).sum() + union = ((gt == _index) | (pred == _index)).sum() + iou = intersection.astype(float) / union * 100 + iou_list.append(iou) + miou = np.mean(iou_list) + logger.info('Class-wise mIoU: ' + ' '.join(f'{x:.1f}' for x in iou_list)) + logger.info(f'mIoU: {miou:.1f}') + return miou diff --git a/softgroup/model/blocks.py b/softgroup/model/blocks.py index dcf1f84..ef3ed91 100644 --- a/softgroup/model/blocks.py +++ b/softgroup/model/blocks.py @@ -6,6 +6,26 @@ from spconv.modules import SparseModule from torch import nn +class MLP(nn.Sequential): + + def __init__(self, in_channels, out_channels, norm_fn, num_layers=2): + modules = [] + for _ in range(num_layers - 1): + modules.extend( + [nn.Linear(in_channels, in_channels, bias=False), + norm_fn(in_channels), + nn.ReLU()]) + modules.append(nn.Linear(in_channels, out_channels)) + return super().__init__(*modules) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(self[-1].weight, 0, 0.01) + nn.init.constant_(self[-1].bias, 0) + + class ResidualBlock(SparseModule): def __init__(self, in_channels, out_channels, norm_fn, indice_key=None): diff --git a/softgroup/model/softgroup.py b/softgroup/model/softgroup.py index 1cbc1cb..0af1cd7 100644 --- a/softgroup/model/softgroup.py +++ b/softgroup/model/softgroup.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from ..lib.softgroup_ops import (ballquery_batch_p, bfs_cluster, get_mask_iou_on_cluster, get_mask_iou_on_pred, get_mask_label, global_avg_pool, sec_max, sec_mean, sec_min, voxelization, voxelization_idx) -from .blocks import ResidualBlock, UBlock +from .blocks import MLP, ResidualBlock, UBlock class SoftGroup(nn.Module): @@ -50,34 +50,19 @@ class SoftGroup(nn.Module): self.unet = UBlock(block_channels, norm_fn, 2, block, indice_key_id=1) self.output_layer = spconv.SparseSequential(norm_fn(channels), nn.ReLU()) - # semantic segmentation branch - self.semantic_linear = nn.Sequential( - nn.Linear(channels, channels, bias=True), norm_fn(channels), nn.ReLU(), - nn.Linear(channels, semantic_classes)) - - # center shift vector branch - self.offset_linear = nn.Sequential( - nn.Linear(channels, channels, bias=True), norm_fn(channels), nn.ReLU(), - nn.Linear(channels, 3, bias=True)) + # point-wise prediction + self.semantic_linear = MLP(channels, semantic_classes, norm_fn, num_layers=2) + self.offset_linear = MLP(channels, 3, norm_fn, num_layers=2) # topdown refinement path if not semantic_only: - self.intra_ins_unet = UBlock([channels, 2 * channels], - norm_fn, - 2, - block, - indice_key_id=11) - self.intra_ins_outputlayer = spconv.SparseSequential(norm_fn(channels), nn.ReLU()) - self.cls_linear = nn.Linear(channels, instance_classes + 1) - self.mask_linear = nn.Sequential( - nn.Linear(channels, channels), nn.ReLU(), nn.Linear(channels, instance_classes + 1)) - # TODO renamve score_linear to iou_score_linear - self.score_linear = nn.Linear(channels, instance_classes + 1) + self.tiny_unet = UBlock([channels, 2 * channels], norm_fn, 2, block, indice_key_id=11) + self.tiny_unet_outputlayer = spconv.SparseSequential(norm_fn(channels), nn.ReLU()) + self.cls_linear = MLP(channels, instance_classes + 1, norm_fn, num_layers=2) + self.mask_linear = MLP(channels, instance_classes + 1, norm_fn, num_layers=2) + self.iou_score_linear = MLP(channels, instance_classes + 1, norm_fn, num_layers=2) - nn.init.normal_(self.score_linear.weight, 0, 0.01) - nn.init.constant_(self.score_linear.bias, 0) - - self.apply(self.set_bn_init) + self.init_weights() for mod in fixed_modules: mod = getattr(self, mod) @@ -85,15 +70,13 @@ class SoftGroup(nn.Module): for param in mod.parameters(): param.requires_grad = False - @staticmethod - def set_bn_init(m): - classname = m.__class__.__name__ - if classname.find('BatchNorm') != -1: - m.weight.data.fill_(1.0) - m.bias.data.fill_(0.0) - def init_weights(self): - pass + for m in self.modules(): + if isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, MLP): + m.init_weights() def forward(self, batch, return_loss=False): if return_loss: @@ -235,17 +218,20 @@ class SoftGroup(nn.Module): input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size) semantic_scores, pt_offsets, output_feats, coords_float = self.forward_backbone( input, v2p_map, coords_float, x4_split=self.test_cfg.x4_split) - proposals_idx, proposals_offset = self.forward_grouping(semantic_scores, pt_offsets, - batch_idxs, coords_float, - self.grouping_cfg) - scores_batch_idxs, cls_scores, iou_scores, mask_scores = self.forward_instance( - proposals_idx, proposals_offset, output_feats, coords_float) - pred_instances = self.get_instances(batch['scan_ids'][0], proposals_idx, semantic_scores, - cls_scores, iou_scores, mask_scores) - gt_instances = self.get_gt_instances(labels, instance_labels) - ret = {} - ret['det_ins'] = pred_instances - ret['gt_ins'] = gt_instances + semantic_preds = semantic_scores.max(1)[1] + ret = dict( + semantic_preds=semantic_preds.cpu().numpy(), semantic_labels=labels.cpu().numpy()) + if not self.semantic_only: + proposals_idx, proposals_offset = self.forward_grouping(semantic_scores, pt_offsets, + batch_idxs, coords_float, + self.grouping_cfg) + scores_batch_idxs, cls_scores, iou_scores, mask_scores = self.forward_instance( + proposals_idx, proposals_offset, output_feats, coords_float) + pred_instances = self.get_instances(batch['scan_ids'][0], proposals_idx, + semantic_scores, cls_scores, iou_scores, + mask_scores) + gt_instances = self.get_gt_instances(labels, instance_labels) + ret.update(dict(pred_instances=pred_instances, gt_instances=gt_instances)) return ret def forward_backbone(self, input, input_map, coords, x4_split=False): @@ -344,8 +330,8 @@ class SoftGroup(nn.Module): input_feats, inp_map = self.clusters_voxelization(proposals_idx, proposals_offset, output_feats, coords_float, **self.instance_voxel_cfg) - feats = self.intra_ins_unet(input_feats) - feats = self.intra_ins_outputlayer(feats) + feats = self.tiny_unet(input_feats) + feats = self.tiny_unet_outputlayer(feats) # predict mask scores mask_scores = self.mask_linear(feats.features) @@ -355,7 +341,7 @@ class SoftGroup(nn.Module): # predict instance cls and iou scores feats = self.global_pool(feats) cls_scores = self.cls_linear(feats) - iou_scores = self.score_linear(feats) + iou_scores = self.iou_score_linear(feats) return instance_batch_idxs, cls_scores, iou_scores, mask_scores diff --git a/test.py b/test.py index 5f363ea..50ae660 100644 --- a/test.py +++ b/test.py @@ -6,7 +6,7 @@ import torch import yaml from munch import Munch from softgroup.data import build_dataloader, build_dataset -from softgroup.evaluation import ScanNetEval +from softgroup.evaluation import ScanNetEval, evaluate_semantic_acc, evaluate_semantic_miou from softgroup.model import SoftGroup from softgroup.util import get_root_logger, load_checkpoint from tqdm import tqdm @@ -20,45 +20,6 @@ def get_args(): return args -def evaluate_semantic_segmantation_accuracy(matches): - seg_gt_list = [] - seg_pred_list = [] - for k, v in matches.items(): - seg_gt_list.append(v['seg_gt']) - seg_pred_list.append(v['seg_pred']) - seg_gt_all = torch.cat(seg_gt_list, dim=0).cuda() - seg_pred_all = torch.cat(seg_pred_list, dim=0).cuda() - assert seg_gt_all.shape == seg_pred_all.shape - correct = (seg_gt_all[seg_gt_all != -100] == seg_pred_all[seg_gt_all != -100]).sum() - whole = (seg_gt_all != -100).sum() - seg_accuracy = correct.float() / whole.float() - return seg_accuracy - - -def evaluate_semantic_segmantation_miou(matches): - seg_gt_list = [] - seg_pred_list = [] - for k, v in matches.items(): - seg_gt_list.append(v['seg_gt']) - seg_pred_list.append(v['seg_pred']) - seg_gt_all = torch.cat(seg_gt_list, dim=0).cuda() - seg_pred_all = torch.cat(seg_pred_list, dim=0).cuda() - pos_inds = seg_gt_all != -100 - seg_gt_all = seg_gt_all[pos_inds] - seg_pred_all = seg_pred_all[pos_inds] - assert seg_gt_all.shape == seg_pred_all.shape - iou_list = [] - for _index in seg_gt_all.unique(): - if _index != -100: - intersection = ((seg_gt_all == _index) & (seg_pred_all == _index)).sum() - union = ((seg_gt_all == _index) | (seg_pred_all == _index)).sum() - iou = intersection.float() / union - iou_list.append(iou) - iou_tensor = torch.tensor(iou_list) - miou = iou_tensor.mean() - return miou - - if __name__ == '__main__': torch.backends.cudnn.enabled = False # TODO remove this test_seed = 567 @@ -79,12 +40,20 @@ if __name__ == '__main__': dataset = build_dataset(cfg.data.test, logger) dataloader = build_dataloader(dataset, training=False, **cfg.dataloader.test) - all_preds, all_gts = [], [] + all_sem_preds, all_sem_labels, all_pred_insts, all_gt_insts = [], [], [], [] with torch.no_grad(): model = model.eval() for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)): ret = model(batch) - all_preds.append(ret['det_ins']) - all_gts.append(ret['gt_ins']) - scannet_eval = ScanNetEval(dataset.CLASSES) - scannet_eval.evaluate(all_preds, all_gts) + all_sem_preds.append(ret['semantic_preds']) + all_sem_labels.append(ret['semantic_labels']) + if not cfg.model.semantic_only: + all_pred_insts.append(ret['pred_instances']) + all_gt_insts.append(ret['gt_instances']) + if not cfg.model.semantic_only: + logger.info('Evaluate instance segmentation') + scannet_eval = ScanNetEval(dataset.CLASSES) + scannet_eval.evaluate(all_pred_insts, all_gt_insts) + logger.info('Evaluate semantic segmentation') + evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label, logger) + evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label, logger) diff --git a/train.py b/train.py index 6931231..73cb83f 100644 --- a/train.py +++ b/train.py @@ -12,10 +12,13 @@ import torch import yaml from munch import Munch from softgroup.data import build_dataloader, build_dataset +from softgroup.evaluation import ScanNetEval, evaluate_semantic_acc, evaluate_semantic_miou from softgroup.model import SoftGroup from softgroup.util import (AverageMeter, build_optimizer, checkpoint_save, cosine_lr_after_step, - get_max_memory, get_root_logger, load_checkpoint) + get_max_memory, get_root_logger, is_multiple, is_power2, + load_checkpoint) from tensorboardX import SummaryWriter +from tqdm import tqdm def eval_epoch(val_loader, model, model_fn, epoch): @@ -149,3 +152,29 @@ if __name__ == '__main__': log_str += f', {k}: {v.val:.4f}' logger.info(log_str) checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq) + + # validation + if not (is_multiple(epoch, cfg.save_freq) or is_power2(epoch)): + continue + all_sem_preds, all_sem_labels, all_pred_insts, all_gt_insts = [], [], [], [] + logger.info('Validation') + with torch.no_grad(): + model = model.eval() + for batch in tqdm(val_loader, total=len(val_loader)): + ret = model(batch) + all_sem_preds.append(ret['semantic_preds']) + all_sem_labels.append(ret['semantic_labels']) + if not cfg.model.semantic_only: + all_pred_insts.append(ret['pred_instances']) + all_gt_insts.append(ret['gt_instances']) + if not cfg.model.semantic_only: + logger.info('Evaluate instance segmentation') + scannet_eval = ScanNetEval(val_loader.dataset.CLASSES) + scannet_eval.evaluate(all_pred_insts, all_gt_insts) + logger.info('Evaluate semantic segmentation') + miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label, + logger) + acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label, + logger) + writer.add_scalar('mIoU', miou, epoch) + writer.add_scalar('Acc', acc, epoch)