mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
support save results
This commit is contained in:
parent
590b96b8aa
commit
096c5e8749
@ -1,72 +0,0 @@
|
||||
model:
|
||||
channels: 32
|
||||
num_blocks: 7
|
||||
semantic_classes: 20
|
||||
instance_classes: 18
|
||||
sem2ins_classes: []
|
||||
semantic_only: False
|
||||
ignore_label: -100
|
||||
grouping_cfg:
|
||||
score_thr: 0.2
|
||||
radius: 0.04
|
||||
mean_active: 300
|
||||
class_numpoint_mean: [-1., -1., 3917., 12056., 2303.,
|
||||
8331., 3948., 3166., 5629., 11719.,
|
||||
1003., 3317., 4912., 10221., 3889.,
|
||||
4136., 2120., 945., 3967., 2589.]
|
||||
npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1
|
||||
ignore_classes: [0, 1]
|
||||
instance_voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: 20
|
||||
train_cfg:
|
||||
max_proposal_num: 200
|
||||
pos_iou_thr: 0.5
|
||||
test_cfg:
|
||||
x4_split: False
|
||||
cls_score_thr: 0.001
|
||||
mask_score_thr: -0.5
|
||||
min_npoint: 100
|
||||
fixed_modules: ['input_conv', 'unet', 'output_layer', 'semantic_linear', 'offset_linear']
|
||||
|
||||
data:
|
||||
train:
|
||||
type: 'scannetv2'
|
||||
data_root: 'dataset/scannetv2'
|
||||
prefix: 'train'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: True
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
test:
|
||||
type: 'scannetv2'
|
||||
data_root: 'dataset/scannetv2'
|
||||
prefix: 'val'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: False
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
|
||||
dataloader:
|
||||
train:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
test:
|
||||
batch_size: 1
|
||||
num_workers: 16
|
||||
|
||||
optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.001
|
||||
|
||||
epochs: 512
|
||||
step_epoch: 200
|
||||
save_freq: 8
|
||||
pretrain: 'hais_ckpt.pth'
|
||||
work_dir: 'work_dirs/softgroup_scannet'
|
||||
@ -1,71 +0,0 @@
|
||||
model:
|
||||
channels: 32
|
||||
num_blocks: 7
|
||||
semantic_classes: 13
|
||||
instance_classes: 13
|
||||
sem2ins_classes: [0, 1]
|
||||
semantic_only: True
|
||||
ignore_label: -100
|
||||
grouping_cfg:
|
||||
score_thr: 0.2
|
||||
radius: 0.04
|
||||
mean_active: 300
|
||||
class_numpoint_mean: [1823, 7457, 6189, 7424, 34229, 1724, 5439,
|
||||
6016, 39796, 5279, 5092, 12210, 10225]
|
||||
npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1
|
||||
ignore_classes: [0, 1]
|
||||
instance_voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: 20
|
||||
train_cfg:
|
||||
max_proposal_num: 200
|
||||
pos_iou_thr: 0.5
|
||||
test_cfg:
|
||||
x4_split: True
|
||||
cls_score_thr: 0.001
|
||||
mask_score_thr: -0.5
|
||||
min_npoint: 100
|
||||
fixed_modules: []
|
||||
|
||||
data:
|
||||
train:
|
||||
type: 's3dis'
|
||||
data_root: 'dataset/s3dis/preprocess'
|
||||
prefix: ['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6']
|
||||
suffix: '_inst_nostuff.pth'
|
||||
repeat: 5
|
||||
training: True
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
test:
|
||||
type: 's3dis'
|
||||
data_root: 'dataset/s3dis/preprocess'
|
||||
prefix: 'Area_5'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: False
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
|
||||
dataloader:
|
||||
train:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
test:
|
||||
batch_size: 1
|
||||
num_workers: 1
|
||||
|
||||
optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.001
|
||||
|
||||
epochs: 30 # actual epochs = 30 * repeat
|
||||
step_epoch: 0
|
||||
save_freq: 2
|
||||
pretrain: 'hais_ckpt.pth'
|
||||
work_dir: 'work_dirs/softgroup_s3dis_backbone'
|
||||
@ -1,71 +0,0 @@
|
||||
model:
|
||||
channels: 32
|
||||
num_blocks: 7
|
||||
semantic_classes: 13
|
||||
instance_classes: 13
|
||||
sem2ins_classes: [0, 1]
|
||||
semantic_only: False
|
||||
ignore_label: -100
|
||||
grouping_cfg:
|
||||
score_thr: 0.2
|
||||
radius: 0.04
|
||||
mean_active: 300
|
||||
class_numpoint_mean: [1823, 7457, 6189, 7424, 34229, 1724, 5439,
|
||||
6016, 39796, 5279, 5092, 12210, 10225]
|
||||
npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1
|
||||
ignore_classes: [0, 1]
|
||||
instance_voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: 20
|
||||
train_cfg:
|
||||
max_proposal_num: 200
|
||||
pos_iou_thr: 0.5
|
||||
test_cfg:
|
||||
x4_split: True
|
||||
cls_score_thr: 0.001
|
||||
mask_score_thr: -0.5
|
||||
min_npoint: 100
|
||||
fixed_modules: ['input_conv', 'unet', 'output_layer', 'semantic_linear', 'offset_linear']
|
||||
|
||||
data:
|
||||
train:
|
||||
type: 's3dis'
|
||||
data_root: 'dataset/s3dis/preprocess'
|
||||
prefix: ['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6']
|
||||
suffix: '_inst_nostuff.pth'
|
||||
repeat: 5
|
||||
training: True
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
test:
|
||||
type: 's3dis'
|
||||
data_root: 'dataset/s3dis/preprocess'
|
||||
prefix: 'Area_5'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: False
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
|
||||
dataloader:
|
||||
train:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
test:
|
||||
batch_size: 1
|
||||
num_workers: 1
|
||||
|
||||
optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.001
|
||||
|
||||
epochs: 30 # actual epochs = 30 * repeat
|
||||
step_epoch: 0
|
||||
save_freq: 2
|
||||
pretrain: 'exp/s3dis/softgroup/softgroup_fold5_s3dis/softgroup_fold5_s3dis-000000030.pth'
|
||||
work_dir: 'work_dirs/softgroup_s3dis'
|
||||
@ -64,9 +64,14 @@ optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.004
|
||||
|
||||
save_cfg:
|
||||
semantic: True
|
||||
offset: True
|
||||
instance: False
|
||||
|
||||
fp16: False
|
||||
epochs: 20
|
||||
step_epoch: 0
|
||||
save_freq: 2
|
||||
pretrain: 'work_dirs/softgroup_scannet_backbone/epoch_120.pth'
|
||||
pretrain: './hais_ckpt_spconv2.pth'
|
||||
work_dir: ''
|
||||
|
||||
@ -64,6 +64,11 @@ optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.004
|
||||
|
||||
save_cfg:
|
||||
semantic: True
|
||||
offset: True
|
||||
instance: True
|
||||
|
||||
fp16: False
|
||||
epochs: 20
|
||||
step_epoch: 0
|
||||
|
||||
@ -66,9 +66,14 @@ optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.004
|
||||
|
||||
save_cfg:
|
||||
semantic: True
|
||||
offset: True
|
||||
instance: True
|
||||
|
||||
fp16: False
|
||||
epochs: 128
|
||||
step_epoch: 50
|
||||
save_freq: 4
|
||||
pretrain: 'work_dirs/softgroup_scannet_backbone/epoch_120.pth'
|
||||
pretrain: './hais_ckpt_spconv2.pth'
|
||||
work_dir: ''
|
||||
|
||||
@ -1,74 +0,0 @@
|
||||
model:
|
||||
channels: 32
|
||||
num_blocks: 7
|
||||
semantic_classes: 20
|
||||
instance_classes: 18
|
||||
sem2ins_classes: []
|
||||
semantic_only: True
|
||||
ignore_label: -100
|
||||
grouping_cfg:
|
||||
score_thr: 0.2
|
||||
radius: 0.04
|
||||
mean_active: 300
|
||||
class_numpoint_mean: [-1., -1., 3917., 12056., 2303.,
|
||||
8331., 3948., 3166., 5629., 11719.,
|
||||
1003., 3317., 4912., 10221., 3889.,
|
||||
4136., 2120., 945., 3967., 2589.]
|
||||
npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1
|
||||
ignore_classes: [0, 1]
|
||||
instance_voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: 20
|
||||
train_cfg:
|
||||
max_proposal_num: 200
|
||||
pos_iou_thr: 0.5
|
||||
test_cfg:
|
||||
x4_split: False
|
||||
cls_score_thr: 0.001
|
||||
mask_score_thr: -0.5
|
||||
min_npoint: 100
|
||||
fixed_modules: []
|
||||
|
||||
data:
|
||||
train:
|
||||
type: 'scannetv2'
|
||||
data_root: 'dataset/scannetv2'
|
||||
prefix: 'train'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: True
|
||||
repeat: 4
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
test:
|
||||
type: 'scannetv2'
|
||||
data_root: 'dataset/scannetv2'
|
||||
prefix: 'val'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: False
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
|
||||
dataloader:
|
||||
train:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
test:
|
||||
batch_size: 1
|
||||
num_workers: 1
|
||||
|
||||
optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.004
|
||||
|
||||
fp16: False
|
||||
epochs: 128
|
||||
step_epoch: 50
|
||||
save_freq: 4
|
||||
pretrain: ''
|
||||
work_dir: ''
|
||||
@ -1,74 +0,0 @@
|
||||
model:
|
||||
channels: 32
|
||||
num_blocks: 7
|
||||
semantic_classes: 20
|
||||
instance_classes: 18
|
||||
sem2ins_classes: []
|
||||
semantic_only: True
|
||||
ignore_label: -100
|
||||
grouping_cfg:
|
||||
score_thr: 0.2
|
||||
radius: 0.04
|
||||
mean_active: 300
|
||||
class_numpoint_mean: [-1., -1., 3917., 12056., 2303.,
|
||||
8331., 3948., 3166., 5629., 11719.,
|
||||
1003., 3317., 4912., 10221., 3889.,
|
||||
4136., 2120., 945., 3967., 2589.]
|
||||
npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1
|
||||
ignore_classes: [0, 1]
|
||||
instance_voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: 20
|
||||
train_cfg:
|
||||
max_proposal_num: 200
|
||||
pos_iou_thr: 0.5
|
||||
test_cfg:
|
||||
x4_split: False
|
||||
cls_score_thr: 0.001
|
||||
mask_score_thr: -0.5
|
||||
min_npoint: 100
|
||||
fixed_modules: []
|
||||
|
||||
data:
|
||||
train:
|
||||
type: 'scannetv2'
|
||||
data_root: 'dataset/scannetv2'
|
||||
prefix: 'train'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: True
|
||||
repeat: 4
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
test:
|
||||
type: 'scannetv2'
|
||||
data_root: 'dataset/scannetv2'
|
||||
prefix: 'val'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: False
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
|
||||
dataloader:
|
||||
train:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
test:
|
||||
batch_size: 1
|
||||
num_workers: 1
|
||||
|
||||
optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.004
|
||||
|
||||
fp16: True
|
||||
epochs: 128
|
||||
step_epoch: 50
|
||||
save_freq: 4
|
||||
pretrain: ''
|
||||
work_dir: ''
|
||||
@ -1,74 +0,0 @@
|
||||
model:
|
||||
channels: 32
|
||||
num_blocks: 7
|
||||
semantic_classes: 20
|
||||
instance_classes: 18
|
||||
sem2ins_classes: []
|
||||
semantic_only: False
|
||||
ignore_label: -100
|
||||
grouping_cfg:
|
||||
score_thr: 0.2
|
||||
radius: 0.04
|
||||
mean_active: 300
|
||||
class_numpoint_mean: [-1., -1., 3917., 12056., 2303.,
|
||||
8331., 3948., 3166., 5629., 11719.,
|
||||
1003., 3317., 4912., 10221., 3889.,
|
||||
4136., 2120., 945., 3967., 2589.]
|
||||
npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1
|
||||
ignore_classes: [0, 1]
|
||||
instance_voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: 20
|
||||
train_cfg:
|
||||
max_proposal_num: 200
|
||||
pos_iou_thr: 0.5
|
||||
test_cfg:
|
||||
x4_split: False
|
||||
cls_score_thr: 0.001
|
||||
mask_score_thr: -0.5
|
||||
min_npoint: 100
|
||||
fixed_modules: ['input_conv', 'unet', 'output_layer', 'semantic_linear', 'offset_linear']
|
||||
|
||||
data:
|
||||
train:
|
||||
type: 'scannetv2'
|
||||
data_root: 'dataset/scannetv2'
|
||||
prefix: 'train'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: True
|
||||
repeat: 4
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
test:
|
||||
type: 'scannetv2'
|
||||
data_root: 'dataset/scannetv2'
|
||||
prefix: 'val'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: False
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
|
||||
dataloader:
|
||||
train:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
test:
|
||||
batch_size: 1
|
||||
num_workers: 1
|
||||
|
||||
optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.004
|
||||
|
||||
fp16: True
|
||||
epochs: 128
|
||||
step_epoch: 50
|
||||
save_freq: 4
|
||||
pretrain: 'work_dirs/softgroup_scannet_backbone_spconv2_dist/epoch_116.pth'
|
||||
work_dir: ''
|
||||
@ -79,9 +79,10 @@ class SoftGroup(nn.Module):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, MLP):
|
||||
m.init_weights()
|
||||
for m in [self.cls_linear, self.iou_score_linear]:
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
if not self.semantic_only:
|
||||
for m in [self.cls_linear, self.iou_score_linear]:
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
@ -232,6 +233,8 @@ class SoftGroup(nn.Module):
|
||||
pt_offset_labels = self.merge_4_parts(pt_offset_labels)
|
||||
semantic_preds = semantic_scores.max(1)[1]
|
||||
ret = dict(
|
||||
scan_id=scan_ids[0],
|
||||
coords_float=coords_float.cpu().numpy(),
|
||||
semantic_preds=semantic_preds.cpu().numpy(),
|
||||
semantic_labels=semantic_labels.cpu().numpy(),
|
||||
offset_preds=pt_offsets.cpu().numpy(),
|
||||
|
||||
104
tools/test.py
104
tools/test.py
@ -1,5 +1,10 @@
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import os.path as osp
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from munch import Munch
|
||||
@ -8,7 +13,7 @@ from softgroup.evaluation import (ScanNetEval, evaluate_offset_mae, evaluate_sem
|
||||
evaluate_semantic_miou)
|
||||
from softgroup.model import SoftGroup
|
||||
from softgroup.util import (collect_results_gpu, get_dist_info, get_root_logger, init_dist,
|
||||
is_main_process, load_checkpoint)
|
||||
is_main_process, load_checkpoint, rle_decode)
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -18,11 +23,57 @@ def get_args():
|
||||
parser.add_argument('config', type=str, help='path to config file')
|
||||
parser.add_argument('checkpoint', type=str, help='path to checkpoint')
|
||||
parser.add_argument('--dist', action='store_true', help='run with distributed parallel')
|
||||
parser.add_argument('--out', type=str, help='directory for output results')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def save_npy(root, name, scan_ids, arrs):
|
||||
root = osp.join(root, name)
|
||||
os.makedirs(root, exist_ok=True)
|
||||
paths = [osp.join(root, f'{i}.npy') for i in scan_ids]
|
||||
pool = mp.Pool()
|
||||
pool.starmap(np.save, zip(paths, arrs))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
def save_single_instance(root, scan_id, insts):
|
||||
f = open(osp.join(root, f'{scan_id}.txt'), 'w')
|
||||
os.makedirs(osp.join(root, 'predicted_masks'), exist_ok=True)
|
||||
for i, inst in enumerate(insts):
|
||||
assert scan_id == inst['scan_id']
|
||||
label_id = inst['label_id']
|
||||
conf = inst['conf']
|
||||
f.write(f'predicted_masks/{scan_id}_{i:03d}.txt {label_id} {conf:.4f}\n')
|
||||
mask_path = osp.join(root, 'predicted_masks', f'{scan_id}_{i:03d}.txt')
|
||||
mask = rle_decode(inst['pred_mask'])
|
||||
np.savetxt(mask_path, mask, fmt='%d')
|
||||
f.close()
|
||||
|
||||
|
||||
def save_pred_instances(root, name, scan_ids, pred_insts):
|
||||
root = osp.join(root, name)
|
||||
os.makedirs(root, exist_ok=True)
|
||||
roots = [root] * len(scan_ids)
|
||||
pool = mp.Pool()
|
||||
pool.starmap(save_single_instance, zip(roots, scan_ids, pred_insts))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
def save_gt_instances(root, name, scan_ids, gt_insts):
|
||||
root = osp.join(root, name)
|
||||
os.makedirs(root, exist_ok=True)
|
||||
paths = [osp.join(root, f'{i}.txt') for i in scan_ids]
|
||||
pool = mp.Pool()
|
||||
map_func = partial(np.savetxt, fmt='%d')
|
||||
pool.starmap(map_func, zip(paths, gt_insts))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
cfg_txt = open(args.config, 'r').read()
|
||||
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
|
||||
@ -39,8 +90,8 @@ if __name__ == '__main__':
|
||||
dataset = build_dataset(cfg.data.test, logger)
|
||||
dataloader = build_dataloader(dataset, training=False, dist=args.dist, **cfg.dataloader.test)
|
||||
results = []
|
||||
all_sem_preds, all_sem_labels, all_offset_preds, all_offset_labels = [], [], [], []
|
||||
all_inst_labels, all_pred_insts, all_gt_insts = [], [], []
|
||||
scan_ids, coords, sem_preds, sem_labels, offset_preds, offset_labels = [], [], [], [], [], []
|
||||
inst_labels, pred_insts, gt_insts = [], [], []
|
||||
_, world_size = get_dist_info()
|
||||
progress_bar = tqdm(total=len(dataloader) * world_size, disable=not is_main_process())
|
||||
with torch.no_grad():
|
||||
@ -53,20 +104,41 @@ if __name__ == '__main__':
|
||||
results = collect_results_gpu(results, len(dataset))
|
||||
if is_main_process():
|
||||
for res in results:
|
||||
all_sem_preds.append(res['semantic_preds'])
|
||||
all_sem_labels.append(res['semantic_labels'])
|
||||
all_offset_preds.append(res['offset_preds'])
|
||||
all_offset_labels.append(res['offset_labels'])
|
||||
all_inst_labels.append(res['instance_labels'])
|
||||
scan_ids.append(res['scan_id'])
|
||||
coords.append(res['coords_float'])
|
||||
sem_preds.append(res['semantic_preds'])
|
||||
sem_labels.append(res['semantic_labels'])
|
||||
offset_preds.append(res['offset_preds'])
|
||||
offset_labels.append(res['offset_labels'])
|
||||
inst_labels.append(res['instance_labels'])
|
||||
if not cfg.model.semantic_only:
|
||||
all_pred_insts.append(res['pred_instances'])
|
||||
all_gt_insts.append(res['gt_instances'])
|
||||
pred_insts.append(res['pred_instances'])
|
||||
gt_insts.append(res['gt_instances'])
|
||||
if not cfg.model.semantic_only:
|
||||
logger.info('Evaluate instance segmentation')
|
||||
scannet_eval = ScanNetEval(dataset.CLASSES)
|
||||
scannet_eval.evaluate(all_pred_insts, all_gt_insts)
|
||||
scannet_eval.evaluate(pred_insts, gt_insts)
|
||||
logger.info('Evaluate semantic segmentation and offset MAE')
|
||||
evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label, logger)
|
||||
evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label, logger)
|
||||
evaluate_offset_mae(all_offset_preds, all_offset_labels, all_inst_labels,
|
||||
cfg.model.ignore_label, logger)
|
||||
ignore_label = cfg.model.ignore_label
|
||||
evaluate_semantic_miou(sem_preds, sem_labels, ignore_label, logger)
|
||||
evaluate_semantic_acc(sem_preds, sem_labels, ignore_label, logger)
|
||||
evaluate_offset_mae(offset_preds, offset_labels, inst_labels, ignore_label, logger)
|
||||
|
||||
# save output
|
||||
if not args.out:
|
||||
return
|
||||
logger.info('Save results')
|
||||
save_npy(args.out, 'coords', scan_ids, coords)
|
||||
if cfg.save_cfg.semantic:
|
||||
save_npy(args.out, 'semantic_pred', scan_ids, sem_preds)
|
||||
save_npy(args.out, 'semantic_label', scan_ids, sem_labels)
|
||||
if cfg.save_cfg.offset:
|
||||
save_npy(args.out, 'offset_pred', scan_ids, offset_preds)
|
||||
save_npy(args.out, 'offset_label', scan_ids, offset_labels)
|
||||
if cfg.save_cfg.instance:
|
||||
save_pred_instances(args.out, 'pred_instance', scan_ids, pred_insts)
|
||||
save_gt_instances(args.out, 'gt_instance', scan_ids, gt_insts)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user