support multiprocess test

This commit is contained in:
Thang Vu 2022-04-13 14:49:34 +00:00
parent 11130f850d
commit f0802b75bb

View File

@ -1,10 +1,10 @@
# Adapted from https://github.com/ScanNet/ScanNet/blob/master/BenchmarkScripts/3d_evaluation/evaluate_semantic_instance.py # noqa E501 # Adapted from https://github.com/ScanNet/ScanNet/blob/master/BenchmarkScripts/3d_evaluation/evaluate_semantic_instance.py # noqa E501
# Modified by Thang Vu # Modified by Thang Vu
import multiprocessing as mp
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
from tqdm import tqdm
from ..util import rle_decode from ..util import rle_decode
from .instance_eval_util import get_instances from .instance_eval_util import get_instances
@ -381,16 +381,17 @@ class ScanNetEval(object):
for each point: for each point:
gt_id = class_id * 1000 + instance_id gt_id = class_id * 1000 + instance_id
""" """
print('evaluating', len(pred_list), 'scans...') pool = mp.Pool()
results = pool.starmap(self.assign_instances_for_scan, zip(pred_list, gt_list))
pool.close()
pool.join()
matches = {} matches = {}
for i, (preds, gts) in enumerate(tqdm(zip(pred_list, gt_list), total=len(pred_list))): for i, (gt2pred, pred2gt) in enumerate(results):
gt2pred, pred2gt = self.assign_instances_for_scan(preds, gts)
# assign gt to predictions
matches_key = f'gt_{i}' matches_key = f'gt_{i}'
matches[matches_key] = {} matches[matches_key] = {}
matches[matches_key]['gt'] = gt2pred matches[matches_key]['gt'] = gt2pred
matches[matches_key]['pred'] = pred2gt matches[matches_key]['pred'] = pred2gt
print()
ap_scores, rc_scores = self.evaluate_matches(matches) ap_scores, rc_scores = self.evaluate_matches(matches)
avgs = self.compute_averages(ap_scores, rc_scores) avgs = self.compute_averages(ap_scores, rc_scores)