From 4090322ae0f885bd6a754fed635badc178efb4ab Mon Sep 17 00:00:00 2001 From: Thang Vu Date: Sat, 9 Apr 2022 04:42:43 +0000 Subject: [PATCH] support spconv2 --- softgroup/model/blocks.py | 26 +++++++++-- softgroup/model/softgroup.py | 2 +- train.py | 88 ++++++++++-------------------------- 3 files changed, 46 insertions(+), 70 deletions(-) diff --git a/softgroup/model/blocks.py b/softgroup/model/blocks.py index ef3ed91..3b20646 100644 --- a/softgroup/model/blocks.py +++ b/softgroup/model/blocks.py @@ -1,8 +1,8 @@ from collections import OrderedDict -import spconv +import spconv.pytorch as spconv import torch -from spconv.modules import SparseModule +from spconv.pytorch.modules import SparseModule from torch import nn @@ -26,6 +26,20 @@ class MLP(nn.Sequential): nn.init.constant_(self[-1].bias, 0) +# current 1x1 conv in spconv2x has a bug. It will be removed after the bug is fixed +class Custom1x1Subm3d(spconv.SparseConv3d): + + def forward(self, input): + features = torch.mm(input.features, self.weight.view(self.in_channels, self.out_channels)) + if self.bias is not None: + features += self.bias + out_tensor = spconv.SparseConvTensor(features, input.indices, input.spatial_shape, + input.batch_size) + out_tensor.indice_dict = input.indice_dict + out_tensor.grid = input.grid + return out_tensor + + class ResidualBlock(SparseModule): def __init__(self, in_channels, out_channels, norm_fn, indice_key=None): @@ -35,7 +49,7 @@ class ResidualBlock(SparseModule): self.i_branch = spconv.SparseSequential(nn.Identity()) else: self.i_branch = spconv.SparseSequential( - spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, bias=False)) + Custom1x1Subm3d(in_channels, out_channels, kernel_size=1, bias=False)) self.conv_branch = spconv.SparseSequential( norm_fn(in_channels), nn.ReLU(), @@ -58,7 +72,8 @@ class ResidualBlock(SparseModule): identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape, input.batch_size) output = self.conv_branch(input) - output.features += self.i_branch(identity).features + out_feats = output.features + self.i_branch(identity).features + output = output.replace_feature(out_feats) return output @@ -121,6 +136,7 @@ class UBlock(nn.Module): output_decoder = self.conv(output) output_decoder = self.u(output_decoder) output_decoder = self.deconv(output_decoder) - output.features = torch.cat((identity.features, output_decoder.features), dim=1) + out_feats = torch.cat((identity.features, output_decoder.features), dim=1) + output = output.replace_feature(out_feats) output = self.blocks_tail(output) return output diff --git a/softgroup/model/softgroup.py b/softgroup/model/softgroup.py index 8a41f64..165ba99 100644 --- a/softgroup/model/softgroup.py +++ b/softgroup/model/softgroup.py @@ -1,6 +1,6 @@ import functools -import spconv +import spconv.pytorch as spconv import torch import torch.nn as nn import torch.nn.functional as F diff --git a/train.py b/train.py index 73cb83f..8c892f2 100644 --- a/train.py +++ b/train.py @@ -2,12 +2,9 @@ import argparse import datetime import os import os.path as osp -import random import shutil -import sys import time -import numpy as np import torch import yaml from munch import Munch @@ -21,35 +18,6 @@ from tensorboardX import SummaryWriter from tqdm import tqdm -def eval_epoch(val_loader, model, model_fn, epoch): - logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>') - am_dict = {} - - with torch.no_grad(): - model.eval() - start_epoch = time.time() - for i, batch in enumerate(val_loader): - - # prepare input and forward - loss, preds, visual_dict, meter_dict = model_fn( - batch, model, epoch, semantic_only=cfg.semantic_only) - - for k, v in meter_dict.items(): - if k not in am_dict.keys(): - am_dict[k] = AverageMeter() - am_dict[k].update(v[0], v[1]) - sys.stdout.write('\riter: {}/{} loss: {:.4f}({:.4f})'.format( - i + 1, len(val_loader), am_dict['loss'].val, am_dict['loss'].avg)) - - logger.info('epoch: {}/{}, val loss: {:.4f}, time: {}s'.format( - epoch, cfg.epochs, am_dict['loss'].avg, - time.time() - start_epoch)) - - for k in am_dict.keys(): - if k in visual_dict.keys(): - writer.add_scalar(k + '_eval', am_dict[k].avg, epoch) - - def get_args(): parser = argparse.ArgumentParser('SoftGroup') parser.add_argument('config', type=str, help='path to config file') @@ -60,14 +28,6 @@ def get_args(): if __name__ == '__main__': - # TODO remove these setup - torch.backends.cudnn.enabled = False - test_seed = 123 - random.seed(test_seed) - np.random.seed(test_seed) - torch.manual_seed(test_seed) - torch.cuda.manual_seed_all(test_seed) - args = get_args() cfg_txt = open(args.config, 'r').read() cfg = Munch.fromDict(yaml.safe_load(cfg_txt)) @@ -144,7 +104,7 @@ if __name__ == '__main__': writer.add_scalar('learning_rate', lr, current_iter) for k, v in meter_dict.items(): writer.add_scalar(k, v.val, current_iter) - if i % 10 == 0: + if is_multiple(i, 10): log_str = f'Epoch [{epoch}/{cfg.epochs}][{i}/{len(train_loader)}] ' log_str += f'lr: {lr:.2g}, eta: {remain_time}, mem: {get_max_memory()}, '\ f'data_time: {data_time.val:.2f}, iter_time: {iter_time.val:.2f}' @@ -154,27 +114,27 @@ if __name__ == '__main__': checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq) # validation - if not (is_multiple(epoch, cfg.save_freq) or is_power2(epoch)): - continue - all_sem_preds, all_sem_labels, all_pred_insts, all_gt_insts = [], [], [], [] - logger.info('Validation') - with torch.no_grad(): - model = model.eval() - for batch in tqdm(val_loader, total=len(val_loader)): - ret = model(batch) - all_sem_preds.append(ret['semantic_preds']) - all_sem_labels.append(ret['semantic_labels']) + if is_multiple(epoch, cfg.save_freq) or is_power2(epoch): + all_sem_preds, all_sem_labels, all_pred_insts, all_gt_insts = [], [], [], [] + logger.info('Validation') + with torch.no_grad(): + model = model.eval() + for batch in tqdm(val_loader, total=len(val_loader)): + ret = model(batch) + all_sem_preds.append(ret['semantic_preds']) + all_sem_labels.append(ret['semantic_labels']) + if not cfg.model.semantic_only: + all_pred_insts.append(ret['pred_instances']) + all_gt_insts.append(ret['gt_instances']) if not cfg.model.semantic_only: - all_pred_insts.append(ret['pred_instances']) - all_gt_insts.append(ret['gt_instances']) - if not cfg.model.semantic_only: - logger.info('Evaluate instance segmentation') - scannet_eval = ScanNetEval(val_loader.dataset.CLASSES) - scannet_eval.evaluate(all_pred_insts, all_gt_insts) - logger.info('Evaluate semantic segmentation') - miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label, - logger) - acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label, - logger) - writer.add_scalar('mIoU', miou, epoch) - writer.add_scalar('Acc', acc, epoch) + logger.info('Evaluate instance segmentation') + scannet_eval = ScanNetEval(val_loader.dataset.CLASSES) + scannet_eval.evaluate(all_pred_insts, all_gt_insts) + logger.info('Evaluate semantic segmentation') + miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label, + logger) + acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label, + logger) + writer.add_scalar('mIoU', miou, epoch) + writer.add_scalar('Acc', acc, epoch) + writer.flush()