support spconv2

This commit is contained in:
Thang Vu 2022-04-09 04:42:43 +00:00
parent a93a1c58bc
commit 4090322ae0
3 changed files with 46 additions and 70 deletions

View File

@ -1,8 +1,8 @@
from collections import OrderedDict from collections import OrderedDict
import spconv import spconv.pytorch as spconv
import torch import torch
from spconv.modules import SparseModule from spconv.pytorch.modules import SparseModule
from torch import nn from torch import nn
@ -26,6 +26,20 @@ class MLP(nn.Sequential):
nn.init.constant_(self[-1].bias, 0) nn.init.constant_(self[-1].bias, 0)
# current 1x1 conv in spconv2x has a bug. It will be removed after the bug is fixed
class Custom1x1Subm3d(spconv.SparseConv3d):
def forward(self, input):
features = torch.mm(input.features, self.weight.view(self.in_channels, self.out_channels))
if self.bias is not None:
features += self.bias
out_tensor = spconv.SparseConvTensor(features, input.indices, input.spatial_shape,
input.batch_size)
out_tensor.indice_dict = input.indice_dict
out_tensor.grid = input.grid
return out_tensor
class ResidualBlock(SparseModule): class ResidualBlock(SparseModule):
def __init__(self, in_channels, out_channels, norm_fn, indice_key=None): def __init__(self, in_channels, out_channels, norm_fn, indice_key=None):
@ -35,7 +49,7 @@ class ResidualBlock(SparseModule):
self.i_branch = spconv.SparseSequential(nn.Identity()) self.i_branch = spconv.SparseSequential(nn.Identity())
else: else:
self.i_branch = spconv.SparseSequential( self.i_branch = spconv.SparseSequential(
spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, bias=False)) Custom1x1Subm3d(in_channels, out_channels, kernel_size=1, bias=False))
self.conv_branch = spconv.SparseSequential( self.conv_branch = spconv.SparseSequential(
norm_fn(in_channels), nn.ReLU(), norm_fn(in_channels), nn.ReLU(),
@ -58,7 +72,8 @@ class ResidualBlock(SparseModule):
identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape, identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape,
input.batch_size) input.batch_size)
output = self.conv_branch(input) output = self.conv_branch(input)
output.features += self.i_branch(identity).features out_feats = output.features + self.i_branch(identity).features
output = output.replace_feature(out_feats)
return output return output
@ -121,6 +136,7 @@ class UBlock(nn.Module):
output_decoder = self.conv(output) output_decoder = self.conv(output)
output_decoder = self.u(output_decoder) output_decoder = self.u(output_decoder)
output_decoder = self.deconv(output_decoder) output_decoder = self.deconv(output_decoder)
output.features = torch.cat((identity.features, output_decoder.features), dim=1) out_feats = torch.cat((identity.features, output_decoder.features), dim=1)
output = output.replace_feature(out_feats)
output = self.blocks_tail(output) output = self.blocks_tail(output)
return output return output

View File

@ -1,6 +1,6 @@
import functools import functools
import spconv import spconv.pytorch as spconv
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F

View File

@ -2,12 +2,9 @@ import argparse
import datetime import datetime
import os import os
import os.path as osp import os.path as osp
import random
import shutil import shutil
import sys
import time import time
import numpy as np
import torch import torch
import yaml import yaml
from munch import Munch from munch import Munch
@ -21,35 +18,6 @@ from tensorboardX import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
def eval_epoch(val_loader, model, model_fn, epoch):
logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
am_dict = {}
with torch.no_grad():
model.eval()
start_epoch = time.time()
for i, batch in enumerate(val_loader):
# prepare input and forward
loss, preds, visual_dict, meter_dict = model_fn(
batch, model, epoch, semantic_only=cfg.semantic_only)
for k, v in meter_dict.items():
if k not in am_dict.keys():
am_dict[k] = AverageMeter()
am_dict[k].update(v[0], v[1])
sys.stdout.write('\riter: {}/{} loss: {:.4f}({:.4f})'.format(
i + 1, len(val_loader), am_dict['loss'].val, am_dict['loss'].avg))
logger.info('epoch: {}/{}, val loss: {:.4f}, time: {}s'.format(
epoch, cfg.epochs, am_dict['loss'].avg,
time.time() - start_epoch))
for k in am_dict.keys():
if k in visual_dict.keys():
writer.add_scalar(k + '_eval', am_dict[k].avg, epoch)
def get_args(): def get_args():
parser = argparse.ArgumentParser('SoftGroup') parser = argparse.ArgumentParser('SoftGroup')
parser.add_argument('config', type=str, help='path to config file') parser.add_argument('config', type=str, help='path to config file')
@ -60,14 +28,6 @@ def get_args():
if __name__ == '__main__': if __name__ == '__main__':
# TODO remove these setup
torch.backends.cudnn.enabled = False
test_seed = 123
random.seed(test_seed)
np.random.seed(test_seed)
torch.manual_seed(test_seed)
torch.cuda.manual_seed_all(test_seed)
args = get_args() args = get_args()
cfg_txt = open(args.config, 'r').read() cfg_txt = open(args.config, 'r').read()
cfg = Munch.fromDict(yaml.safe_load(cfg_txt)) cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
@ -144,7 +104,7 @@ if __name__ == '__main__':
writer.add_scalar('learning_rate', lr, current_iter) writer.add_scalar('learning_rate', lr, current_iter)
for k, v in meter_dict.items(): for k, v in meter_dict.items():
writer.add_scalar(k, v.val, current_iter) writer.add_scalar(k, v.val, current_iter)
if i % 10 == 0: if is_multiple(i, 10):
log_str = f'Epoch [{epoch}/{cfg.epochs}][{i}/{len(train_loader)}] ' log_str = f'Epoch [{epoch}/{cfg.epochs}][{i}/{len(train_loader)}] '
log_str += f'lr: {lr:.2g}, eta: {remain_time}, mem: {get_max_memory()}, '\ log_str += f'lr: {lr:.2g}, eta: {remain_time}, mem: {get_max_memory()}, '\
f'data_time: {data_time.val:.2f}, iter_time: {iter_time.val:.2f}' f'data_time: {data_time.val:.2f}, iter_time: {iter_time.val:.2f}'
@ -154,27 +114,27 @@ if __name__ == '__main__':
checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq) checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq)
# validation # validation
if not (is_multiple(epoch, cfg.save_freq) or is_power2(epoch)): if is_multiple(epoch, cfg.save_freq) or is_power2(epoch):
continue all_sem_preds, all_sem_labels, all_pred_insts, all_gt_insts = [], [], [], []
all_sem_preds, all_sem_labels, all_pred_insts, all_gt_insts = [], [], [], [] logger.info('Validation')
logger.info('Validation') with torch.no_grad():
with torch.no_grad(): model = model.eval()
model = model.eval() for batch in tqdm(val_loader, total=len(val_loader)):
for batch in tqdm(val_loader, total=len(val_loader)): ret = model(batch)
ret = model(batch) all_sem_preds.append(ret['semantic_preds'])
all_sem_preds.append(ret['semantic_preds']) all_sem_labels.append(ret['semantic_labels'])
all_sem_labels.append(ret['semantic_labels']) if not cfg.model.semantic_only:
all_pred_insts.append(ret['pred_instances'])
all_gt_insts.append(ret['gt_instances'])
if not cfg.model.semantic_only: if not cfg.model.semantic_only:
all_pred_insts.append(ret['pred_instances']) logger.info('Evaluate instance segmentation')
all_gt_insts.append(ret['gt_instances']) scannet_eval = ScanNetEval(val_loader.dataset.CLASSES)
if not cfg.model.semantic_only: scannet_eval.evaluate(all_pred_insts, all_gt_insts)
logger.info('Evaluate instance segmentation') logger.info('Evaluate semantic segmentation')
scannet_eval = ScanNetEval(val_loader.dataset.CLASSES) miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
scannet_eval.evaluate(all_pred_insts, all_gt_insts) logger)
logger.info('Evaluate semantic segmentation') acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label, logger)
logger) writer.add_scalar('mIoU', miou, epoch)
acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label, writer.add_scalar('Acc', acc, epoch)
logger) writer.flush()
writer.add_scalar('mIoU', miou, epoch)
writer.add_scalar('Acc', acc, epoch)