support multiprocess test

This commit is contained in:
Thang Vu 2022-04-13 14:49:34 +00:00
parent 11130f850d
commit f0802b75bb

View File

@ -1,10 +1,10 @@
# Adapted from https://github.com/ScanNet/ScanNet/blob/master/BenchmarkScripts/3d_evaluation/evaluate_semantic_instance.py # noqa E501
# Modified by Thang Vu
import multiprocessing as mp
from copy import deepcopy
import numpy as np
from tqdm import tqdm
from ..util import rle_decode
from .instance_eval_util import get_instances
@ -381,16 +381,17 @@ class ScanNetEval(object):
for each point:
gt_id = class_id * 1000 + instance_id
"""
print('evaluating', len(pred_list), 'scans...')
pool = mp.Pool()
results = pool.starmap(self.assign_instances_for_scan, zip(pred_list, gt_list))
pool.close()
pool.join()
matches = {}
for i, (preds, gts) in enumerate(tqdm(zip(pred_list, gt_list), total=len(pred_list))):
gt2pred, pred2gt = self.assign_instances_for_scan(preds, gts)
# assign gt to predictions
for i, (gt2pred, pred2gt) in enumerate(results):
matches_key = f'gt_{i}'
matches[matches_key] = {}
matches[matches_key]['gt'] = gt2pred
matches[matches_key]['pred'] = pred2gt
print()
ap_scores, rc_scores = self.evaluate_matches(matches)
avgs = self.compute_averages(ap_scores, rc_scores)