mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
update variable names
This commit is contained in:
parent
70f2b7454f
commit
f7a31c531e
@ -29,6 +29,7 @@ data:
|
||||
prefix: ['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6']
|
||||
suffix: '_inst_nostuff.pth'
|
||||
repeat: 5
|
||||
training: True
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
@ -39,6 +40,7 @@ data:
|
||||
data_root: 'dataset/s3dis/preprocess'
|
||||
prefix: 'Area_5'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: False
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
|
||||
@ -17,6 +17,7 @@ def build_dataset(data_cfg, logger):
|
||||
else:
|
||||
raise ValueError(f'Unknown {data_type}')
|
||||
|
||||
|
||||
def build_dataloader(dataset, batch_size=1, num_workers=1, training=True):
|
||||
if training:
|
||||
return DataLoader(
|
||||
|
||||
@ -17,7 +17,14 @@ class CustomDataset(Dataset):
|
||||
|
||||
CLASSES = None
|
||||
|
||||
def __init__(self, data_root, prefix, suffix, voxel_cfg=None, training=True, repeat=1, logger=None):
|
||||
def __init__(self,
|
||||
data_root,
|
||||
prefix,
|
||||
suffix,
|
||||
voxel_cfg=None,
|
||||
training=True,
|
||||
repeat=1,
|
||||
logger=None):
|
||||
self.data_root = data_root
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
@ -25,13 +32,15 @@ class CustomDataset(Dataset):
|
||||
self.training = training
|
||||
self.repeat = repeat
|
||||
self.logger = logger
|
||||
self.mode = 'train' if training else 'test'
|
||||
self.filenames = self.get_filenames()
|
||||
self.logger.info(f'Load {self.mode} dataset: {len(self.filenames)} scans')
|
||||
|
||||
def get_filenames(self):
|
||||
filenames = glob(osp.join(self.data_root, self.prefix, '*' + self.suffix))
|
||||
assert len(filenames) > 0, 'Empty dataset.'
|
||||
filenames = sorted(filenames * self.repeat)
|
||||
self.logger.info(f'Load dataset: {len(filenames)} scans')
|
||||
return filenames
|
||||
|
||||
def load(self, filename):
|
||||
return torch.load(filename)
|
||||
@ -64,43 +73,19 @@ class CustomDataset(Dataset):
|
||||
return x + g(x) * mag
|
||||
|
||||
def getInstanceInfo(self, xyz, instance_label, label):
|
||||
'''
|
||||
:param xyz: (n, 3)
|
||||
:param instance_label: (n), int, (0~nInst-1, -100)
|
||||
:return: instance_num, dict
|
||||
'''
|
||||
instance_info = np.ones(
|
||||
(xyz.shape[0], 9), dtype=np.float32
|
||||
) * -100.0 # (n, 9), float, (cx, cy, cz, minx, miny, minz, maxx, maxy, maxz)
|
||||
instance_pointnum = [] # (nInst), int
|
||||
pt_mean = np.ones((xyz.shape[0], 3), dtype=np.float32) * -100.0
|
||||
instance_pointnum = []
|
||||
instance_cls = []
|
||||
instance_num = int(instance_label.max()) + 1
|
||||
for i_ in range(instance_num):
|
||||
inst_idx_i = np.where(instance_label == i_)
|
||||
|
||||
# instance_info
|
||||
xyz_i = xyz[inst_idx_i]
|
||||
min_xyz_i = xyz_i.min(0)
|
||||
max_xyz_i = xyz_i.max(0)
|
||||
mean_xyz_i = xyz_i.mean(0)
|
||||
instance_info_i = instance_info[inst_idx_i]
|
||||
instance_info_i[:, 0:3] = mean_xyz_i
|
||||
instance_info_i[:, 3:6] = min_xyz_i
|
||||
instance_info_i[:, 6:9] = max_xyz_i
|
||||
instance_info[inst_idx_i] = instance_info_i
|
||||
|
||||
# instance_pointnum
|
||||
pt_mean[inst_idx_i] = xyz_i.mean(0)
|
||||
instance_pointnum.append(inst_idx_i[0].size)
|
||||
cls_loc = inst_idx_i[0][0]
|
||||
|
||||
# ignore 2 first classes (floor, ceil)
|
||||
cls = label[cls_loc] - 2 if label[cls_loc] != -100 else label[cls_loc]
|
||||
instance_cls.append(cls)
|
||||
return instance_num, {
|
||||
"instance_info": instance_info,
|
||||
"instance_pointnum": instance_pointnum,
|
||||
"instance_cls": instance_cls
|
||||
}
|
||||
instance_cls.append(label[cls_loc])
|
||||
pt_offset_label = pt_mean - xyz
|
||||
return instance_num, instance_pointnum, instance_cls, pt_offset_label
|
||||
|
||||
def dataAugment(self, xyz, jitter=False, flip=False, rot=False):
|
||||
m = np.eye(3)
|
||||
@ -114,22 +99,20 @@ class CustomDataset(Dataset):
|
||||
[-math.sin(theta), math.cos(theta), 0], [0, 0, 1]]) # rotation
|
||||
return np.matmul(xyz, m)
|
||||
|
||||
def crop(self, xyz):
|
||||
'''
|
||||
:param xyz: (n, 3) >= 0
|
||||
'''
|
||||
def crop(self, xyz, step=32):
|
||||
xyz_offset = xyz.copy()
|
||||
valid_idxs = (xyz_offset.min(1) >= 0)
|
||||
valid_idxs = xyz_offset.min(1) >= 0
|
||||
assert valid_idxs.sum() == xyz.shape[0]
|
||||
|
||||
spatial_shape = np.array([self.voxel_cfg.spatial_shape[1]] * 3)
|
||||
room_range = xyz.max(0) - xyz.min(0)
|
||||
while (valid_idxs.sum() > self.voxel_cfg.max_npoint):
|
||||
step_temp = step
|
||||
if valid_idxs.sum() > 1e6:
|
||||
step_temp = step * 2
|
||||
offset = np.clip(spatial_shape - room_range + 0.001, None, 0) * np.random.rand(3)
|
||||
xyz_offset = xyz + offset
|
||||
valid_idxs = (xyz_offset.min(1) >= 0) * ((xyz_offset < spatial_shape).sum(1) == 3)
|
||||
spatial_shape[:2] -= 32
|
||||
|
||||
spatial_shape[:2] -= step_temp
|
||||
return xyz_offset, valid_idxs
|
||||
|
||||
def getCroppedInstLabel(self, instance_label, valid_idxs):
|
||||
@ -179,11 +162,8 @@ class CustomDataset(Dataset):
|
||||
if data is None:
|
||||
return None
|
||||
xyz, xyz_middle, rgb, label, instance_label = data
|
||||
inst_num, inst_infos = self.getInstanceInfo(xyz_middle, instance_label.astype(np.int32),
|
||||
label)
|
||||
inst_info = inst_infos["instance_info"]
|
||||
inst_pointnum = inst_infos["instance_pointnum"]
|
||||
inst_cls = inst_infos["instance_cls"]
|
||||
info = self.getInstanceInfo(xyz_middle, instance_label.astype(np.int32), label)
|
||||
inst_num, inst_pointnum, inst_cls, pt_offset_label = info
|
||||
loc = torch.from_numpy(xyz).long()
|
||||
loc_float = torch.from_numpy(xyz_middle)
|
||||
feat = torch.from_numpy(rgb).float()
|
||||
@ -191,9 +171,9 @@ class CustomDataset(Dataset):
|
||||
feat += torch.randn(3) * 0.1
|
||||
label = torch.from_numpy(label)
|
||||
instance_label = torch.from_numpy(instance_label)
|
||||
inst_info = torch.from_numpy(inst_info)
|
||||
return (scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_info,
|
||||
inst_pointnum, inst_cls)
|
||||
pt_offset_label = torch.from_numpy(pt_offset_label)
|
||||
return (scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_pointnum,
|
||||
inst_cls, pt_offset_label)
|
||||
|
||||
def collate_fn(self, batch):
|
||||
scan_ids = []
|
||||
@ -203,55 +183,51 @@ class CustomDataset(Dataset):
|
||||
labels = []
|
||||
instance_labels = []
|
||||
|
||||
instance_infos = [] # (N, 9)
|
||||
instance_pointnum = [] # (total_nInst), int
|
||||
instance_cls = [] # (total_nInst), long
|
||||
|
||||
batch_offsets = [0]
|
||||
pt_offset_labels = []
|
||||
|
||||
total_inst_num = 0
|
||||
batch_id = 0
|
||||
for data in batch:
|
||||
if data is None:
|
||||
continue
|
||||
(scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_info,
|
||||
inst_pointnum, inst_cls) = data
|
||||
(scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_pointnum,
|
||||
inst_cls, pt_offset_label) = data
|
||||
instance_label[np.where(instance_label != -100)] += total_inst_num
|
||||
total_inst_num += inst_num
|
||||
batch_offsets.append(batch_offsets[-1] + loc.size(0))
|
||||
scan_ids.append(scan_id)
|
||||
locs.append(torch.cat([loc.new_full((loc.size(0), 1), batch_id), loc], 1))
|
||||
locs_float.append(loc_float)
|
||||
feats.append(feat)
|
||||
labels.append(label)
|
||||
instance_labels.append(instance_label)
|
||||
instance_infos.append(inst_info)
|
||||
instance_pointnum.extend(inst_pointnum)
|
||||
instance_cls.extend(inst_cls)
|
||||
pt_offset_labels.append(pt_offset_label)
|
||||
batch_id += 1
|
||||
assert batch_id > 0, 'empty batch'
|
||||
if batch_id < len(batch):
|
||||
self.logger.info(f'batch is truncated from size {len(batch)} to {batch_id}')
|
||||
|
||||
# merge all the scenes in the batch
|
||||
batch_offsets = torch.tensor(batch_offsets, dtype=torch.int) # int (B+1)
|
||||
|
||||
locs = torch.cat(locs, 0) # long (N, 1 + 3), the batch item idx is put in locs[:, 0]
|
||||
batch_idxs = locs[:, 0].int()
|
||||
locs_float = torch.cat(locs_float, 0).to(torch.float32) # float (N, 3)
|
||||
feats = torch.cat(feats, 0) # float (N, C)
|
||||
labels = torch.cat(labels, 0).long() # long (N)
|
||||
instance_labels = torch.cat(instance_labels, 0).long() # long (N)
|
||||
instance_infos = torch.cat(instance_infos,
|
||||
0).to(torch.float32) # float (N, 9) (meanxyz, minxyz, maxxyz)
|
||||
instance_pointnum = torch.tensor(instance_pointnum, dtype=torch.int) # int (total_nInst)
|
||||
instance_cls = torch.tensor(instance_cls, dtype=torch.long) # long (total_nInst)
|
||||
pt_offset_labels = torch.cat(pt_offset_labels).float()
|
||||
|
||||
spatial_shape = np.clip(
|
||||
locs.max(0)[0][1:].numpy() + 1, self.voxel_cfg.spatial_shape[0], None)
|
||||
voxel_locs, p2v_map, v2p_map = softgroup_ops.voxelization_idx(locs, 1)
|
||||
voxel_locs, v2p_map, p2v_map = softgroup_ops.voxelization_idx(locs, 1)
|
||||
return {
|
||||
'scan_ids': scan_ids,
|
||||
'locs': locs,
|
||||
'batch_idxs': batch_idxs,
|
||||
'voxel_locs': voxel_locs,
|
||||
'p2v_map': p2v_map,
|
||||
'v2p_map': v2p_map,
|
||||
@ -259,10 +235,9 @@ class CustomDataset(Dataset):
|
||||
'feats': feats,
|
||||
'labels': labels,
|
||||
'instance_labels': instance_labels,
|
||||
'instance_info': instance_infos,
|
||||
'instance_pointnum': instance_pointnum,
|
||||
'instance_cls': instance_cls,
|
||||
'offsets': batch_offsets,
|
||||
'pt_offset_labels': pt_offset_labels,
|
||||
'spatial_shape': spatial_shape,
|
||||
'batch_size': batch_id,
|
||||
}
|
||||
|
||||
@ -24,7 +24,6 @@ class S3DISDataset(CustomDataset):
|
||||
assert len(filenames) > 0, f'Empty {p}'
|
||||
filenames_all.extend(filenames)
|
||||
filenames_all = sorted(filenames_all * self.repeat)
|
||||
self.logger.info(f'Load dataset: {len(filenames_all)} scans')
|
||||
return filenames_all
|
||||
|
||||
def load(self, filename):
|
||||
@ -41,53 +40,7 @@ class S3DISDataset(CustomDataset):
|
||||
return xyz, rgb, label, instance_label
|
||||
|
||||
def crop(self, xyz, step=64):
|
||||
xyz_offset = xyz.copy()
|
||||
valid_idxs = (xyz_offset.min(1) >= 0) * (
|
||||
(xyz < self.voxel_cfg.spatial_shape[1]).sum(1) == 3)
|
||||
|
||||
spatial_shape = np.array([self.voxel_cfg.spatial_shape[1]] * 3)
|
||||
room_range = xyz.max(0) - xyz.min(0)
|
||||
while (valid_idxs.sum() > self.voxel_cfg.max_npoint):
|
||||
step_temp = step
|
||||
if valid_idxs.sum() > 1e6:
|
||||
step_temp = step * 2
|
||||
offset = np.clip(spatial_shape - room_range + 0.001, None, 0) * np.random.rand(3)
|
||||
xyz_offset = xyz + offset
|
||||
valid_idxs = (xyz_offset.min(1) >= 0) * ((xyz_offset < spatial_shape).sum(1) == 3)
|
||||
spatial_shape[:2] -= step_temp
|
||||
|
||||
return xyz_offset, valid_idxs
|
||||
|
||||
def getInstanceInfo(self, xyz, instance_label, label):
|
||||
instance_info = np.ones((xyz.shape[0], 9), dtype=np.float32) * -100.0
|
||||
instance_pointnum = [] # (nInst), int
|
||||
instance_cls = []
|
||||
instance_num = int(instance_label.max()) + 1
|
||||
for i_ in range(instance_num):
|
||||
inst_idx_i = np.where(instance_label == i_)
|
||||
|
||||
# instance_info
|
||||
xyz_i = xyz[inst_idx_i]
|
||||
min_xyz_i = xyz_i.min(0)
|
||||
max_xyz_i = xyz_i.max(0)
|
||||
mean_xyz_i = xyz_i.mean(0)
|
||||
instance_info_i = instance_info[inst_idx_i]
|
||||
instance_info_i[:, 0:3] = mean_xyz_i
|
||||
instance_info_i[:, 3:6] = min_xyz_i
|
||||
instance_info_i[:, 6:9] = max_xyz_i
|
||||
instance_info[inst_idx_i] = instance_info_i
|
||||
|
||||
# instance_pointnum
|
||||
instance_pointnum.append(inst_idx_i[0].size)
|
||||
cls_loc = inst_idx_i[0][0]
|
||||
instance_cls.append(label[cls_loc])
|
||||
# assert (0 not in instance_cls) and (1 not in instance_cls) # sanity check stuff cls
|
||||
|
||||
return instance_num, {
|
||||
"instance_info": instance_info,
|
||||
"instance_pointnum": instance_pointnum,
|
||||
"instance_cls": instance_cls
|
||||
}
|
||||
return super().crop(xyz, step=step)
|
||||
|
||||
def transform_test(self, xyz, rgb, label, instance_label):
|
||||
# devide into 4 piecies
|
||||
@ -120,23 +73,24 @@ class S3DISDataset(CustomDataset):
|
||||
return super().collate_fn(batch)
|
||||
|
||||
# assume 1 scan only
|
||||
(scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_info, inst_pointnum,
|
||||
inst_cls) = batch[0]
|
||||
(scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_pointnum, inst_cls,
|
||||
pt_offset_label) = batch[0]
|
||||
scan_ids = [scan_id]
|
||||
locs = loc.long()
|
||||
batch_idxs = torch.zeros_like(loc[:, 0].int())
|
||||
locs_float = loc_float.float()
|
||||
feats = feat.float()
|
||||
labels = label.long()
|
||||
instance_labels = instance_label.long()
|
||||
instance_infos = inst_info.float()
|
||||
instance_pointnum = torch.tensor([inst_pointnum], dtype=torch.int)
|
||||
instance_cls = torch.tensor([inst_cls], dtype=torch.long)
|
||||
pt_offset_labels = pt_offset_label.float()
|
||||
spatial_shape = np.clip((locs.max(0)[0][1:] + 1).numpy(), self.voxel_cfg.spatial_shape[0],
|
||||
None)
|
||||
voxel_locs, p2v_map, v2p_map = softgroup_ops.voxelization_idx(locs, 4)
|
||||
voxel_locs, v2p_map, p2v_map = softgroup_ops.voxelization_idx(locs, 4)
|
||||
return {
|
||||
'scan_ids': scan_ids,
|
||||
'locs': locs,
|
||||
'batch_idxs': batch_idxs,
|
||||
'voxel_locs': voxel_locs,
|
||||
'p2v_map': p2v_map,
|
||||
'v2p_map': v2p_map,
|
||||
@ -144,8 +98,9 @@ class S3DISDataset(CustomDataset):
|
||||
'feats': feats,
|
||||
'labels': labels,
|
||||
'instance_labels': instance_labels,
|
||||
'instance_info': instance_infos,
|
||||
'instance_pointnum': instance_pointnum,
|
||||
'instance_cls': instance_cls,
|
||||
'spatial_shape': spatial_shape
|
||||
'pt_offset_labels': pt_offset_labels,
|
||||
'spatial_shape': spatial_shape,
|
||||
'batch_size': 4
|
||||
}
|
||||
|
||||
@ -6,3 +6,9 @@ class ScanNetDataset(CustomDataset):
|
||||
CLASSES = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture',
|
||||
'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink',
|
||||
'bathtub', 'otherfurniture')
|
||||
|
||||
def getInstanceInfo(self, xyz, instance_label, label):
|
||||
ret = super().getInstanceInfo(xyz, instnace_label, label)
|
||||
instance_num, instance_pointnum, instance_cls, pt_offset_label = ret
|
||||
instance_cls = [x - 2 if x != -100 else x for x in instance_cls]
|
||||
return instance_num, instance_pointnum, instance_cls, pt_offset_label
|
||||
|
||||
@ -409,7 +409,7 @@ class ScanNetEval(object):
|
||||
"""
|
||||
print('evaluating', len(pred_list), 'scans...')
|
||||
matches = {}
|
||||
for i, (preds, gts) in enumerate(tqdm(zip(pred_list, gt_list))):
|
||||
for i, (preds, gts) in enumerate(tqdm(zip(pred_list, gt_list), total=len(pred_list))):
|
||||
gt2pred, pred2gt = self.assign_instances_for_scan(preds, gts)
|
||||
# assign gt to predictions
|
||||
matches_key = f'gt_{i}'
|
||||
|
||||
@ -103,7 +103,7 @@ class SoftGroup(nn.Module):
|
||||
return self.forward_test(batch)
|
||||
|
||||
def forward_train(self, batch):
|
||||
coords = batch['locs'].cuda()
|
||||
batch_idxs = batch['batch_idxs'].cuda()
|
||||
voxel_coords = batch['voxel_locs'].cuda()
|
||||
p2v_map = batch['p2v_map'].cuda()
|
||||
v2p_map = batch['v2p_map'].cuda()
|
||||
@ -111,26 +111,24 @@ class SoftGroup(nn.Module):
|
||||
feats = batch['feats'].cuda()
|
||||
semantic_labels = batch['labels'].cuda()
|
||||
instance_labels = batch['instance_labels'].cuda()
|
||||
instance_info = batch['instance_info'].cuda()
|
||||
# instance_pointnum = batch['instance_pointnum'].cuda()
|
||||
# instance_cls = batch['instance_cls'].cuda()
|
||||
# batch_offsets = batch['offsets'].cuda()
|
||||
pt_offset_labels = batch['pt_offset_labels'].cuda()
|
||||
spatial_shape = batch['spatial_shape']
|
||||
batch_size = batch['batch_size']
|
||||
|
||||
feats = torch.cat((feats, coords_float), 1)
|
||||
voxel_feats = softgroup_ops.voxelization(feats, v2p_map)
|
||||
|
||||
losses = {}
|
||||
pt_offset_labels = instance_info[:, :3] - coords_float
|
||||
feats = torch.cat((feats, coords_float), 1)
|
||||
voxel_feats = softgroup_ops.voxelization(feats, p2v_map)
|
||||
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
|
||||
semantic_scores, pt_offsets, output_feats, coords_float = self.forward_backbone(
|
||||
input, p2v_map, coords_float) # TODO check name for map
|
||||
input, v2p_map, coords_float)
|
||||
point_wise_loss = self.point_wise_loss(semantic_scores, pt_offsets, semantic_labels,
|
||||
instance_labels, pt_offset_labels)
|
||||
losses.update(point_wise_loss)
|
||||
loss = sum(v[0] for v in losses.values())
|
||||
losses['loss'] = (loss, coords.size(0))
|
||||
losses['loss'] = (loss, batch_idxs.size(0))
|
||||
return loss, losses
|
||||
|
||||
def point_wise_loss(self, semantic_scores, pt_offsets, semantic_labels, instance_labels,
|
||||
@ -141,7 +139,7 @@ class SoftGroup(nn.Module):
|
||||
|
||||
pos_inds = instance_labels != self.ignore_label
|
||||
if pos_inds.sum() == 0:
|
||||
offset_loss = 0 * pt_offset.sum()
|
||||
offset_loss = 0 * pt_offsets.sum()
|
||||
else:
|
||||
offset_loss = self.offset_loss(pt_offsets[pos_inds],
|
||||
pt_offset_labels[pos_inds]) / pos_inds.sum()
|
||||
@ -149,7 +147,7 @@ class SoftGroup(nn.Module):
|
||||
return losses
|
||||
|
||||
def forward_test(self, batch):
|
||||
coords = batch['locs'].cuda()
|
||||
batch_idxs = batch['batch_idxs'].cuda()
|
||||
voxel_coords = batch['voxel_locs'].cuda()
|
||||
p2v_map = batch['p2v_map'].cuda()
|
||||
v2p_map = batch['v2p_map'].cuda()
|
||||
@ -157,23 +155,17 @@ class SoftGroup(nn.Module):
|
||||
feats = batch['feats'].cuda()
|
||||
labels = batch['labels'].cuda()
|
||||
instance_labels = batch['instance_labels'].cuda()
|
||||
# instance_info = batch['instance_info'].cuda()
|
||||
# instance_pointnum = batch['instance_pointnum'].cuda()
|
||||
# instance_cls = batch['instance_cls'].cuda()
|
||||
# batch_offsets = batch['offsets'].cuda()
|
||||
spatial_shape = batch['spatial_shape']
|
||||
batch_size = batch['batch_size']
|
||||
|
||||
feats = torch.cat((feats, coords_float), 1)
|
||||
voxel_feats = softgroup_ops.voxelization(feats, v2p_map)
|
||||
if self.test_cfg.x4_split:
|
||||
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, 4)
|
||||
batch_idxs = torch.zeros_like(coords[:, 0].int())
|
||||
else:
|
||||
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, 1)
|
||||
batch_idxs = coords[:, 0].int()
|
||||
voxel_feats = softgroup_ops.voxelization(feats, p2v_map)
|
||||
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
|
||||
semantic_scores, pt_offsets, output_feats, coords_float = self.forward_backbone(
|
||||
input, p2v_map, coords_float,
|
||||
x4_split=self.test_cfg.x4_split) # TODO check name for map
|
||||
input, v2p_map, coords_float, x4_split=self.test_cfg.x4_split)
|
||||
proposals_idx, proposals_offset = self.forward_grouping(semantic_scores, pt_offsets,
|
||||
batch_idxs, coords_float,
|
||||
self.grouping_cfg)
|
||||
|
||||
24
test.py
24
test.py
@ -12,6 +12,8 @@ from model.softgroup import SoftGroup
|
||||
|
||||
from data.scannetv2 import ScanNetDataset
|
||||
from torch.utils.data import DataLoader
|
||||
from util import get_root_logger
|
||||
from data import build_dataset, build_dataloader
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -62,6 +64,7 @@ def evaluate_semantic_segmantation_miou(matches):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cudnn.enabled = False # TODO remove this
|
||||
test_seed = 567
|
||||
random.seed(test_seed)
|
||||
np.random.seed(test_seed)
|
||||
@ -69,24 +72,17 @@ if __name__ == '__main__':
|
||||
torch.cuda.manual_seed_all(test_seed)
|
||||
|
||||
args = get_args()
|
||||
cfg = Munch.fromDict(yaml.safe_load(open(args.config, 'r')))
|
||||
torch.backends.cudnn.enabled = False
|
||||
cfg_txt = open(args.config, 'r').read()
|
||||
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
|
||||
logger = get_root_logger()
|
||||
|
||||
model = SoftGroup(**cfg.model)
|
||||
print(f'Load state dict from {args.checkpoint}')
|
||||
model = utils.load_checkpoint(model, args.checkpoint)
|
||||
logger.info(f'Load state dict from {args.checkpoint}')
|
||||
utils.load_checkpoint(args.checkpoint, logger, model)
|
||||
model.cuda()
|
||||
|
||||
|
||||
dataset = ScanNetDataset(training=False, **cfg.data.test)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=16,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
pin_memory=True)
|
||||
dataset = build_dataset(cfg.data.test, logger)
|
||||
dataloader = build_dataloader(dataset, training=False)
|
||||
all_preds, all_gts = [], []
|
||||
with torch.no_grad():
|
||||
model = model.eval()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user