mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
import argparse
|
|
import random
|
|
|
|
import numpy as np
|
|
import torch
|
|
import yaml
|
|
from munch import Munch
|
|
from softgroup.data import build_dataloader, build_dataset
|
|
from softgroup.evaluation import ScanNetEval
|
|
from softgroup.model import SoftGroup
|
|
from softgroup.util import get_root_logger, load_checkpoint
|
|
from tqdm import tqdm
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser('SoftGroup')
|
|
parser.add_argument('config', type=str, help='path to config file')
|
|
parser.add_argument('checkpoint', type=str, help='path to checkpoint')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def evaluate_semantic_segmantation_accuracy(matches):
|
|
seg_gt_list = []
|
|
seg_pred_list = []
|
|
for k, v in matches.items():
|
|
seg_gt_list.append(v['seg_gt'])
|
|
seg_pred_list.append(v['seg_pred'])
|
|
seg_gt_all = torch.cat(seg_gt_list, dim=0).cuda()
|
|
seg_pred_all = torch.cat(seg_pred_list, dim=0).cuda()
|
|
assert seg_gt_all.shape == seg_pred_all.shape
|
|
correct = (seg_gt_all[seg_gt_all != -100] == seg_pred_all[seg_gt_all != -100]).sum()
|
|
whole = (seg_gt_all != -100).sum()
|
|
seg_accuracy = correct.float() / whole.float()
|
|
return seg_accuracy
|
|
|
|
|
|
def evaluate_semantic_segmantation_miou(matches):
|
|
seg_gt_list = []
|
|
seg_pred_list = []
|
|
for k, v in matches.items():
|
|
seg_gt_list.append(v['seg_gt'])
|
|
seg_pred_list.append(v['seg_pred'])
|
|
seg_gt_all = torch.cat(seg_gt_list, dim=0).cuda()
|
|
seg_pred_all = torch.cat(seg_pred_list, dim=0).cuda()
|
|
pos_inds = seg_gt_all != -100
|
|
seg_gt_all = seg_gt_all[pos_inds]
|
|
seg_pred_all = seg_pred_all[pos_inds]
|
|
assert seg_gt_all.shape == seg_pred_all.shape
|
|
iou_list = []
|
|
for _index in seg_gt_all.unique():
|
|
if _index != -100:
|
|
intersection = ((seg_gt_all == _index) & (seg_pred_all == _index)).sum()
|
|
union = ((seg_gt_all == _index) | (seg_pred_all == _index)).sum()
|
|
iou = intersection.float() / union
|
|
iou_list.append(iou)
|
|
iou_tensor = torch.tensor(iou_list)
|
|
miou = iou_tensor.mean()
|
|
return miou
|
|
|
|
|
|
if __name__ == '__main__':
|
|
torch.backends.cudnn.enabled = False # TODO remove this
|
|
test_seed = 567
|
|
random.seed(test_seed)
|
|
np.random.seed(test_seed)
|
|
torch.manual_seed(test_seed)
|
|
torch.cuda.manual_seed_all(test_seed)
|
|
|
|
args = get_args()
|
|
cfg_txt = open(args.config, 'r').read()
|
|
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
|
|
logger = get_root_logger()
|
|
|
|
model = SoftGroup(**cfg.model)
|
|
logger.info(f'Load state dict from {args.checkpoint}')
|
|
load_checkpoint(args.checkpoint, logger, model)
|
|
model.cuda()
|
|
|
|
dataset = build_dataset(cfg.data.test, logger)
|
|
dataloader = build_dataloader(dataset, training=False, **cfg.dataloader.test)
|
|
all_preds, all_gts = [], []
|
|
with torch.no_grad():
|
|
model = model.eval()
|
|
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
|
|
ret = model(batch)
|
|
all_preds.append(ret['det_ins'])
|
|
all_gts.append(ret['gt_ins'])
|
|
scannet_eval = ScanNetEval(dataset.CLASSES)
|
|
scannet_eval.evaluate(all_preds, all_gts)
|