support spconv2

This commit is contained in:
Thang Vu 2022-04-09 04:42:43 +00:00
parent a93a1c58bc
commit 4090322ae0
3 changed files with 46 additions and 70 deletions

View File

@ -1,8 +1,8 @@
from collections import OrderedDict
import spconv
import spconv.pytorch as spconv
import torch
from spconv.modules import SparseModule
from spconv.pytorch.modules import SparseModule
from torch import nn
@ -26,6 +26,20 @@ class MLP(nn.Sequential):
nn.init.constant_(self[-1].bias, 0)
# current 1x1 conv in spconv2x has a bug. It will be removed after the bug is fixed
class Custom1x1Subm3d(spconv.SparseConv3d):
def forward(self, input):
features = torch.mm(input.features, self.weight.view(self.in_channels, self.out_channels))
if self.bias is not None:
features += self.bias
out_tensor = spconv.SparseConvTensor(features, input.indices, input.spatial_shape,
input.batch_size)
out_tensor.indice_dict = input.indice_dict
out_tensor.grid = input.grid
return out_tensor
class ResidualBlock(SparseModule):
def __init__(self, in_channels, out_channels, norm_fn, indice_key=None):
@ -35,7 +49,7 @@ class ResidualBlock(SparseModule):
self.i_branch = spconv.SparseSequential(nn.Identity())
else:
self.i_branch = spconv.SparseSequential(
spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, bias=False))
Custom1x1Subm3d(in_channels, out_channels, kernel_size=1, bias=False))
self.conv_branch = spconv.SparseSequential(
norm_fn(in_channels), nn.ReLU(),
@ -58,7 +72,8 @@ class ResidualBlock(SparseModule):
identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape,
input.batch_size)
output = self.conv_branch(input)
output.features += self.i_branch(identity).features
out_feats = output.features + self.i_branch(identity).features
output = output.replace_feature(out_feats)
return output
@ -121,6 +136,7 @@ class UBlock(nn.Module):
output_decoder = self.conv(output)
output_decoder = self.u(output_decoder)
output_decoder = self.deconv(output_decoder)
output.features = torch.cat((identity.features, output_decoder.features), dim=1)
out_feats = torch.cat((identity.features, output_decoder.features), dim=1)
output = output.replace_feature(out_feats)
output = self.blocks_tail(output)
return output

View File

@ -1,6 +1,6 @@
import functools
import spconv
import spconv.pytorch as spconv
import torch
import torch.nn as nn
import torch.nn.functional as F

View File

@ -2,12 +2,9 @@ import argparse
import datetime
import os
import os.path as osp
import random
import shutil
import sys
import time
import numpy as np
import torch
import yaml
from munch import Munch
@ -21,35 +18,6 @@ from tensorboardX import SummaryWriter
from tqdm import tqdm
def eval_epoch(val_loader, model, model_fn, epoch):
logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
am_dict = {}
with torch.no_grad():
model.eval()
start_epoch = time.time()
for i, batch in enumerate(val_loader):
# prepare input and forward
loss, preds, visual_dict, meter_dict = model_fn(
batch, model, epoch, semantic_only=cfg.semantic_only)
for k, v in meter_dict.items():
if k not in am_dict.keys():
am_dict[k] = AverageMeter()
am_dict[k].update(v[0], v[1])
sys.stdout.write('\riter: {}/{} loss: {:.4f}({:.4f})'.format(
i + 1, len(val_loader), am_dict['loss'].val, am_dict['loss'].avg))
logger.info('epoch: {}/{}, val loss: {:.4f}, time: {}s'.format(
epoch, cfg.epochs, am_dict['loss'].avg,
time.time() - start_epoch))
for k in am_dict.keys():
if k in visual_dict.keys():
writer.add_scalar(k + '_eval', am_dict[k].avg, epoch)
def get_args():
parser = argparse.ArgumentParser('SoftGroup')
parser.add_argument('config', type=str, help='path to config file')
@ -60,14 +28,6 @@ def get_args():
if __name__ == '__main__':
# TODO remove these setup
torch.backends.cudnn.enabled = False
test_seed = 123
random.seed(test_seed)
np.random.seed(test_seed)
torch.manual_seed(test_seed)
torch.cuda.manual_seed_all(test_seed)
args = get_args()
cfg_txt = open(args.config, 'r').read()
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
@ -144,7 +104,7 @@ if __name__ == '__main__':
writer.add_scalar('learning_rate', lr, current_iter)
for k, v in meter_dict.items():
writer.add_scalar(k, v.val, current_iter)
if i % 10 == 0:
if is_multiple(i, 10):
log_str = f'Epoch [{epoch}/{cfg.epochs}][{i}/{len(train_loader)}] '
log_str += f'lr: {lr:.2g}, eta: {remain_time}, mem: {get_max_memory()}, '\
f'data_time: {data_time.val:.2f}, iter_time: {iter_time.val:.2f}'
@ -154,27 +114,27 @@ if __name__ == '__main__':
checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq)
# validation
if not (is_multiple(epoch, cfg.save_freq) or is_power2(epoch)):
continue
all_sem_preds, all_sem_labels, all_pred_insts, all_gt_insts = [], [], [], []
logger.info('Validation')
with torch.no_grad():
model = model.eval()
for batch in tqdm(val_loader, total=len(val_loader)):
ret = model(batch)
all_sem_preds.append(ret['semantic_preds'])
all_sem_labels.append(ret['semantic_labels'])
if is_multiple(epoch, cfg.save_freq) or is_power2(epoch):
all_sem_preds, all_sem_labels, all_pred_insts, all_gt_insts = [], [], [], []
logger.info('Validation')
with torch.no_grad():
model = model.eval()
for batch in tqdm(val_loader, total=len(val_loader)):
ret = model(batch)
all_sem_preds.append(ret['semantic_preds'])
all_sem_labels.append(ret['semantic_labels'])
if not cfg.model.semantic_only:
all_pred_insts.append(ret['pred_instances'])
all_gt_insts.append(ret['gt_instances'])
if not cfg.model.semantic_only:
all_pred_insts.append(ret['pred_instances'])
all_gt_insts.append(ret['gt_instances'])
if not cfg.model.semantic_only:
logger.info('Evaluate instance segmentation')
scannet_eval = ScanNetEval(val_loader.dataset.CLASSES)
scannet_eval.evaluate(all_pred_insts, all_gt_insts)
logger.info('Evaluate semantic segmentation')
miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
logger)
acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
logger)
writer.add_scalar('mIoU', miou, epoch)
writer.add_scalar('Acc', acc, epoch)
logger.info('Evaluate instance segmentation')
scannet_eval = ScanNetEval(val_loader.dataset.CLASSES)
scannet_eval.evaluate(all_pred_insts, all_gt_insts)
logger.info('Evaluate semantic segmentation')
miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
logger)
acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
logger)
writer.add_scalar('mIoU', miou, epoch)
writer.add_scalar('Acc', acc, epoch)
writer.flush()