update variable names

This commit is contained in:
Thang Vu 2022-04-08 03:19:49 +00:00
parent 70f2b7454f
commit f7a31c531e
9 changed files with 80 additions and 152 deletions

View File

@ -29,6 +29,7 @@ data:
prefix: ['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6']
suffix: '_inst_nostuff.pth'
repeat: 5
training: True
voxel_cfg:
scale: 50
spatial_shape: [128, 512]
@ -39,6 +40,7 @@ data:
data_root: 'dataset/s3dis/preprocess'
prefix: 'Area_5'
suffix: '_inst_nostuff.pth'
training: False
voxel_cfg:
scale: 50
spatial_shape: [128, 512]

View File

@ -17,6 +17,7 @@ def build_dataset(data_cfg, logger):
else:
raise ValueError(f'Unknown {data_type}')
def build_dataloader(dataset, batch_size=1, num_workers=1, training=True):
if training:
return DataLoader(

View File

@ -17,7 +17,14 @@ class CustomDataset(Dataset):
CLASSES = None
def __init__(self, data_root, prefix, suffix, voxel_cfg=None, training=True, repeat=1, logger=None):
def __init__(self,
data_root,
prefix,
suffix,
voxel_cfg=None,
training=True,
repeat=1,
logger=None):
self.data_root = data_root
self.prefix = prefix
self.suffix = suffix
@ -25,13 +32,15 @@ class CustomDataset(Dataset):
self.training = training
self.repeat = repeat
self.logger = logger
self.mode = 'train' if training else 'test'
self.filenames = self.get_filenames()
self.logger.info(f'Load {self.mode} dataset: {len(self.filenames)} scans')
def get_filenames(self):
filenames = glob(osp.join(self.data_root, self.prefix, '*' + self.suffix))
assert len(filenames) > 0, 'Empty dataset.'
filenames = sorted(filenames * self.repeat)
self.logger.info(f'Load dataset: {len(filenames)} scans')
return filenames
def load(self, filename):
return torch.load(filename)
@ -64,43 +73,19 @@ class CustomDataset(Dataset):
return x + g(x) * mag
def getInstanceInfo(self, xyz, instance_label, label):
'''
:param xyz: (n, 3)
:param instance_label: (n), int, (0~nInst-1, -100)
:return: instance_num, dict
'''
instance_info = np.ones(
(xyz.shape[0], 9), dtype=np.float32
) * -100.0 # (n, 9), float, (cx, cy, cz, minx, miny, minz, maxx, maxy, maxz)
instance_pointnum = [] # (nInst), int
pt_mean = np.ones((xyz.shape[0], 3), dtype=np.float32) * -100.0
instance_pointnum = []
instance_cls = []
instance_num = int(instance_label.max()) + 1
for i_ in range(instance_num):
inst_idx_i = np.where(instance_label == i_)
# instance_info
xyz_i = xyz[inst_idx_i]
min_xyz_i = xyz_i.min(0)
max_xyz_i = xyz_i.max(0)
mean_xyz_i = xyz_i.mean(0)
instance_info_i = instance_info[inst_idx_i]
instance_info_i[:, 0:3] = mean_xyz_i
instance_info_i[:, 3:6] = min_xyz_i
instance_info_i[:, 6:9] = max_xyz_i
instance_info[inst_idx_i] = instance_info_i
# instance_pointnum
pt_mean[inst_idx_i] = xyz_i.mean(0)
instance_pointnum.append(inst_idx_i[0].size)
cls_loc = inst_idx_i[0][0]
# ignore 2 first classes (floor, ceil)
cls = label[cls_loc] - 2 if label[cls_loc] != -100 else label[cls_loc]
instance_cls.append(cls)
return instance_num, {
"instance_info": instance_info,
"instance_pointnum": instance_pointnum,
"instance_cls": instance_cls
}
instance_cls.append(label[cls_loc])
pt_offset_label = pt_mean - xyz
return instance_num, instance_pointnum, instance_cls, pt_offset_label
def dataAugment(self, xyz, jitter=False, flip=False, rot=False):
m = np.eye(3)
@ -114,22 +99,20 @@ class CustomDataset(Dataset):
[-math.sin(theta), math.cos(theta), 0], [0, 0, 1]]) # rotation
return np.matmul(xyz, m)
def crop(self, xyz):
'''
:param xyz: (n, 3) >= 0
'''
def crop(self, xyz, step=32):
xyz_offset = xyz.copy()
valid_idxs = (xyz_offset.min(1) >= 0)
valid_idxs = xyz_offset.min(1) >= 0
assert valid_idxs.sum() == xyz.shape[0]
spatial_shape = np.array([self.voxel_cfg.spatial_shape[1]] * 3)
room_range = xyz.max(0) - xyz.min(0)
while (valid_idxs.sum() > self.voxel_cfg.max_npoint):
step_temp = step
if valid_idxs.sum() > 1e6:
step_temp = step * 2
offset = np.clip(spatial_shape - room_range + 0.001, None, 0) * np.random.rand(3)
xyz_offset = xyz + offset
valid_idxs = (xyz_offset.min(1) >= 0) * ((xyz_offset < spatial_shape).sum(1) == 3)
spatial_shape[:2] -= 32
spatial_shape[:2] -= step_temp
return xyz_offset, valid_idxs
def getCroppedInstLabel(self, instance_label, valid_idxs):
@ -179,11 +162,8 @@ class CustomDataset(Dataset):
if data is None:
return None
xyz, xyz_middle, rgb, label, instance_label = data
inst_num, inst_infos = self.getInstanceInfo(xyz_middle, instance_label.astype(np.int32),
label)
inst_info = inst_infos["instance_info"]
inst_pointnum = inst_infos["instance_pointnum"]
inst_cls = inst_infos["instance_cls"]
info = self.getInstanceInfo(xyz_middle, instance_label.astype(np.int32), label)
inst_num, inst_pointnum, inst_cls, pt_offset_label = info
loc = torch.from_numpy(xyz).long()
loc_float = torch.from_numpy(xyz_middle)
feat = torch.from_numpy(rgb).float()
@ -191,9 +171,9 @@ class CustomDataset(Dataset):
feat += torch.randn(3) * 0.1
label = torch.from_numpy(label)
instance_label = torch.from_numpy(instance_label)
inst_info = torch.from_numpy(inst_info)
return (scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_info,
inst_pointnum, inst_cls)
pt_offset_label = torch.from_numpy(pt_offset_label)
return (scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_pointnum,
inst_cls, pt_offset_label)
def collate_fn(self, batch):
scan_ids = []
@ -203,55 +183,51 @@ class CustomDataset(Dataset):
labels = []
instance_labels = []
instance_infos = [] # (N, 9)
instance_pointnum = [] # (total_nInst), int
instance_cls = [] # (total_nInst), long
batch_offsets = [0]
pt_offset_labels = []
total_inst_num = 0
batch_id = 0
for data in batch:
if data is None:
continue
(scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_info,
inst_pointnum, inst_cls) = data
(scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_pointnum,
inst_cls, pt_offset_label) = data
instance_label[np.where(instance_label != -100)] += total_inst_num
total_inst_num += inst_num
batch_offsets.append(batch_offsets[-1] + loc.size(0))
scan_ids.append(scan_id)
locs.append(torch.cat([loc.new_full((loc.size(0), 1), batch_id), loc], 1))
locs_float.append(loc_float)
feats.append(feat)
labels.append(label)
instance_labels.append(instance_label)
instance_infos.append(inst_info)
instance_pointnum.extend(inst_pointnum)
instance_cls.extend(inst_cls)
pt_offset_labels.append(pt_offset_label)
batch_id += 1
assert batch_id > 0, 'empty batch'
if batch_id < len(batch):
self.logger.info(f'batch is truncated from size {len(batch)} to {batch_id}')
# merge all the scenes in the batch
batch_offsets = torch.tensor(batch_offsets, dtype=torch.int) # int (B+1)
locs = torch.cat(locs, 0) # long (N, 1 + 3), the batch item idx is put in locs[:, 0]
batch_idxs = locs[:, 0].int()
locs_float = torch.cat(locs_float, 0).to(torch.float32) # float (N, 3)
feats = torch.cat(feats, 0) # float (N, C)
labels = torch.cat(labels, 0).long() # long (N)
instance_labels = torch.cat(instance_labels, 0).long() # long (N)
instance_infos = torch.cat(instance_infos,
0).to(torch.float32) # float (N, 9) (meanxyz, minxyz, maxxyz)
instance_pointnum = torch.tensor(instance_pointnum, dtype=torch.int) # int (total_nInst)
instance_cls = torch.tensor(instance_cls, dtype=torch.long) # long (total_nInst)
pt_offset_labels = torch.cat(pt_offset_labels).float()
spatial_shape = np.clip(
locs.max(0)[0][1:].numpy() + 1, self.voxel_cfg.spatial_shape[0], None)
voxel_locs, p2v_map, v2p_map = softgroup_ops.voxelization_idx(locs, 1)
voxel_locs, v2p_map, p2v_map = softgroup_ops.voxelization_idx(locs, 1)
return {
'scan_ids': scan_ids,
'locs': locs,
'batch_idxs': batch_idxs,
'voxel_locs': voxel_locs,
'p2v_map': p2v_map,
'v2p_map': v2p_map,
@ -259,10 +235,9 @@ class CustomDataset(Dataset):
'feats': feats,
'labels': labels,
'instance_labels': instance_labels,
'instance_info': instance_infos,
'instance_pointnum': instance_pointnum,
'instance_cls': instance_cls,
'offsets': batch_offsets,
'pt_offset_labels': pt_offset_labels,
'spatial_shape': spatial_shape,
'batch_size': batch_id,
}

View File

@ -24,7 +24,6 @@ class S3DISDataset(CustomDataset):
assert len(filenames) > 0, f'Empty {p}'
filenames_all.extend(filenames)
filenames_all = sorted(filenames_all * self.repeat)
self.logger.info(f'Load dataset: {len(filenames_all)} scans')
return filenames_all
def load(self, filename):
@ -41,53 +40,7 @@ class S3DISDataset(CustomDataset):
return xyz, rgb, label, instance_label
def crop(self, xyz, step=64):
xyz_offset = xyz.copy()
valid_idxs = (xyz_offset.min(1) >= 0) * (
(xyz < self.voxel_cfg.spatial_shape[1]).sum(1) == 3)
spatial_shape = np.array([self.voxel_cfg.spatial_shape[1]] * 3)
room_range = xyz.max(0) - xyz.min(0)
while (valid_idxs.sum() > self.voxel_cfg.max_npoint):
step_temp = step
if valid_idxs.sum() > 1e6:
step_temp = step * 2
offset = np.clip(spatial_shape - room_range + 0.001, None, 0) * np.random.rand(3)
xyz_offset = xyz + offset
valid_idxs = (xyz_offset.min(1) >= 0) * ((xyz_offset < spatial_shape).sum(1) == 3)
spatial_shape[:2] -= step_temp
return xyz_offset, valid_idxs
def getInstanceInfo(self, xyz, instance_label, label):
instance_info = np.ones((xyz.shape[0], 9), dtype=np.float32) * -100.0
instance_pointnum = [] # (nInst), int
instance_cls = []
instance_num = int(instance_label.max()) + 1
for i_ in range(instance_num):
inst_idx_i = np.where(instance_label == i_)
# instance_info
xyz_i = xyz[inst_idx_i]
min_xyz_i = xyz_i.min(0)
max_xyz_i = xyz_i.max(0)
mean_xyz_i = xyz_i.mean(0)
instance_info_i = instance_info[inst_idx_i]
instance_info_i[:, 0:3] = mean_xyz_i
instance_info_i[:, 3:6] = min_xyz_i
instance_info_i[:, 6:9] = max_xyz_i
instance_info[inst_idx_i] = instance_info_i
# instance_pointnum
instance_pointnum.append(inst_idx_i[0].size)
cls_loc = inst_idx_i[0][0]
instance_cls.append(label[cls_loc])
# assert (0 not in instance_cls) and (1 not in instance_cls) # sanity check stuff cls
return instance_num, {
"instance_info": instance_info,
"instance_pointnum": instance_pointnum,
"instance_cls": instance_cls
}
return super().crop(xyz, step=step)
def transform_test(self, xyz, rgb, label, instance_label):
# devide into 4 piecies
@ -120,23 +73,24 @@ class S3DISDataset(CustomDataset):
return super().collate_fn(batch)
# assume 1 scan only
(scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_info, inst_pointnum,
inst_cls) = batch[0]
(scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_pointnum, inst_cls,
pt_offset_label) = batch[0]
scan_ids = [scan_id]
locs = loc.long()
batch_idxs = torch.zeros_like(loc[:, 0].int())
locs_float = loc_float.float()
feats = feat.float()
labels = label.long()
instance_labels = instance_label.long()
instance_infos = inst_info.float()
instance_pointnum = torch.tensor([inst_pointnum], dtype=torch.int)
instance_cls = torch.tensor([inst_cls], dtype=torch.long)
pt_offset_labels = pt_offset_label.float()
spatial_shape = np.clip((locs.max(0)[0][1:] + 1).numpy(), self.voxel_cfg.spatial_shape[0],
None)
voxel_locs, p2v_map, v2p_map = softgroup_ops.voxelization_idx(locs, 4)
voxel_locs, v2p_map, p2v_map = softgroup_ops.voxelization_idx(locs, 4)
return {
'scan_ids': scan_ids,
'locs': locs,
'batch_idxs': batch_idxs,
'voxel_locs': voxel_locs,
'p2v_map': p2v_map,
'v2p_map': v2p_map,
@ -144,8 +98,9 @@ class S3DISDataset(CustomDataset):
'feats': feats,
'labels': labels,
'instance_labels': instance_labels,
'instance_info': instance_infos,
'instance_pointnum': instance_pointnum,
'instance_cls': instance_cls,
'spatial_shape': spatial_shape
'pt_offset_labels': pt_offset_labels,
'spatial_shape': spatial_shape,
'batch_size': 4
}

View File

@ -6,3 +6,9 @@ class ScanNetDataset(CustomDataset):
CLASSES = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture',
'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink',
'bathtub', 'otherfurniture')
def getInstanceInfo(self, xyz, instance_label, label):
ret = super().getInstanceInfo(xyz, instnace_label, label)
instance_num, instance_pointnum, instance_cls, pt_offset_label = ret
instance_cls = [x - 2 if x != -100 else x for x in instance_cls]
return instance_num, instance_pointnum, instance_cls, pt_offset_label

View File

@ -409,7 +409,7 @@ class ScanNetEval(object):
"""
print('evaluating', len(pred_list), 'scans...')
matches = {}
for i, (preds, gts) in enumerate(tqdm(zip(pred_list, gt_list))):
for i, (preds, gts) in enumerate(tqdm(zip(pred_list, gt_list), total=len(pred_list))):
gt2pred, pred2gt = self.assign_instances_for_scan(preds, gts)
# assign gt to predictions
matches_key = f'gt_{i}'

View File

@ -103,7 +103,7 @@ class SoftGroup(nn.Module):
return self.forward_test(batch)
def forward_train(self, batch):
coords = batch['locs'].cuda()
batch_idxs = batch['batch_idxs'].cuda()
voxel_coords = batch['voxel_locs'].cuda()
p2v_map = batch['p2v_map'].cuda()
v2p_map = batch['v2p_map'].cuda()
@ -111,26 +111,24 @@ class SoftGroup(nn.Module):
feats = batch['feats'].cuda()
semantic_labels = batch['labels'].cuda()
instance_labels = batch['instance_labels'].cuda()
instance_info = batch['instance_info'].cuda()
# instance_pointnum = batch['instance_pointnum'].cuda()
# instance_cls = batch['instance_cls'].cuda()
# batch_offsets = batch['offsets'].cuda()
pt_offset_labels = batch['pt_offset_labels'].cuda()
spatial_shape = batch['spatial_shape']
batch_size = batch['batch_size']
feats = torch.cat((feats, coords_float), 1)
voxel_feats = softgroup_ops.voxelization(feats, v2p_map)
losses = {}
pt_offset_labels = instance_info[:, :3] - coords_float
feats = torch.cat((feats, coords_float), 1)
voxel_feats = softgroup_ops.voxelization(feats, p2v_map)
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
semantic_scores, pt_offsets, output_feats, coords_float = self.forward_backbone(
input, p2v_map, coords_float) # TODO check name for map
input, v2p_map, coords_float)
point_wise_loss = self.point_wise_loss(semantic_scores, pt_offsets, semantic_labels,
instance_labels, pt_offset_labels)
losses.update(point_wise_loss)
loss = sum(v[0] for v in losses.values())
losses['loss'] = (loss, coords.size(0))
losses['loss'] = (loss, batch_idxs.size(0))
return loss, losses
def point_wise_loss(self, semantic_scores, pt_offsets, semantic_labels, instance_labels,
@ -141,7 +139,7 @@ class SoftGroup(nn.Module):
pos_inds = instance_labels != self.ignore_label
if pos_inds.sum() == 0:
offset_loss = 0 * pt_offset.sum()
offset_loss = 0 * pt_offsets.sum()
else:
offset_loss = self.offset_loss(pt_offsets[pos_inds],
pt_offset_labels[pos_inds]) / pos_inds.sum()
@ -149,7 +147,7 @@ class SoftGroup(nn.Module):
return losses
def forward_test(self, batch):
coords = batch['locs'].cuda()
batch_idxs = batch['batch_idxs'].cuda()
voxel_coords = batch['voxel_locs'].cuda()
p2v_map = batch['p2v_map'].cuda()
v2p_map = batch['v2p_map'].cuda()
@ -157,23 +155,17 @@ class SoftGroup(nn.Module):
feats = batch['feats'].cuda()
labels = batch['labels'].cuda()
instance_labels = batch['instance_labels'].cuda()
# instance_info = batch['instance_info'].cuda()
# instance_pointnum = batch['instance_pointnum'].cuda()
# instance_cls = batch['instance_cls'].cuda()
# batch_offsets = batch['offsets'].cuda()
spatial_shape = batch['spatial_shape']
batch_size = batch['batch_size']
feats = torch.cat((feats, coords_float), 1)
voxel_feats = softgroup_ops.voxelization(feats, v2p_map)
if self.test_cfg.x4_split:
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, 4)
batch_idxs = torch.zeros_like(coords[:, 0].int())
else:
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, 1)
batch_idxs = coords[:, 0].int()
voxel_feats = softgroup_ops.voxelization(feats, p2v_map)
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
semantic_scores, pt_offsets, output_feats, coords_float = self.forward_backbone(
input, p2v_map, coords_float,
x4_split=self.test_cfg.x4_split) # TODO check name for map
input, v2p_map, coords_float, x4_split=self.test_cfg.x4_split)
proposals_idx, proposals_offset = self.forward_grouping(semantic_scores, pt_offsets,
batch_idxs, coords_float,
self.grouping_cfg)

24
test.py
View File

@ -12,6 +12,8 @@ from model.softgroup import SoftGroup
from data.scannetv2 import ScanNetDataset
from torch.utils.data import DataLoader
from util import get_root_logger
from data import build_dataset, build_dataloader
def get_args():
@ -62,6 +64,7 @@ def evaluate_semantic_segmantation_miou(matches):
if __name__ == '__main__':
torch.backends.cudnn.enabled = False # TODO remove this
test_seed = 567
random.seed(test_seed)
np.random.seed(test_seed)
@ -69,24 +72,17 @@ if __name__ == '__main__':
torch.cuda.manual_seed_all(test_seed)
args = get_args()
cfg = Munch.fromDict(yaml.safe_load(open(args.config, 'r')))
torch.backends.cudnn.enabled = False
cfg_txt = open(args.config, 'r').read()
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
logger = get_root_logger()
model = SoftGroup(**cfg.model)
print(f'Load state dict from {args.checkpoint}')
model = utils.load_checkpoint(model, args.checkpoint)
logger.info(f'Load state dict from {args.checkpoint}')
utils.load_checkpoint(args.checkpoint, logger, model)
model.cuda()
dataset = ScanNetDataset(training=False, **cfg.data.test)
dataloader = DataLoader(
dataset,
batch_size=1,
collate_fn=dataset.collate_fn,
num_workers=16,
shuffle=False,
drop_last=False,
pin_memory=True)
dataset = build_dataset(cfg.data.test, logger)
dataloader = build_dataloader(dataset, training=False)
all_preds, all_gts = [], []
with torch.no_grad():
model = model.eval()

View File

@ -103,6 +103,7 @@ if __name__ == '__main__':
utils.load_checkpoint(cfg.pretrain, logger, model)
# train and val
logger.info('Training')
for epoch in range(start_epoch, cfg.epochs + 1):
model.train()
iter_time = utils.AverageMeter()