SoftGroup/softgroup/model/softgroup.py
2022-04-10 02:57:11 +00:00

468 lines
22 KiB
Python

import functools
import spconv.pytorch as spconv
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..lib.softgroup_ops import (ballquery_batch_p, bfs_cluster, get_mask_iou_on_cluster,
get_mask_iou_on_pred, get_mask_label, global_avg_pool, sec_max,
sec_min, voxelization, voxelization_idx)
from ..util import force_fp32
from .blocks import MLP, ResidualBlock, UBlock
class SoftGroup(nn.Module):
def __init__(self,
channels=32,
num_blocks=7,
semantic_only=False,
semantic_classes=20,
instance_classes=18,
sem2ins_classes=[],
ignore_label=-100,
grouping_cfg=None,
instance_voxel_cfg=None,
train_cfg=None,
test_cfg=None,
fixed_modules=[]):
super().__init__()
self.channels = channels
self.num_blocks = num_blocks
self.semantic_only = semantic_only
self.semantic_classes = semantic_classes
self.instance_classes = instance_classes
self.sem2ins_classes = sem2ins_classes
self.ignore_label = ignore_label
self.grouping_cfg = grouping_cfg
self.instance_voxel_cfg = instance_voxel_cfg
self.train_cfg = train_cfg
self.test_cfg = test_cfg
block = ResidualBlock
norm_fn = functools.partial(nn.BatchNorm1d, eps=1e-4, momentum=0.1)
# backbone
self.input_conv = spconv.SparseSequential(
spconv.SubMConv3d(
6, channels, kernel_size=3, padding=1, bias=False, indice_key='subm1'))
block_channels = [channels * (i + 1) for i in range(num_blocks)]
self.unet = UBlock(block_channels, norm_fn, 2, block, indice_key_id=1)
self.output_layer = spconv.SparseSequential(norm_fn(channels), nn.ReLU())
# point-wise prediction
self.semantic_linear = MLP(channels, semantic_classes, norm_fn, num_layers=2)
self.offset_linear = MLP(channels, 3, norm_fn, num_layers=2)
# topdown refinement path
if not semantic_only:
self.tiny_unet = UBlock([channels, 2 * channels], norm_fn, 2, block, indice_key_id=11)
self.tiny_unet_outputlayer = spconv.SparseSequential(norm_fn(channels), nn.ReLU())
self.cls_linear = MLP(channels, instance_classes + 1, norm_fn, num_layers=2)
self.mask_linear = MLP(channels, instance_classes + 1, norm_fn, num_layers=2)
self.iou_score_linear = MLP(channels, instance_classes + 1, norm_fn, num_layers=2)
self.init_weights()
for mod in fixed_modules:
mod = getattr(self, mod)
mod.eval()
for param in mod.parameters():
param.requires_grad = False
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, MLP):
m.init_weights()
def forward(self, batch, return_loss=False):
if return_loss:
return self.forward_train(**batch)
else:
return self.forward_test(**batch)
def forward_train(self, batch_idxs, voxel_coords, p2v_map, v2p_map, coords_float, feats,
semantic_labels, instance_labels, instance_pointnum, instance_cls,
pt_offset_labels, spatial_shape, batch_size, **kwargs):
losses = {}
feats = torch.cat((feats, coords_float), 1)
voxel_feats = voxelization(feats, p2v_map)
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
semantic_scores, pt_offsets, output_feats, coords_float = self.forward_backbone(
input, v2p_map, coords_float)
# point wise losses
point_wise_loss = self.point_wise_loss(semantic_scores, pt_offsets, semantic_labels,
instance_labels, pt_offset_labels)
losses.update(point_wise_loss)
# instance losses
if not self.semantic_only:
proposals_idx, proposals_offset = self.forward_grouping(semantic_scores, pt_offsets,
batch_idxs, coords_float,
self.grouping_cfg)
if proposals_offset.shape[0] > self.train_cfg.max_proposal_num:
proposals_offset = proposals_offset[:self.train_cfg.max_proposal_num + 1]
proposals_idx = proposals_idx[:proposals_offset[-1]]
assert proposals_idx.shape[0] == proposals_offset[-1]
inst_feats, inst_map = self.clusters_voxelization(
proposals_idx,
proposals_offset,
output_feats,
coords_float,
rand_quantize=True,
**self.instance_voxel_cfg)
instance_batch_idxs, cls_scores, iou_scores, mask_scores = self.forward_instance(
inst_feats, inst_map)
instance_loss = self.instance_loss(cls_scores, mask_scores, iou_scores, proposals_idx,
proposals_offset, instance_labels, instance_pointnum,
instance_cls, instance_batch_idxs)
losses.update(instance_loss)
# parse loss
loss = sum(v[0] for v in losses.values())
losses['loss'] = (loss, batch_idxs.size(0))
return loss, losses
def point_wise_loss(self, semantic_scores, pt_offsets, semantic_labels, instance_labels,
pt_offset_labels):
losses = {}
semantic_loss = F.cross_entropy(
semantic_scores, semantic_labels, ignore_index=self.ignore_label)
losses['semantic_loss'] = (semantic_loss, semantic_scores.size(0))
pos_inds = instance_labels != self.ignore_label
if pos_inds.sum() == 0:
offset_loss = 0 * pt_offsets.sum()
else:
offset_loss = F.l1_loss(
pt_offsets[pos_inds], pt_offset_labels[pos_inds], reduction='sum') / pos_inds.sum()
losses['offset_loss'] = (offset_loss, pos_inds.sum())
return losses
@force_fp32(apply_to=('cls_scores', 'mask_scores', 'iou_scores'))
def instance_loss(self, cls_scores, mask_scores, iou_scores, proposals_idx, proposals_offset,
instance_labels, instance_pointnum, instance_cls, instance_batch_idxs):
losses = {}
proposals_idx = proposals_idx[:, 1].cuda()
proposals_offset = proposals_offset.cuda()
# cal iou of clustered instance
ious_on_cluster = get_mask_iou_on_cluster(proposals_idx, proposals_offset, instance_labels,
instance_pointnum)
# filter out background instances
fg_inds = (instance_cls != self.ignore_label)
fg_instance_cls = instance_cls[fg_inds]
fg_ious_on_cluster = ious_on_cluster[:, fg_inds]
# overlap > thr on fg instances are positive samples
max_iou, gt_inds = fg_ious_on_cluster.max(1)
pos_inds = max_iou >= self.train_cfg.pos_iou_thr
pos_gt_inds = gt_inds[pos_inds]
# compute cls loss. follow detection convention: 0 -> K - 1 are fg, K is bg
labels = fg_instance_cls.new_full((fg_ious_on_cluster.size(0), ), self.instance_classes)
labels[pos_inds] = fg_instance_cls[pos_gt_inds]
cls_loss = F.cross_entropy(cls_scores, labels)
losses['cls_loss'] = (cls_loss, labels.size(0))
# compute mask loss
mask_cls_label = labels[instance_batch_idxs.long()]
slice_inds = torch.arange(
0, mask_cls_label.size(0), dtype=torch.long, device=mask_cls_label.device)
mask_scores_sigmoid_slice = mask_scores.sigmoid()[slice_inds, mask_cls_label]
mask_label = get_mask_label(proposals_idx, proposals_offset, instance_labels, instance_cls,
instance_pointnum, ious_on_cluster, self.train_cfg.pos_iou_thr)
mask_label_weight = (mask_label != -1).float()
mask_label[mask_label == -1.] = 0.5 # any value is ok
mask_loss = F.binary_cross_entropy(
mask_scores_sigmoid_slice, mask_label, weight=mask_label_weight, reduction='sum')
mask_loss /= (mask_label_weight.sum() + 1)
losses['mask_loss'] = (mask_loss, mask_label_weight.sum())
# compute iou score loss
ious = get_mask_iou_on_pred(proposals_idx, proposals_offset, instance_labels,
instance_pointnum, mask_scores_sigmoid_slice.detach())
fg_ious = ious[:, fg_inds]
gt_ious, _ = fg_ious.max(1)
slice_inds = torch.arange(0, labels.size(0), dtype=torch.long, device=labels.device)
iou_score_weight = (labels < self.instance_classes).float()
iou_score_slice = iou_scores[slice_inds, labels]
iou_score_loss = F.mse_loss(iou_score_slice, gt_ious, reduction='none')
iou_score_loss = (iou_score_loss * iou_score_weight).sum() / (iou_score_weight.sum() + 1)
losses['iou_score_loss'] = (iou_score_loss, iou_score_weight.sum())
return losses
def forward_test(self, batch_idxs, voxel_coords, p2v_map, v2p_map, coords_float, feats,
semantic_labels, instance_labels, spatial_shape, batch_size, scan_ids,
**kwargs):
feats = torch.cat((feats, coords_float), 1)
voxel_feats = voxelization(feats, p2v_map)
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
semantic_scores, pt_offsets, output_feats, coords_float = self.forward_backbone(
input, v2p_map, coords_float, x4_split=self.test_cfg.x4_split)
semantic_preds = semantic_scores.max(1)[1]
ret = dict(
semantic_preds=semantic_preds.cpu().numpy(),
semantic_labels=semantic_labels.cpu().numpy())
if not self.semantic_only:
proposals_idx, proposals_offset = self.forward_grouping(semantic_scores, pt_offsets,
batch_idxs, coords_float,
self.grouping_cfg)
inst_feats, inst_map = self.clusters_voxelization(proposals_idx, proposals_offset,
output_feats, coords_float,
**self.instance_voxel_cfg)
_, cls_scores, iou_scores, mask_scores = self.forward_instance(inst_feats, inst_map)
pred_instances = self.get_instances(scan_ids[0], proposals_idx, semantic_scores,
cls_scores, iou_scores, mask_scores)
gt_instances = self.get_gt_instances(semantic_labels, instance_labels)
ret.update(dict(pred_instances=pred_instances, gt_instances=gt_instances))
return ret
def forward_backbone(self, input, input_map, coords, x4_split=False):
if x4_split:
output_feats = self.forward_4_parts(input, input_map)
output_feats = self.merge_4_parts(output_feats)
coords = self.merge_4_parts(coords)
else:
output = self.input_conv(input)
output = self.unet(output)
output = self.output_layer(output)
output_feats = output.features[input_map.long()]
semantic_scores = self.semantic_linear(output_feats)
pt_offsets = self.offset_linear(output_feats)
return semantic_scores, pt_offsets, output_feats, coords
def forward_4_parts(self, x, input_map):
"""Helper function for s3dis: devide and forward 4 parts of a scene."""
outs = []
for i in range(4):
inds = x.indices[:, 0] == i
feats = x.features[inds]
coords = x.indices[inds]
coords[:, 0] = 0
x_new = spconv.SparseConvTensor(
indices=coords, features=feats, spatial_shape=x.spatial_shape, batch_size=1)
out = self.input_conv(x_new)
out = self.unet(out)
out = self.output_layer(out)
outs.append(out.features)
outs = torch.cat(outs, dim=0)
return outs[input_map.long()]
def merge_4_parts(self, x):
"""Helper function for s3dis: take output of 4 parts and merge them."""
inds = torch.arange(x.size(0), device=x.device)
p1 = inds[::4]
p2 = inds[1::4]
p3 = inds[2::4]
p4 = inds[3::4]
ps = [p1, p2, p3, p4]
x_split = torch.split(x, [p.size(0) for p in ps])
x_new = torch.zeros_like(x)
for i, p in enumerate(ps):
x_new[p] = x_split[i]
return x_new
@force_fp32(apply_to=('semantic_scores, pt_offsets'))
def forward_grouping(self,
semantic_scores,
pt_offsets,
batch_idxs,
coords_float,
grouping_cfg=None):
proposals_idx_list = []
proposals_offset_list = []
batch_size = batch_idxs.max() + 1
semantic_scores = semantic_scores.softmax(dim=-1)
radius = self.grouping_cfg.radius
mean_active = self.grouping_cfg.mean_active
npoint_thr = self.grouping_cfg.npoint_thr
class_numpoint_mean = torch.tensor(
self.grouping_cfg.class_numpoint_mean, dtype=torch.float32)
for class_id in range(self.semantic_classes):
if class_id in self.grouping_cfg.ignore_classes:
continue
scores = semantic_scores[:, class_id].contiguous()
object_idxs = (scores > self.grouping_cfg.score_thr).nonzero().view(-1)
if object_idxs.size(0) < self.test_cfg.min_npoint:
continue
batch_idxs_ = batch_idxs[object_idxs]
batch_offsets_ = self.get_batch_offsets(batch_idxs_, batch_size)
coords_ = coords_float[object_idxs]
pt_offsets_ = pt_offsets[object_idxs]
idx, start_len = ballquery_batch_p(coords_ + pt_offsets_, batch_idxs_, batch_offsets_,
radius, mean_active)
proposals_idx, proposals_offset = bfs_cluster(class_numpoint_mean, idx.cpu(),
start_len.cpu(), npoint_thr, class_id)
proposals_idx[:, 1] = object_idxs[proposals_idx[:, 1].long()].int()
# merge proposals
if len(proposals_offset_list) > 0:
proposals_idx[:, 0] += sum([x.size(0) for x in proposals_offset_list]) - 1
proposals_offset += proposals_offset_list[-1][-1]
proposals_offset = proposals_offset[1:]
if proposals_idx.size(0) > 0:
proposals_idx_list.append(proposals_idx)
proposals_offset_list.append(proposals_offset)
proposals_idx = torch.cat(proposals_idx_list, dim=0)
proposals_offset = torch.cat(proposals_offset_list)
return proposals_idx, proposals_offset
def forward_instance(self, inst_feats, inst_map):
feats = self.tiny_unet(inst_feats)
feats = self.tiny_unet_outputlayer(feats)
# predict mask scores
mask_scores = self.mask_linear(feats.features)
mask_scores = mask_scores[inst_map.long()]
instance_batch_idxs = feats.indices[:, 0][inst_map.long()]
# predict instance cls and iou scores
feats = self.global_pool(feats)
cls_scores = self.cls_linear(feats)
iou_scores = self.iou_score_linear(feats)
return instance_batch_idxs, cls_scores, iou_scores, mask_scores
@force_fp32(apply_to=('semantic_scores', 'cls_scores', 'iou_scores', 'mask_scores'))
def get_instances(self, scan_id, proposals_idx, semantic_scores, cls_scores, iou_scores,
mask_scores):
num_instances = cls_scores.size(0)
num_points = semantic_scores.size(0)
cls_scores = cls_scores.softmax(1)
semantic_pred = semantic_scores.max(1)[1]
cls_pred_list, score_pred_list, mask_pred_list = [], [], []
for i in range(self.instance_classes):
if i in self.sem2ins_classes:
cls_pred = cls_scores.new_tensor([i + 1], dtype=torch.long)
score_pred = cls_scores.new_tensor([1.], dtype=torch.float32)
mask_pred = (semantic_pred == i)[None, :].int()
else:
cls_pred = cls_scores.new_full((num_instances, ), i + 1, dtype=torch.long)
cur_cls_scores = cls_scores[:, i]
cur_iou_scores = iou_scores[:, i]
cur_mask_scores = mask_scores[:, i]
score_pred = cur_cls_scores * cur_iou_scores.clamp(0, 1)
mask_pred = torch.zeros((num_instances, num_points), dtype=torch.int, device='cuda')
mask_inds = cur_mask_scores > self.test_cfg.mask_score_thr
cur_proposals_idx = proposals_idx[mask_inds].long()
mask_pred[cur_proposals_idx[:, 0], cur_proposals_idx[:, 1]] = 1
# filter low score instance
inds = cur_cls_scores > self.test_cfg.cls_score_thr
cls_pred = cls_pred[inds]
score_pred = score_pred[inds]
mask_pred = mask_pred[inds]
# filter too small instances
npoint = mask_pred.sum(1)
inds = npoint >= self.test_cfg.min_npoint
cls_pred = cls_pred[inds]
score_pred = score_pred[inds]
mask_pred = mask_pred[inds]
cls_pred_list.append(cls_pred)
score_pred_list.append(score_pred)
mask_pred_list.append(mask_pred)
cls_pred = torch.cat(cls_pred_list).cpu().numpy()
score_pred = torch.cat(score_pred_list).cpu().numpy()
mask_pred = torch.cat(mask_pred_list).cpu().numpy()
instances = []
for i in range(cls_pred.shape[0]):
pred = {}
pred['scan_id'] = scan_id
pred['label_id'] = cls_pred[i]
pred['conf'] = score_pred[i]
pred['pred_mask'] = mask_pred[i]
instances.append(pred)
return instances
def get_gt_instances(self, semantic_labels, instance_labels):
"""Get gt instances for evaluation."""
# convert to evaluation format 0: ignore, 1->N: valid
label_shift = self.semantic_classes - self.instance_classes
semantic_labels = semantic_labels - label_shift + 1
semantic_labels[semantic_labels < 0] = 0
instance_labels += 1
ignore_inds = instance_labels < 0
# scannet encoding rule
gt_ins = semantic_labels * 1000 + instance_labels
gt_ins[ignore_inds] = 0
gt_ins = gt_ins.cpu().numpy()
return gt_ins
@force_fp32(apply_to='feats')
def clusters_voxelization(self,
clusters_idx,
clusters_offset,
feats,
coords,
scale,
spatial_shape,
rand_quantize=False):
batch_idx = clusters_idx[:, 0].cuda().long()
c_idxs = clusters_idx[:, 1].cuda()
feats = feats[c_idxs.long()]
coords = coords[c_idxs.long()]
coords_min = sec_min(coords, clusters_offset.cuda())
coords_max = sec_max(coords, clusters_offset.cuda())
# 0.01 to ensure voxel_coords < spatial_shape
clusters_scale = 1 / ((coords_max - coords_min) / spatial_shape).max(1)[0] - 0.01
clusters_scale = torch.clamp(clusters_scale, min=None, max=scale)
coords_min = coords_min * clusters_scale[:, None]
coords_max = coords_max * clusters_scale[:, None]
clusters_scale = clusters_scale[batch_idx]
coords = coords * clusters_scale[:, None]
if rand_quantize:
# after this, coords.long() will have some randomness
range = coords_max - coords_min
coords_min -= torch.clamp(spatial_shape - range - 0.001, min=0) * torch.rand(3).cuda()
coords_min -= torch.clamp(spatial_shape - range + 0.001, max=0) * torch.rand(3).cuda()
coords_min = coords_min[batch_idx]
coords -= coords_min
assert coords.shape.numel() == ((coords >= 0) * (coords < spatial_shape)).sum()
coords = coords.long()
coords = torch.cat([clusters_idx[:, 0].view(-1, 1).long(), coords.cpu()], 1)
out_coords, inp_map, out_map = voxelization_idx(coords, int(clusters_idx[-1, 0]) + 1)
out_feats = voxelization(feats, out_map.cuda())
spatial_shape = [spatial_shape] * 3
voxelization_feats = spconv.SparseConvTensor(out_feats,
out_coords.int().cuda(), spatial_shape,
int(clusters_idx[-1, 0]) + 1)
return voxelization_feats, inp_map
def get_batch_offsets(self, batch_idxs, bs):
batch_offsets = torch.zeros(bs + 1).int().cuda()
for i in range(bs):
batch_offsets[i + 1] = batch_offsets[i] + (batch_idxs == i).sum()
assert batch_offsets[-1] == batch_idxs.shape[0]
return batch_offsets
@force_fp32(apply_to=('x'))
def global_pool(self, x, expand=False):
indices = x.indices[:, 0]
batch_counts = torch.bincount(indices)
batch_offset = torch.cumsum(batch_counts, dim=0)
pad = batch_offset.new_full((1, ), 0)
batch_offset = torch.cat([pad, batch_offset]).int()
x_pool = global_avg_pool(x.features, batch_offset)
if not expand:
return x_pool
x_pool_expand = x_pool[indices.long()]
x.features = torch.cat((x.features, x_pool_expand), dim=1)
return x