mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
support evaluate during training
This commit is contained in:
parent
f811c2f815
commit
068d67b1d4
72
configs/softgroup_scannet_backbone.yaml
Normal file
72
configs/softgroup_scannet_backbone.yaml
Normal file
@ -0,0 +1,72 @@
|
||||
model:
|
||||
channels: 32
|
||||
num_blocks: 7
|
||||
semantic_classes: 20
|
||||
instance_classes: 18
|
||||
sem2ins_classes: []
|
||||
semantic_only: True
|
||||
ignore_label: -100
|
||||
grouping_cfg:
|
||||
score_thr: 0.2
|
||||
radius: 0.04
|
||||
mean_active: 300
|
||||
class_numpoint_mean: [-1., -1., 3917., 12056., 2303.,
|
||||
8331., 3948., 3166., 5629., 11719.,
|
||||
1003., 3317., 4912., 10221., 3889.,
|
||||
4136., 2120., 945., 3967., 2589.]
|
||||
npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1
|
||||
ignore_classes: [0, 1]
|
||||
instance_voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: 20
|
||||
train_cfg:
|
||||
max_proposal_num: 200
|
||||
pos_iou_thr: 0.5
|
||||
test_cfg:
|
||||
x4_split: False
|
||||
cls_score_thr: 0.001
|
||||
mask_score_thr: -0.5
|
||||
min_npoint: 100
|
||||
fixed_modules: []
|
||||
|
||||
data:
|
||||
train:
|
||||
type: 'scannetv2'
|
||||
data_root: 'dataset/scannetv2'
|
||||
prefix: 'train'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: True
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
test:
|
||||
type: 'scannetv2'
|
||||
data_root: 'dataset/scannetv2'
|
||||
prefix: 'val'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: False
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
|
||||
dataloader:
|
||||
train:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
test:
|
||||
batch_size: 1
|
||||
num_workers: 1
|
||||
|
||||
optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.001
|
||||
|
||||
epochs: 512
|
||||
step_epoch: 200
|
||||
save_freq: 16
|
||||
pretrain: ''
|
||||
work_dir: 'work_dirs/softgroup_scannet_backbone'
|
||||
@ -1,3 +1,4 @@
|
||||
from .evaluate_semantic_instance import ScanNetEval
|
||||
from .instance_eval import ScanNetEval
|
||||
from .semantic_eval import evaluate_semantic_acc, evaluate_semantic_miou
|
||||
|
||||
__all__ = ['ScanNetEval']
|
||||
__all__ = ['ScanNetEval', 'evaluate_semantic_acc', 'evaluate_semantic_miou']
|
||||
|
||||
@ -6,7 +6,7 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from .util_3d import get_instances
|
||||
from .instance_eval_util import get_instances
|
||||
|
||||
|
||||
class ScanNetEval(object):
|
||||
32
softgroup/evaluation/semantic_eval.py
Normal file
32
softgroup/evaluation/semantic_eval.py
Normal file
@ -0,0 +1,32 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def evaluate_semantic_acc(pred_list, gt_list, ignore_label=-100, logger=None):
|
||||
gt = np.concatenate(gt_list, axis=0)
|
||||
pred = np.concatenate(pred_list, axis=0)
|
||||
assert gt.shape == pred.shape
|
||||
correct = (gt[gt != ignore_label] == pred[gt != ignore_label]).sum()
|
||||
whole = (gt != ignore_label).sum()
|
||||
acc = correct.astype(float) / whole * 100
|
||||
logger.info(f'Acc: {acc:.1f}')
|
||||
return acc
|
||||
|
||||
|
||||
def evaluate_semantic_miou(pred_list, gt_list, ignore_label=-100, logger=None):
|
||||
gt = np.concatenate(gt_list, axis=0)
|
||||
pred = np.concatenate(pred_list, axis=0)
|
||||
pos_inds = gt != ignore_label
|
||||
gt = gt[pos_inds]
|
||||
pred = pred[pos_inds]
|
||||
assert gt.shape == pred.shape
|
||||
iou_list = []
|
||||
for _index in np.unique(gt):
|
||||
if _index != ignore_label:
|
||||
intersection = ((gt == _index) & (pred == _index)).sum()
|
||||
union = ((gt == _index) | (pred == _index)).sum()
|
||||
iou = intersection.astype(float) / union * 100
|
||||
iou_list.append(iou)
|
||||
miou = np.mean(iou_list)
|
||||
logger.info('Class-wise mIoU: ' + ' '.join(f'{x:.1f}' for x in iou_list))
|
||||
logger.info(f'mIoU: {miou:.1f}')
|
||||
return miou
|
||||
@ -6,6 +6,26 @@ from spconv.modules import SparseModule
|
||||
from torch import nn
|
||||
|
||||
|
||||
class MLP(nn.Sequential):
|
||||
|
||||
def __init__(self, in_channels, out_channels, norm_fn, num_layers=2):
|
||||
modules = []
|
||||
for _ in range(num_layers - 1):
|
||||
modules.extend(
|
||||
[nn.Linear(in_channels, in_channels, bias=False),
|
||||
norm_fn(in_channels),
|
||||
nn.ReLU()])
|
||||
modules.append(nn.Linear(in_channels, out_channels))
|
||||
return super().__init__(*modules)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
nn.init.normal_(self[-1].weight, 0, 0.01)
|
||||
nn.init.constant_(self[-1].bias, 0)
|
||||
|
||||
|
||||
class ResidualBlock(SparseModule):
|
||||
|
||||
def __init__(self, in_channels, out_channels, norm_fn, indice_key=None):
|
||||
|
||||
@ -8,7 +8,7 @@ import torch.nn.functional as F
|
||||
from ..lib.softgroup_ops import (ballquery_batch_p, bfs_cluster, get_mask_iou_on_cluster,
|
||||
get_mask_iou_on_pred, get_mask_label, global_avg_pool, sec_max,
|
||||
sec_mean, sec_min, voxelization, voxelization_idx)
|
||||
from .blocks import ResidualBlock, UBlock
|
||||
from .blocks import MLP, ResidualBlock, UBlock
|
||||
|
||||
|
||||
class SoftGroup(nn.Module):
|
||||
@ -50,34 +50,19 @@ class SoftGroup(nn.Module):
|
||||
self.unet = UBlock(block_channels, norm_fn, 2, block, indice_key_id=1)
|
||||
self.output_layer = spconv.SparseSequential(norm_fn(channels), nn.ReLU())
|
||||
|
||||
# semantic segmentation branch
|
||||
self.semantic_linear = nn.Sequential(
|
||||
nn.Linear(channels, channels, bias=True), norm_fn(channels), nn.ReLU(),
|
||||
nn.Linear(channels, semantic_classes))
|
||||
|
||||
# center shift vector branch
|
||||
self.offset_linear = nn.Sequential(
|
||||
nn.Linear(channels, channels, bias=True), norm_fn(channels), nn.ReLU(),
|
||||
nn.Linear(channels, 3, bias=True))
|
||||
# point-wise prediction
|
||||
self.semantic_linear = MLP(channels, semantic_classes, norm_fn, num_layers=2)
|
||||
self.offset_linear = MLP(channels, 3, norm_fn, num_layers=2)
|
||||
|
||||
# topdown refinement path
|
||||
if not semantic_only:
|
||||
self.intra_ins_unet = UBlock([channels, 2 * channels],
|
||||
norm_fn,
|
||||
2,
|
||||
block,
|
||||
indice_key_id=11)
|
||||
self.intra_ins_outputlayer = spconv.SparseSequential(norm_fn(channels), nn.ReLU())
|
||||
self.cls_linear = nn.Linear(channels, instance_classes + 1)
|
||||
self.mask_linear = nn.Sequential(
|
||||
nn.Linear(channels, channels), nn.ReLU(), nn.Linear(channels, instance_classes + 1))
|
||||
# TODO renamve score_linear to iou_score_linear
|
||||
self.score_linear = nn.Linear(channels, instance_classes + 1)
|
||||
self.tiny_unet = UBlock([channels, 2 * channels], norm_fn, 2, block, indice_key_id=11)
|
||||
self.tiny_unet_outputlayer = spconv.SparseSequential(norm_fn(channels), nn.ReLU())
|
||||
self.cls_linear = MLP(channels, instance_classes + 1, norm_fn, num_layers=2)
|
||||
self.mask_linear = MLP(channels, instance_classes + 1, norm_fn, num_layers=2)
|
||||
self.iou_score_linear = MLP(channels, instance_classes + 1, norm_fn, num_layers=2)
|
||||
|
||||
nn.init.normal_(self.score_linear.weight, 0, 0.01)
|
||||
nn.init.constant_(self.score_linear.bias, 0)
|
||||
|
||||
self.apply(self.set_bn_init)
|
||||
self.init_weights()
|
||||
|
||||
for mod in fixed_modules:
|
||||
mod = getattr(self, mod)
|
||||
@ -85,15 +70,13 @@ class SoftGroup(nn.Module):
|
||||
for param in mod.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@staticmethod
|
||||
def set_bn_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('BatchNorm') != -1:
|
||||
m.weight.data.fill_(1.0)
|
||||
m.bias.data.fill_(0.0)
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, MLP):
|
||||
m.init_weights()
|
||||
|
||||
def forward(self, batch, return_loss=False):
|
||||
if return_loss:
|
||||
@ -235,17 +218,20 @@ class SoftGroup(nn.Module):
|
||||
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
|
||||
semantic_scores, pt_offsets, output_feats, coords_float = self.forward_backbone(
|
||||
input, v2p_map, coords_float, x4_split=self.test_cfg.x4_split)
|
||||
proposals_idx, proposals_offset = self.forward_grouping(semantic_scores, pt_offsets,
|
||||
batch_idxs, coords_float,
|
||||
self.grouping_cfg)
|
||||
scores_batch_idxs, cls_scores, iou_scores, mask_scores = self.forward_instance(
|
||||
proposals_idx, proposals_offset, output_feats, coords_float)
|
||||
pred_instances = self.get_instances(batch['scan_ids'][0], proposals_idx, semantic_scores,
|
||||
cls_scores, iou_scores, mask_scores)
|
||||
gt_instances = self.get_gt_instances(labels, instance_labels)
|
||||
ret = {}
|
||||
ret['det_ins'] = pred_instances
|
||||
ret['gt_ins'] = gt_instances
|
||||
semantic_preds = semantic_scores.max(1)[1]
|
||||
ret = dict(
|
||||
semantic_preds=semantic_preds.cpu().numpy(), semantic_labels=labels.cpu().numpy())
|
||||
if not self.semantic_only:
|
||||
proposals_idx, proposals_offset = self.forward_grouping(semantic_scores, pt_offsets,
|
||||
batch_idxs, coords_float,
|
||||
self.grouping_cfg)
|
||||
scores_batch_idxs, cls_scores, iou_scores, mask_scores = self.forward_instance(
|
||||
proposals_idx, proposals_offset, output_feats, coords_float)
|
||||
pred_instances = self.get_instances(batch['scan_ids'][0], proposals_idx,
|
||||
semantic_scores, cls_scores, iou_scores,
|
||||
mask_scores)
|
||||
gt_instances = self.get_gt_instances(labels, instance_labels)
|
||||
ret.update(dict(pred_instances=pred_instances, gt_instances=gt_instances))
|
||||
return ret
|
||||
|
||||
def forward_backbone(self, input, input_map, coords, x4_split=False):
|
||||
@ -344,8 +330,8 @@ class SoftGroup(nn.Module):
|
||||
input_feats, inp_map = self.clusters_voxelization(proposals_idx, proposals_offset,
|
||||
output_feats, coords_float,
|
||||
**self.instance_voxel_cfg)
|
||||
feats = self.intra_ins_unet(input_feats)
|
||||
feats = self.intra_ins_outputlayer(feats)
|
||||
feats = self.tiny_unet(input_feats)
|
||||
feats = self.tiny_unet_outputlayer(feats)
|
||||
|
||||
# predict mask scores
|
||||
mask_scores = self.mask_linear(feats.features)
|
||||
@ -355,7 +341,7 @@ class SoftGroup(nn.Module):
|
||||
# predict instance cls and iou scores
|
||||
feats = self.global_pool(feats)
|
||||
cls_scores = self.cls_linear(feats)
|
||||
iou_scores = self.score_linear(feats)
|
||||
iou_scores = self.iou_score_linear(feats)
|
||||
|
||||
return instance_batch_idxs, cls_scores, iou_scores, mask_scores
|
||||
|
||||
|
||||
59
test.py
59
test.py
@ -6,7 +6,7 @@ import torch
|
||||
import yaml
|
||||
from munch import Munch
|
||||
from softgroup.data import build_dataloader, build_dataset
|
||||
from softgroup.evaluation import ScanNetEval
|
||||
from softgroup.evaluation import ScanNetEval, evaluate_semantic_acc, evaluate_semantic_miou
|
||||
from softgroup.model import SoftGroup
|
||||
from softgroup.util import get_root_logger, load_checkpoint
|
||||
from tqdm import tqdm
|
||||
@ -20,45 +20,6 @@ def get_args():
|
||||
return args
|
||||
|
||||
|
||||
def evaluate_semantic_segmantation_accuracy(matches):
|
||||
seg_gt_list = []
|
||||
seg_pred_list = []
|
||||
for k, v in matches.items():
|
||||
seg_gt_list.append(v['seg_gt'])
|
||||
seg_pred_list.append(v['seg_pred'])
|
||||
seg_gt_all = torch.cat(seg_gt_list, dim=0).cuda()
|
||||
seg_pred_all = torch.cat(seg_pred_list, dim=0).cuda()
|
||||
assert seg_gt_all.shape == seg_pred_all.shape
|
||||
correct = (seg_gt_all[seg_gt_all != -100] == seg_pred_all[seg_gt_all != -100]).sum()
|
||||
whole = (seg_gt_all != -100).sum()
|
||||
seg_accuracy = correct.float() / whole.float()
|
||||
return seg_accuracy
|
||||
|
||||
|
||||
def evaluate_semantic_segmantation_miou(matches):
|
||||
seg_gt_list = []
|
||||
seg_pred_list = []
|
||||
for k, v in matches.items():
|
||||
seg_gt_list.append(v['seg_gt'])
|
||||
seg_pred_list.append(v['seg_pred'])
|
||||
seg_gt_all = torch.cat(seg_gt_list, dim=0).cuda()
|
||||
seg_pred_all = torch.cat(seg_pred_list, dim=0).cuda()
|
||||
pos_inds = seg_gt_all != -100
|
||||
seg_gt_all = seg_gt_all[pos_inds]
|
||||
seg_pred_all = seg_pred_all[pos_inds]
|
||||
assert seg_gt_all.shape == seg_pred_all.shape
|
||||
iou_list = []
|
||||
for _index in seg_gt_all.unique():
|
||||
if _index != -100:
|
||||
intersection = ((seg_gt_all == _index) & (seg_pred_all == _index)).sum()
|
||||
union = ((seg_gt_all == _index) | (seg_pred_all == _index)).sum()
|
||||
iou = intersection.float() / union
|
||||
iou_list.append(iou)
|
||||
iou_tensor = torch.tensor(iou_list)
|
||||
miou = iou_tensor.mean()
|
||||
return miou
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cudnn.enabled = False # TODO remove this
|
||||
test_seed = 567
|
||||
@ -79,12 +40,20 @@ if __name__ == '__main__':
|
||||
|
||||
dataset = build_dataset(cfg.data.test, logger)
|
||||
dataloader = build_dataloader(dataset, training=False, **cfg.dataloader.test)
|
||||
all_preds, all_gts = [], []
|
||||
all_sem_preds, all_sem_labels, all_pred_insts, all_gt_insts = [], [], [], []
|
||||
with torch.no_grad():
|
||||
model = model.eval()
|
||||
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
|
||||
ret = model(batch)
|
||||
all_preds.append(ret['det_ins'])
|
||||
all_gts.append(ret['gt_ins'])
|
||||
scannet_eval = ScanNetEval(dataset.CLASSES)
|
||||
scannet_eval.evaluate(all_preds, all_gts)
|
||||
all_sem_preds.append(ret['semantic_preds'])
|
||||
all_sem_labels.append(ret['semantic_labels'])
|
||||
if not cfg.model.semantic_only:
|
||||
all_pred_insts.append(ret['pred_instances'])
|
||||
all_gt_insts.append(ret['gt_instances'])
|
||||
if not cfg.model.semantic_only:
|
||||
logger.info('Evaluate instance segmentation')
|
||||
scannet_eval = ScanNetEval(dataset.CLASSES)
|
||||
scannet_eval.evaluate(all_pred_insts, all_gt_insts)
|
||||
logger.info('Evaluate semantic segmentation')
|
||||
evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label, logger)
|
||||
evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label, logger)
|
||||
|
||||
31
train.py
31
train.py
@ -12,10 +12,13 @@ import torch
|
||||
import yaml
|
||||
from munch import Munch
|
||||
from softgroup.data import build_dataloader, build_dataset
|
||||
from softgroup.evaluation import ScanNetEval, evaluate_semantic_acc, evaluate_semantic_miou
|
||||
from softgroup.model import SoftGroup
|
||||
from softgroup.util import (AverageMeter, build_optimizer, checkpoint_save, cosine_lr_after_step,
|
||||
get_max_memory, get_root_logger, load_checkpoint)
|
||||
get_max_memory, get_root_logger, is_multiple, is_power2,
|
||||
load_checkpoint)
|
||||
from tensorboardX import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def eval_epoch(val_loader, model, model_fn, epoch):
|
||||
@ -149,3 +152,29 @@ if __name__ == '__main__':
|
||||
log_str += f', {k}: {v.val:.4f}'
|
||||
logger.info(log_str)
|
||||
checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq)
|
||||
|
||||
# validation
|
||||
if not (is_multiple(epoch, cfg.save_freq) or is_power2(epoch)):
|
||||
continue
|
||||
all_sem_preds, all_sem_labels, all_pred_insts, all_gt_insts = [], [], [], []
|
||||
logger.info('Validation')
|
||||
with torch.no_grad():
|
||||
model = model.eval()
|
||||
for batch in tqdm(val_loader, total=len(val_loader)):
|
||||
ret = model(batch)
|
||||
all_sem_preds.append(ret['semantic_preds'])
|
||||
all_sem_labels.append(ret['semantic_labels'])
|
||||
if not cfg.model.semantic_only:
|
||||
all_pred_insts.append(ret['pred_instances'])
|
||||
all_gt_insts.append(ret['gt_instances'])
|
||||
if not cfg.model.semantic_only:
|
||||
logger.info('Evaluate instance segmentation')
|
||||
scannet_eval = ScanNetEval(val_loader.dataset.CLASSES)
|
||||
scannet_eval.evaluate(all_pred_insts, all_gt_insts)
|
||||
logger.info('Evaluate semantic segmentation')
|
||||
miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
|
||||
logger)
|
||||
acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
|
||||
logger)
|
||||
writer.add_scalar('mIoU', miou, epoch)
|
||||
writer.add_scalar('Acc', acc, epoch)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user