mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
support spconv2
This commit is contained in:
parent
a93a1c58bc
commit
4090322ae0
@ -1,8 +1,8 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import spconv
|
||||
import spconv.pytorch as spconv
|
||||
import torch
|
||||
from spconv.modules import SparseModule
|
||||
from spconv.pytorch.modules import SparseModule
|
||||
from torch import nn
|
||||
|
||||
|
||||
@ -26,6 +26,20 @@ class MLP(nn.Sequential):
|
||||
nn.init.constant_(self[-1].bias, 0)
|
||||
|
||||
|
||||
# current 1x1 conv in spconv2x has a bug. It will be removed after the bug is fixed
|
||||
class Custom1x1Subm3d(spconv.SparseConv3d):
|
||||
|
||||
def forward(self, input):
|
||||
features = torch.mm(input.features, self.weight.view(self.in_channels, self.out_channels))
|
||||
if self.bias is not None:
|
||||
features += self.bias
|
||||
out_tensor = spconv.SparseConvTensor(features, input.indices, input.spatial_shape,
|
||||
input.batch_size)
|
||||
out_tensor.indice_dict = input.indice_dict
|
||||
out_tensor.grid = input.grid
|
||||
return out_tensor
|
||||
|
||||
|
||||
class ResidualBlock(SparseModule):
|
||||
|
||||
def __init__(self, in_channels, out_channels, norm_fn, indice_key=None):
|
||||
@ -35,7 +49,7 @@ class ResidualBlock(SparseModule):
|
||||
self.i_branch = spconv.SparseSequential(nn.Identity())
|
||||
else:
|
||||
self.i_branch = spconv.SparseSequential(
|
||||
spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, bias=False))
|
||||
Custom1x1Subm3d(in_channels, out_channels, kernel_size=1, bias=False))
|
||||
|
||||
self.conv_branch = spconv.SparseSequential(
|
||||
norm_fn(in_channels), nn.ReLU(),
|
||||
@ -58,7 +72,8 @@ class ResidualBlock(SparseModule):
|
||||
identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape,
|
||||
input.batch_size)
|
||||
output = self.conv_branch(input)
|
||||
output.features += self.i_branch(identity).features
|
||||
out_feats = output.features + self.i_branch(identity).features
|
||||
output = output.replace_feature(out_feats)
|
||||
|
||||
return output
|
||||
|
||||
@ -121,6 +136,7 @@ class UBlock(nn.Module):
|
||||
output_decoder = self.conv(output)
|
||||
output_decoder = self.u(output_decoder)
|
||||
output_decoder = self.deconv(output_decoder)
|
||||
output.features = torch.cat((identity.features, output_decoder.features), dim=1)
|
||||
out_feats = torch.cat((identity.features, output_decoder.features), dim=1)
|
||||
output = output.replace_feature(out_feats)
|
||||
output = self.blocks_tail(output)
|
||||
return output
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import functools
|
||||
|
||||
import spconv
|
||||
import spconv.pytorch as spconv
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
88
train.py
88
train.py
@ -2,12 +2,9 @@ import argparse
|
||||
import datetime
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from munch import Munch
|
||||
@ -21,35 +18,6 @@ from tensorboardX import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def eval_epoch(val_loader, model, model_fn, epoch):
|
||||
logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
|
||||
am_dict = {}
|
||||
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
start_epoch = time.time()
|
||||
for i, batch in enumerate(val_loader):
|
||||
|
||||
# prepare input and forward
|
||||
loss, preds, visual_dict, meter_dict = model_fn(
|
||||
batch, model, epoch, semantic_only=cfg.semantic_only)
|
||||
|
||||
for k, v in meter_dict.items():
|
||||
if k not in am_dict.keys():
|
||||
am_dict[k] = AverageMeter()
|
||||
am_dict[k].update(v[0], v[1])
|
||||
sys.stdout.write('\riter: {}/{} loss: {:.4f}({:.4f})'.format(
|
||||
i + 1, len(val_loader), am_dict['loss'].val, am_dict['loss'].avg))
|
||||
|
||||
logger.info('epoch: {}/{}, val loss: {:.4f}, time: {}s'.format(
|
||||
epoch, cfg.epochs, am_dict['loss'].avg,
|
||||
time.time() - start_epoch))
|
||||
|
||||
for k in am_dict.keys():
|
||||
if k in visual_dict.keys():
|
||||
writer.add_scalar(k + '_eval', am_dict[k].avg, epoch)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser('SoftGroup')
|
||||
parser.add_argument('config', type=str, help='path to config file')
|
||||
@ -60,14 +28,6 @@ def get_args():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# TODO remove these setup
|
||||
torch.backends.cudnn.enabled = False
|
||||
test_seed = 123
|
||||
random.seed(test_seed)
|
||||
np.random.seed(test_seed)
|
||||
torch.manual_seed(test_seed)
|
||||
torch.cuda.manual_seed_all(test_seed)
|
||||
|
||||
args = get_args()
|
||||
cfg_txt = open(args.config, 'r').read()
|
||||
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
|
||||
@ -144,7 +104,7 @@ if __name__ == '__main__':
|
||||
writer.add_scalar('learning_rate', lr, current_iter)
|
||||
for k, v in meter_dict.items():
|
||||
writer.add_scalar(k, v.val, current_iter)
|
||||
if i % 10 == 0:
|
||||
if is_multiple(i, 10):
|
||||
log_str = f'Epoch [{epoch}/{cfg.epochs}][{i}/{len(train_loader)}] '
|
||||
log_str += f'lr: {lr:.2g}, eta: {remain_time}, mem: {get_max_memory()}, '\
|
||||
f'data_time: {data_time.val:.2f}, iter_time: {iter_time.val:.2f}'
|
||||
@ -154,27 +114,27 @@ if __name__ == '__main__':
|
||||
checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq)
|
||||
|
||||
# validation
|
||||
if not (is_multiple(epoch, cfg.save_freq) or is_power2(epoch)):
|
||||
continue
|
||||
all_sem_preds, all_sem_labels, all_pred_insts, all_gt_insts = [], [], [], []
|
||||
logger.info('Validation')
|
||||
with torch.no_grad():
|
||||
model = model.eval()
|
||||
for batch in tqdm(val_loader, total=len(val_loader)):
|
||||
ret = model(batch)
|
||||
all_sem_preds.append(ret['semantic_preds'])
|
||||
all_sem_labels.append(ret['semantic_labels'])
|
||||
if is_multiple(epoch, cfg.save_freq) or is_power2(epoch):
|
||||
all_sem_preds, all_sem_labels, all_pred_insts, all_gt_insts = [], [], [], []
|
||||
logger.info('Validation')
|
||||
with torch.no_grad():
|
||||
model = model.eval()
|
||||
for batch in tqdm(val_loader, total=len(val_loader)):
|
||||
ret = model(batch)
|
||||
all_sem_preds.append(ret['semantic_preds'])
|
||||
all_sem_labels.append(ret['semantic_labels'])
|
||||
if not cfg.model.semantic_only:
|
||||
all_pred_insts.append(ret['pred_instances'])
|
||||
all_gt_insts.append(ret['gt_instances'])
|
||||
if not cfg.model.semantic_only:
|
||||
all_pred_insts.append(ret['pred_instances'])
|
||||
all_gt_insts.append(ret['gt_instances'])
|
||||
if not cfg.model.semantic_only:
|
||||
logger.info('Evaluate instance segmentation')
|
||||
scannet_eval = ScanNetEval(val_loader.dataset.CLASSES)
|
||||
scannet_eval.evaluate(all_pred_insts, all_gt_insts)
|
||||
logger.info('Evaluate semantic segmentation')
|
||||
miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
|
||||
logger)
|
||||
acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
|
||||
logger)
|
||||
writer.add_scalar('mIoU', miou, epoch)
|
||||
writer.add_scalar('Acc', acc, epoch)
|
||||
logger.info('Evaluate instance segmentation')
|
||||
scannet_eval = ScanNetEval(val_loader.dataset.CLASSES)
|
||||
scannet_eval.evaluate(all_pred_insts, all_gt_insts)
|
||||
logger.info('Evaluate semantic segmentation')
|
||||
miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
|
||||
logger)
|
||||
acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
|
||||
logger)
|
||||
writer.add_scalar('mIoU', miou, epoch)
|
||||
writer.add_scalar('Acc', acc, epoch)
|
||||
writer.flush()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user