mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
merge scannet & s3dis training
This commit is contained in:
parent
d52fb85fdd
commit
b6d155fc56
@ -30,6 +30,7 @@ STRUCTURE:
|
||||
block_residual: True
|
||||
block_reps: 2
|
||||
use_coords: True
|
||||
semantic_only: False
|
||||
|
||||
TRAIN:
|
||||
epochs: 500
|
||||
@ -40,20 +41,20 @@ TRAIN:
|
||||
multiplier: 0.5
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0001
|
||||
save_freq: 4 # also eval_freq
|
||||
save_freq: 16 # also eval_freq
|
||||
loss_weight: [1.0, 1.0, 1.0, 1.0, 1.0] # semantic_loss, offset_norm_loss, cls_loss, mask_loss, score_loss
|
||||
fg_thresh: 1.
|
||||
bg_thresh: 0.
|
||||
score_scale: 50 # the minimal voxel size is 2cm
|
||||
score_fullscale: 20
|
||||
score_mode: 4 # mean
|
||||
pretrain_path:
|
||||
pretrain_module: []
|
||||
fix_module: []
|
||||
pretrain_path: 'hais_ckpt.pth'
|
||||
pretrain_module: ['input_conv', 'unet', 'output_layer', 'semantic_linear', 'offset_linear', 'intra_ins_unet', 'intra_ins_outputlayer']
|
||||
fix_module: ['input_conv', 'unet', 'output_layer', 'semantic_linear', 'offset_linear']
|
||||
|
||||
point_aggr_radius: 0.04
|
||||
cluster_shift_meanActive: 300
|
||||
prepare_epochs: 100
|
||||
prepare_epochs: -1
|
||||
using_set_aggr_in_training: False
|
||||
using_set_aggr_in_testing: False
|
||||
max_proposal_num: 200
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
GENERAL:
|
||||
task: train # train, test
|
||||
manual_seed: 456
|
||||
model_dir: model/hais/hais.py
|
||||
model_dir: model/softgroup/softgroup.py
|
||||
dataset_dir: data/scannetv2_inst.py
|
||||
|
||||
DATA:
|
||||
@ -26,7 +26,7 @@ DATA:
|
||||
mode: 4 # 4=mean
|
||||
|
||||
STRUCTURE:
|
||||
model_name: hais
|
||||
model_name: softgroup
|
||||
width: 32
|
||||
block_residual: True
|
||||
block_reps: 2
|
||||
@ -49,7 +49,7 @@ TRAIN:
|
||||
score_scale: 50 # the minimal voxel size is 2cm
|
||||
score_fullscale: 20
|
||||
score_mode: 4 # mean
|
||||
pretrain_path: # './exp/s3dis/hais/hais_fold5_s3dis/hais_fold5_s3dis-000000030.pth'
|
||||
pretrain_path: './exp/s3dis/softgroup/softgroup_fold5_s3dis/softgroup_fold5_s3dis-000000030.pth'
|
||||
pretrain_module: ['input_conv', 'unet', 'output_layer',
|
||||
'semantic_linear', 'offset_linear',
|
||||
'intra_ins_unet', 'intra_ins_outputlayer',
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
GENERAL:
|
||||
task: train # train, test
|
||||
manual_seed: 123
|
||||
model_dir: model/hais/hais.py
|
||||
model_dir: model/softgroup/softgroup.py
|
||||
dataset_dir: data/scannetv2_inst.py
|
||||
|
||||
DATA:
|
||||
@ -12,6 +12,7 @@ DATA:
|
||||
test_area: 'Area_5'
|
||||
train_repeats: 5
|
||||
|
||||
semantic_classes: 13
|
||||
classes: 13
|
||||
ignore_label: -100
|
||||
|
||||
@ -23,7 +24,7 @@ DATA:
|
||||
mode: 4 # 4=mean
|
||||
|
||||
STRUCTURE:
|
||||
model_name: hais
|
||||
model_name: softgroup
|
||||
width: 32
|
||||
block_residual: True
|
||||
block_reps: 2
|
||||
@ -46,7 +47,7 @@ TRAIN:
|
||||
score_scale: 50 # the minimal voxel size is 2cm
|
||||
score_fullscale: 20
|
||||
score_mode: 4 # mean
|
||||
pretrain_path: 'hais_ckpt.pth'
|
||||
pretrain_path: 'softgroup_ckpt.pth'
|
||||
pretrain_module: ['input_conv', 'unet', 'output_layer']
|
||||
fix_module: []
|
||||
|
||||
@ -69,6 +70,7 @@ TRAIN:
|
||||
max_clusters: 100
|
||||
|
||||
iou_thr: 0.5
|
||||
score_thr: 0.2
|
||||
|
||||
TEST:
|
||||
split: val
|
||||
|
||||
@ -258,7 +258,6 @@ class Dataset:
|
||||
# crop
|
||||
xyz, valid_idxs = self.crop(xyz)
|
||||
if valid_idxs.sum() == 0: # handle some corner cases
|
||||
print('bad epoch')
|
||||
continue
|
||||
|
||||
xyz_middle = xyz_middle[valid_idxs]
|
||||
|
||||
@ -146,9 +146,10 @@ class Dataset:
|
||||
# instance_pointnum
|
||||
instance_pointnum.append(inst_idx_i[0].size)
|
||||
cls_loc = inst_idx_i[0][0]
|
||||
instance_cls.append(label[cls_loc])
|
||||
assert (0 not in instance_cls) and (1 not in instance_cls) # sanity check stuff cls
|
||||
|
||||
# ignore 2 first classes (floor, ceil)
|
||||
cls = label[cls_loc] - 2 if label[cls_loc] != -100 else label[cls_loc]
|
||||
instance_cls.append(cls)
|
||||
return instance_num, {"instance_info": instance_info, "instance_pointnum": instance_pointnum,
|
||||
"instance_cls": instance_cls}
|
||||
|
||||
|
||||
@ -101,7 +101,7 @@ class UBlock(nn.Module):
|
||||
return output
|
||||
|
||||
class SoftGroup(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
def __init__(self, cfg, pretrained=True):
|
||||
super().__init__()
|
||||
|
||||
input_c = cfg.input_channel
|
||||
@ -183,14 +183,15 @@ class SoftGroup(nn.Module):
|
||||
module_map = {'input_conv': self.input_conv, 'unet': self.unet, 'output_layer': self.output_layer,
|
||||
'semantic_linear': self.semantic_linear, 'offset_linear': self.offset_linear,
|
||||
'intra_ins_unet': self.intra_ins_unet, 'intra_ins_outputlayer': self.intra_ins_outputlayer,
|
||||
'score_linear': self.score_linear, 'mask_linear': self.mask_linear}
|
||||
'score_linear': self.score_linear, 'mask_linear': self.mask_linear,
|
||||
'cls_linear': self.cls_linear}
|
||||
for m in self.fix_module:
|
||||
mod = module_map[m]
|
||||
for param in mod.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# load pretrain weights
|
||||
if self.pretrain_path is not None:
|
||||
if pretrained and self.pretrain_path is not None:
|
||||
pretrain_dict = torch.load(self.pretrain_path)
|
||||
for m in self.pretrain_module:
|
||||
print("Load pretrained " + m + ": %d/%d" % utils.load_model_param(module_map[m], pretrain_dict['net'], prefix=m))
|
||||
@ -298,21 +299,7 @@ class SoftGroup(nn.Module):
|
||||
x_new[p] = x_split[i]
|
||||
return x_new
|
||||
|
||||
def forward_point_wise_network(self, input, input_map):
|
||||
if self.cfg.dataset == 'scannetv2':
|
||||
output = self.input_conv(input)
|
||||
output = self.unet(output)
|
||||
output = self.output_layer(output)
|
||||
output_feats = output.features[input_map.long()]
|
||||
elif self.cfg.dataset == 's3dis':
|
||||
output_feats = self.forward_chop(input, input_map)
|
||||
output_feats = self.rearange(output_feats)
|
||||
|
||||
semantic_scores = self.semantic_linear(output_feats) # (N, nClass), float
|
||||
semantic_preds = semantic_scores.max(1)[1] # (N), long
|
||||
pt_offsets = self.offset_linear(output_feats) # (N, 3)
|
||||
|
||||
def forward(self, input, input_map, coords, batch_idxs, batch_offsets, epoch, training_mode, gt_instances=None, split=False):
|
||||
def forward(self, input, input_map, coords, batch_idxs, batch_offsets, epoch, training_mode, gt_instances=None, split=False, semantic_only=False):
|
||||
'''
|
||||
:param input_map: (N), int, cuda
|
||||
:param coords: (N, 3), float, cuda
|
||||
@ -338,7 +325,7 @@ class SoftGroup(nn.Module):
|
||||
ret['semantic_scores'] = semantic_scores
|
||||
ret['pt_offsets'] = pt_offsets
|
||||
|
||||
if(epoch > self.prepare_epochs):
|
||||
if(epoch > self.prepare_epochs) and not semantic_only:
|
||||
thr = self.cfg.score_thr
|
||||
semantic_scores = semantic_scores.softmax(dim=-1)
|
||||
proposals_idx_list = []
|
||||
@ -467,7 +454,7 @@ def model_fn_decorator(test=False):
|
||||
gt_instances = torch.cat(gt_instances)
|
||||
return gt_cls, gt_instances
|
||||
|
||||
def test_model_fn(batch, model, epoch):
|
||||
def test_model_fn(batch, model, epoch, semantic_only=False):
|
||||
coords = batch['locs'].cuda() # (N, 1 + 3), long, cuda, dimension 0 for batch_idx
|
||||
voxel_coords = batch['voxel_locs'].cuda() # (M, 1 + 3), long, cuda
|
||||
p2v_map = batch['p2v_map'].cuda() # (N), int, cuda
|
||||
@ -486,15 +473,15 @@ def model_fn_decorator(test=False):
|
||||
if cfg.dataset == 'scannetv2':
|
||||
input_ = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, 1)
|
||||
|
||||
ret = model(input_, p2v_map, coords_float, coords[:, 0].int(), batch_offsets, epoch, 'test')
|
||||
ret = model(input_, p2v_map, coords_float, coords[:, 0].int(), batch_offsets, epoch, 'test', semantic_only=semantic_only)
|
||||
elif cfg.dataset == 's3dis':
|
||||
input_ = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, 4)
|
||||
batch_idxs = torch.zeros_like(coords[:, 0].int())
|
||||
ret = model(input_, p2v_map, coords_float, batch_idxs, batch_offsets, epoch, 'test', split=True)
|
||||
ret = model(input_, p2v_map, coords_float, batch_idxs, batch_offsets, epoch, 'test', split=True, semantic_only=semantic_only)
|
||||
semantic_scores = ret['semantic_scores'] # (N, nClass) float32, cuda
|
||||
pt_offsets = ret['pt_offsets'] # (N, 3), float32, cuda
|
||||
|
||||
if (epoch > cfg.prepare_epochs):
|
||||
if (epoch > cfg.prepare_epochs) and not semantic_only:
|
||||
scores_batch_idxs, cls_scores, scores, proposals_idx, proposals_offset, mask_scores = ret['proposal_scores']
|
||||
|
||||
# preds
|
||||
@ -502,14 +489,14 @@ def model_fn_decorator(test=False):
|
||||
preds = {}
|
||||
preds['semantic'] = semantic_scores
|
||||
preds['pt_offsets'] = pt_offsets
|
||||
if (epoch > cfg.prepare_epochs):
|
||||
if (epoch > cfg.prepare_epochs) and not semantic_only:
|
||||
preds['score'] = scores
|
||||
preds['cls_score'] = cls_scores
|
||||
preds['proposals'] = (scores_batch_idxs, proposals_idx, proposals_offset, mask_scores)
|
||||
|
||||
return preds
|
||||
|
||||
def model_fn(batch, model, epoch):
|
||||
def model_fn(batch, model, epoch, semantic_only=False):
|
||||
# batch {'locs': locs, 'voxel_locs': voxel_locs, 'p2v_map': p2v_map, 'v2p_map': v2p_map,
|
||||
# 'locs_float': locs_float, 'feats': feats, 'labels': labels, 'instance_labels': instance_labels,
|
||||
# 'instance_info': instance_infos, 'instance_pointnum': instance_pointnum,
|
||||
@ -537,13 +524,11 @@ def model_fn_decorator(test=False):
|
||||
|
||||
input_ = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, cfg.batch_size)
|
||||
|
||||
# gt_cls, gt_instances = get_gt_instances(labels, instance_labels)
|
||||
gt_instances = None
|
||||
ret = model(input_, p2v_map, coords_float, coords[:, 0].int(), batch_offsets, epoch, 'train', gt_instances=gt_instances)
|
||||
ret = model(input_, p2v_map, coords_float, coords[:, 0].int(), batch_offsets, epoch, 'train', semantic_only=semantic_only)
|
||||
semantic_scores = ret['semantic_scores'] # (N, nClass) float32, cuda
|
||||
pt_offsets = ret['pt_offsets'] # (N, 3), float32, cuda
|
||||
|
||||
if(epoch > cfg.prepare_epochs):
|
||||
if(epoch > cfg.prepare_epochs) and not semantic_only:
|
||||
scores_batch_idxs, cls_scores, scores, proposals_idx, proposals_offset, mask_scores = ret['proposal_scores']
|
||||
# scores: (nProposal, 1) float, cuda
|
||||
# proposals_idx: (sumNPoint, 2), int, cpu, [:, 0] for cluster_id, [:, 1] for corresponding point idxs in N
|
||||
@ -555,17 +540,17 @@ def model_fn_decorator(test=False):
|
||||
loss_inp['semantic_scores'] = (semantic_scores, labels)
|
||||
loss_inp['pt_offsets'] = (pt_offsets, coords_float, instance_info, instance_labels)
|
||||
|
||||
if(epoch > cfg.prepare_epochs):
|
||||
if(epoch > cfg.prepare_epochs) and not semantic_only:
|
||||
loss_inp['proposal_scores'] = (scores_batch_idxs, cls_scores, scores, proposals_idx, proposals_offset, instance_pointnum, instance_cls, mask_scores)
|
||||
|
||||
loss, loss_out = loss_fn(loss_inp, epoch)
|
||||
loss, loss_out = loss_fn(loss_inp, epoch, semantic_only=semantic_only)
|
||||
|
||||
# accuracy / visual_dict / meter_dict
|
||||
with torch.no_grad():
|
||||
preds = {}
|
||||
preds['semantic'] = semantic_scores
|
||||
preds['pt_offsets'] = pt_offsets
|
||||
if(epoch > cfg.prepare_epochs):
|
||||
if(epoch > cfg.prepare_epochs) and not semantic_only:
|
||||
preds['score'] = scores
|
||||
preds['proposals'] = (proposals_idx, proposals_offset)
|
||||
|
||||
@ -582,7 +567,7 @@ def model_fn_decorator(test=False):
|
||||
return loss, preds, visual_dict, meter_dict
|
||||
|
||||
|
||||
def loss_fn(loss_inp, epoch):
|
||||
def loss_fn(loss_inp, epoch, semantic_only=False):
|
||||
|
||||
loss_out = {}
|
||||
|
||||
@ -612,7 +597,7 @@ def model_fn_decorator(test=False):
|
||||
offset_norm_loss = torch.sum(pt_dist * valid) / (torch.sum(valid) + 1e-6)
|
||||
loss_out['offset_norm_loss'] = (offset_norm_loss, valid.sum())
|
||||
|
||||
if (epoch > cfg.prepare_epochs):
|
||||
if (epoch > cfg.prepare_epochs) and not semantic_only:
|
||||
'''score and mask loss'''
|
||||
|
||||
scores_batch_idxs, cls_scores, scores, proposals_idx, proposals_offset, instance_pointnum, instance_cls, mask_scores = loss_inp['proposal_scores']
|
||||
@ -652,8 +637,8 @@ def model_fn_decorator(test=False):
|
||||
pos_gt_inds = gt_inds[pos_inds]
|
||||
|
||||
# compute cls loss. follow detection convention: 0 -> K - 1 are fg, K is bg
|
||||
labels = fg_instance_cls.new_full((fg_ious_on_cluster.size(0), ), cfg.classes - 2) # ignore 2 first classes
|
||||
labels[pos_inds] = fg_instance_cls[pos_gt_inds] - 2
|
||||
labels = fg_instance_cls.new_full((fg_ious_on_cluster.size(0), ), cfg.classes)
|
||||
labels[pos_inds] = fg_instance_cls[pos_gt_inds]
|
||||
cls_loss = F.cross_entropy(cls_scores, labels)
|
||||
loss_out['cls_loss'] = (cls_loss, labels.size(0))
|
||||
|
||||
@ -688,7 +673,7 @@ def model_fn_decorator(test=False):
|
||||
# gt_scores = get_segmented_scores(gt_ious, cfg.fg_thresh, cfg.bg_thresh)
|
||||
|
||||
slice_inds = torch.arange(0, labels.size(0), dtype=torch.long, device=labels.device)
|
||||
score_weight = (labels < cfg.classes - 2).float()
|
||||
score_weight = (labels < cfg.classes).float()
|
||||
score_slice = scores[slice_inds, labels]
|
||||
score_loss = F.mse_loss(score_slice, gt_ious, reduction='none')
|
||||
score_loss = (score_loss * score_weight).sum() / (score_weight.sum() + 1)
|
||||
@ -698,30 +683,13 @@ def model_fn_decorator(test=False):
|
||||
|
||||
'''total loss'''
|
||||
loss = cfg.loss_weight[0] * semantic_loss + cfg.loss_weight[1] * offset_norm_loss
|
||||
if(epoch > cfg.prepare_epochs):
|
||||
if(epoch > cfg.prepare_epochs) and not semantic_only:
|
||||
loss += (cfg.loss_weight[2] * cls_loss)
|
||||
loss += (cfg.loss_weight[3] * mask_loss)
|
||||
loss += (cfg.loss_weight[4] * score_loss)
|
||||
|
||||
return loss, loss_out
|
||||
|
||||
|
||||
def get_segmented_scores(scores, fg_thresh=1.0, bg_thresh=0.0):
|
||||
'''
|
||||
:param scores: (N), float, 0~1
|
||||
:return: segmented_scores: (N), float 0~1, >fg_thresh: 1, <bg_thresh: 0, mid: linear
|
||||
'''
|
||||
fg_mask = scores > fg_thresh
|
||||
bg_mask = scores < bg_thresh
|
||||
interval_mask = (fg_mask == 0) & (bg_mask == 0)
|
||||
|
||||
segmented_scores = (fg_mask > 0).float()
|
||||
k = 1 / (fg_thresh - bg_thresh + 1e-5)
|
||||
b = bg_thresh / (bg_thresh - fg_thresh + 1e-5)
|
||||
segmented_scores[interval_mask] = scores[interval_mask] * k + b
|
||||
|
||||
return segmented_scores
|
||||
|
||||
if test:
|
||||
fn = test_model_fn
|
||||
else:
|
||||
|
||||
2
test.py
2
test.py
@ -296,7 +296,7 @@ if __name__ == '__main__':
|
||||
else:
|
||||
print("Error: no model version " + model_name)
|
||||
exit(0)
|
||||
model = Network(cfg)
|
||||
model = Network(cfg, pretrained=False)
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
logger.info('cuda available: {}'.format(use_cuda))
|
||||
|
||||
@ -321,7 +321,7 @@ if __name__ == '__main__':
|
||||
else:
|
||||
print("Error: no model version " + model_name)
|
||||
exit(0)
|
||||
model = Network(cfg)
|
||||
model = Network(cfg, pretrained=False)
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
logger.info('cuda available: {}'.format(use_cuda))
|
||||
|
||||
9
train.py
9
train.py
@ -58,7 +58,7 @@ def train_epoch(train_loader, model, model_fn, optimizer, epoch):
|
||||
|
||||
|
||||
# prepare input and forward
|
||||
loss, _, visual_dict, meter_dict = model_fn(batch, model, epoch)
|
||||
loss, _, visual_dict, meter_dict = model_fn(batch, model, epoch, semantic_only=cfg.semantic_only)
|
||||
|
||||
# meter_dict
|
||||
for k, v in meter_dict.items():
|
||||
@ -114,7 +114,7 @@ def eval_epoch(val_loader, model, model_fn, epoch):
|
||||
for i, batch in enumerate(val_loader):
|
||||
|
||||
# prepare input and forward
|
||||
loss, preds, visual_dict, meter_dict = model_fn(batch, model, epoch)
|
||||
loss, preds, visual_dict, meter_dict = model_fn(batch, model, epoch, semantic_only=cfg.semantic_only)
|
||||
|
||||
for k, v in meter_dict.items():
|
||||
if k not in am_dict.keys():
|
||||
@ -189,6 +189,11 @@ if __name__ == '__main__':
|
||||
else:
|
||||
print("Error: no data loader - " + data_name)
|
||||
exit(0)
|
||||
elif cfg.dataset == 's3dis' and data_name == 's3dis':
|
||||
import data.s3dis_inst
|
||||
dataset = data.s3dis_inst.Dataset()
|
||||
dataset.trainLoader()
|
||||
dataset.valLoader()
|
||||
else:
|
||||
raise NotImplementedError("Not yet supported")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user