SoftGroup/test.py
2022-04-08 03:19:49 +00:00

95 lines
3.1 KiB
Python

import argparse
import numpy as np
import random
import torch
import yaml
from munch import Munch
from tqdm import tqdm
import util.utils as utils
from evaluation import ScanNetEval
from model.softgroup import SoftGroup
from data.scannetv2 import ScanNetDataset
from torch.utils.data import DataLoader
from util import get_root_logger
from data import build_dataset, build_dataloader
def get_args():
parser = argparse.ArgumentParser('SoftGroup')
parser.add_argument('config', type=str, help='path to config file')
parser.add_argument('checkpoint', type=str, help='path to checkpoint')
args = parser.parse_args()
return args
def evaluate_semantic_segmantation_accuracy(matches):
seg_gt_list = []
seg_pred_list = []
for k, v in matches.items():
seg_gt_list.append(v['seg_gt'])
seg_pred_list.append(v['seg_pred'])
seg_gt_all = torch.cat(seg_gt_list, dim=0).cuda()
seg_pred_all = torch.cat(seg_pred_list, dim=0).cuda()
assert seg_gt_all.shape == seg_pred_all.shape
correct = (seg_gt_all[seg_gt_all != -100] == seg_pred_all[seg_gt_all != -100]).sum()
whole = (seg_gt_all != -100).sum()
seg_accuracy = correct.float() / whole.float()
return seg_accuracy
def evaluate_semantic_segmantation_miou(matches):
seg_gt_list = []
seg_pred_list = []
for k, v in matches.items():
seg_gt_list.append(v['seg_gt'])
seg_pred_list.append(v['seg_pred'])
seg_gt_all = torch.cat(seg_gt_list, dim=0).cuda()
seg_pred_all = torch.cat(seg_pred_list, dim=0).cuda()
pos_inds = seg_gt_all != -100
seg_gt_all = seg_gt_all[pos_inds]
seg_pred_all = seg_pred_all[pos_inds]
assert seg_gt_all.shape == seg_pred_all.shape
iou_list = []
for _index in seg_gt_all.unique():
if _index != -100:
intersection = ((seg_gt_all == _index) & (seg_pred_all == _index)).sum()
union = ((seg_gt_all == _index) | (seg_pred_all == _index)).sum()
iou = intersection.float() / union
iou_list.append(iou)
iou_tensor = torch.tensor(iou_list)
miou = iou_tensor.mean()
return miou
if __name__ == '__main__':
torch.backends.cudnn.enabled = False # TODO remove this
test_seed = 567
random.seed(test_seed)
np.random.seed(test_seed)
torch.manual_seed(test_seed)
torch.cuda.manual_seed_all(test_seed)
args = get_args()
cfg_txt = open(args.config, 'r').read()
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
logger = get_root_logger()
model = SoftGroup(**cfg.model)
logger.info(f'Load state dict from {args.checkpoint}')
utils.load_checkpoint(args.checkpoint, logger, model)
model.cuda()
dataset = build_dataset(cfg.data.test, logger)
dataloader = build_dataloader(dataset, training=False)
all_preds, all_gts = [], []
with torch.no_grad():
model = model.eval()
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
ret = model(batch)
all_preds.append(ret['det_ins'])
all_gts.append(ret['gt_ins'])
scannet_eval = ScanNetEval(dataset.CLASSES)
scannet_eval.evaluate(all_preds, all_gts)