import argparse import random import numpy as np import torch import yaml from munch import Munch from softgroup.data import build_dataloader, build_dataset from softgroup.evaluation import ScanNetEval from softgroup.model import SoftGroup from softgroup.util import get_root_logger, load_checkpoint from tqdm import tqdm def get_args(): parser = argparse.ArgumentParser('SoftGroup') parser.add_argument('config', type=str, help='path to config file') parser.add_argument('checkpoint', type=str, help='path to checkpoint') args = parser.parse_args() return args def evaluate_semantic_segmantation_accuracy(matches): seg_gt_list = [] seg_pred_list = [] for k, v in matches.items(): seg_gt_list.append(v['seg_gt']) seg_pred_list.append(v['seg_pred']) seg_gt_all = torch.cat(seg_gt_list, dim=0).cuda() seg_pred_all = torch.cat(seg_pred_list, dim=0).cuda() assert seg_gt_all.shape == seg_pred_all.shape correct = (seg_gt_all[seg_gt_all != -100] == seg_pred_all[seg_gt_all != -100]).sum() whole = (seg_gt_all != -100).sum() seg_accuracy = correct.float() / whole.float() return seg_accuracy def evaluate_semantic_segmantation_miou(matches): seg_gt_list = [] seg_pred_list = [] for k, v in matches.items(): seg_gt_list.append(v['seg_gt']) seg_pred_list.append(v['seg_pred']) seg_gt_all = torch.cat(seg_gt_list, dim=0).cuda() seg_pred_all = torch.cat(seg_pred_list, dim=0).cuda() pos_inds = seg_gt_all != -100 seg_gt_all = seg_gt_all[pos_inds] seg_pred_all = seg_pred_all[pos_inds] assert seg_gt_all.shape == seg_pred_all.shape iou_list = [] for _index in seg_gt_all.unique(): if _index != -100: intersection = ((seg_gt_all == _index) & (seg_pred_all == _index)).sum() union = ((seg_gt_all == _index) | (seg_pred_all == _index)).sum() iou = intersection.float() / union iou_list.append(iou) iou_tensor = torch.tensor(iou_list) miou = iou_tensor.mean() return miou if __name__ == '__main__': torch.backends.cudnn.enabled = False # TODO remove this test_seed = 567 random.seed(test_seed) np.random.seed(test_seed) torch.manual_seed(test_seed) torch.cuda.manual_seed_all(test_seed) args = get_args() cfg_txt = open(args.config, 'r').read() cfg = Munch.fromDict(yaml.safe_load(cfg_txt)) logger = get_root_logger() model = SoftGroup(**cfg.model) logger.info(f'Load state dict from {args.checkpoint}') load_checkpoint(args.checkpoint, logger, model) model.cuda() dataset = build_dataset(cfg.data.test, logger) dataloader = build_dataloader(dataset, training=False, **cfg.dataloader.test) all_preds, all_gts = [], [] with torch.no_grad(): model = model.eval() for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)): ret = model(batch) all_preds.append(ret['det_ins']) all_gts.append(ret['gt_ins']) scannet_eval = ScanNetEval(dataset.CLASSES) scannet_eval.evaluate(all_preds, all_gts)