mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
support spconv2
This commit is contained in:
parent
a93a1c58bc
commit
4090322ae0
@ -1,8 +1,8 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import spconv
|
import spconv.pytorch as spconv
|
||||||
import torch
|
import torch
|
||||||
from spconv.modules import SparseModule
|
from spconv.pytorch.modules import SparseModule
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
@ -26,6 +26,20 @@ class MLP(nn.Sequential):
|
|||||||
nn.init.constant_(self[-1].bias, 0)
|
nn.init.constant_(self[-1].bias, 0)
|
||||||
|
|
||||||
|
|
||||||
|
# current 1x1 conv in spconv2x has a bug. It will be removed after the bug is fixed
|
||||||
|
class Custom1x1Subm3d(spconv.SparseConv3d):
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
features = torch.mm(input.features, self.weight.view(self.in_channels, self.out_channels))
|
||||||
|
if self.bias is not None:
|
||||||
|
features += self.bias
|
||||||
|
out_tensor = spconv.SparseConvTensor(features, input.indices, input.spatial_shape,
|
||||||
|
input.batch_size)
|
||||||
|
out_tensor.indice_dict = input.indice_dict
|
||||||
|
out_tensor.grid = input.grid
|
||||||
|
return out_tensor
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(SparseModule):
|
class ResidualBlock(SparseModule):
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, norm_fn, indice_key=None):
|
def __init__(self, in_channels, out_channels, norm_fn, indice_key=None):
|
||||||
@ -35,7 +49,7 @@ class ResidualBlock(SparseModule):
|
|||||||
self.i_branch = spconv.SparseSequential(nn.Identity())
|
self.i_branch = spconv.SparseSequential(nn.Identity())
|
||||||
else:
|
else:
|
||||||
self.i_branch = spconv.SparseSequential(
|
self.i_branch = spconv.SparseSequential(
|
||||||
spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, bias=False))
|
Custom1x1Subm3d(in_channels, out_channels, kernel_size=1, bias=False))
|
||||||
|
|
||||||
self.conv_branch = spconv.SparseSequential(
|
self.conv_branch = spconv.SparseSequential(
|
||||||
norm_fn(in_channels), nn.ReLU(),
|
norm_fn(in_channels), nn.ReLU(),
|
||||||
@ -58,7 +72,8 @@ class ResidualBlock(SparseModule):
|
|||||||
identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape,
|
identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape,
|
||||||
input.batch_size)
|
input.batch_size)
|
||||||
output = self.conv_branch(input)
|
output = self.conv_branch(input)
|
||||||
output.features += self.i_branch(identity).features
|
out_feats = output.features + self.i_branch(identity).features
|
||||||
|
output = output.replace_feature(out_feats)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -121,6 +136,7 @@ class UBlock(nn.Module):
|
|||||||
output_decoder = self.conv(output)
|
output_decoder = self.conv(output)
|
||||||
output_decoder = self.u(output_decoder)
|
output_decoder = self.u(output_decoder)
|
||||||
output_decoder = self.deconv(output_decoder)
|
output_decoder = self.deconv(output_decoder)
|
||||||
output.features = torch.cat((identity.features, output_decoder.features), dim=1)
|
out_feats = torch.cat((identity.features, output_decoder.features), dim=1)
|
||||||
|
output = output.replace_feature(out_feats)
|
||||||
output = self.blocks_tail(output)
|
output = self.blocks_tail(output)
|
||||||
return output
|
return output
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import functools
|
import functools
|
||||||
|
|
||||||
import spconv
|
import spconv.pytorch as spconv
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|||||||
88
train.py
88
train.py
@ -2,12 +2,9 @@ import argparse
|
|||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import random
|
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from munch import Munch
|
from munch import Munch
|
||||||
@ -21,35 +18,6 @@ from tensorboardX import SummaryWriter
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def eval_epoch(val_loader, model, model_fn, epoch):
|
|
||||||
logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
|
|
||||||
am_dict = {}
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
model.eval()
|
|
||||||
start_epoch = time.time()
|
|
||||||
for i, batch in enumerate(val_loader):
|
|
||||||
|
|
||||||
# prepare input and forward
|
|
||||||
loss, preds, visual_dict, meter_dict = model_fn(
|
|
||||||
batch, model, epoch, semantic_only=cfg.semantic_only)
|
|
||||||
|
|
||||||
for k, v in meter_dict.items():
|
|
||||||
if k not in am_dict.keys():
|
|
||||||
am_dict[k] = AverageMeter()
|
|
||||||
am_dict[k].update(v[0], v[1])
|
|
||||||
sys.stdout.write('\riter: {}/{} loss: {:.4f}({:.4f})'.format(
|
|
||||||
i + 1, len(val_loader), am_dict['loss'].val, am_dict['loss'].avg))
|
|
||||||
|
|
||||||
logger.info('epoch: {}/{}, val loss: {:.4f}, time: {}s'.format(
|
|
||||||
epoch, cfg.epochs, am_dict['loss'].avg,
|
|
||||||
time.time() - start_epoch))
|
|
||||||
|
|
||||||
for k in am_dict.keys():
|
|
||||||
if k in visual_dict.keys():
|
|
||||||
writer.add_scalar(k + '_eval', am_dict[k].avg, epoch)
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser('SoftGroup')
|
parser = argparse.ArgumentParser('SoftGroup')
|
||||||
parser.add_argument('config', type=str, help='path to config file')
|
parser.add_argument('config', type=str, help='path to config file')
|
||||||
@ -60,14 +28,6 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# TODO remove these setup
|
|
||||||
torch.backends.cudnn.enabled = False
|
|
||||||
test_seed = 123
|
|
||||||
random.seed(test_seed)
|
|
||||||
np.random.seed(test_seed)
|
|
||||||
torch.manual_seed(test_seed)
|
|
||||||
torch.cuda.manual_seed_all(test_seed)
|
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
cfg_txt = open(args.config, 'r').read()
|
cfg_txt = open(args.config, 'r').read()
|
||||||
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
|
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
|
||||||
@ -144,7 +104,7 @@ if __name__ == '__main__':
|
|||||||
writer.add_scalar('learning_rate', lr, current_iter)
|
writer.add_scalar('learning_rate', lr, current_iter)
|
||||||
for k, v in meter_dict.items():
|
for k, v in meter_dict.items():
|
||||||
writer.add_scalar(k, v.val, current_iter)
|
writer.add_scalar(k, v.val, current_iter)
|
||||||
if i % 10 == 0:
|
if is_multiple(i, 10):
|
||||||
log_str = f'Epoch [{epoch}/{cfg.epochs}][{i}/{len(train_loader)}] '
|
log_str = f'Epoch [{epoch}/{cfg.epochs}][{i}/{len(train_loader)}] '
|
||||||
log_str += f'lr: {lr:.2g}, eta: {remain_time}, mem: {get_max_memory()}, '\
|
log_str += f'lr: {lr:.2g}, eta: {remain_time}, mem: {get_max_memory()}, '\
|
||||||
f'data_time: {data_time.val:.2f}, iter_time: {iter_time.val:.2f}'
|
f'data_time: {data_time.val:.2f}, iter_time: {iter_time.val:.2f}'
|
||||||
@ -154,27 +114,27 @@ if __name__ == '__main__':
|
|||||||
checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq)
|
checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq)
|
||||||
|
|
||||||
# validation
|
# validation
|
||||||
if not (is_multiple(epoch, cfg.save_freq) or is_power2(epoch)):
|
if is_multiple(epoch, cfg.save_freq) or is_power2(epoch):
|
||||||
continue
|
all_sem_preds, all_sem_labels, all_pred_insts, all_gt_insts = [], [], [], []
|
||||||
all_sem_preds, all_sem_labels, all_pred_insts, all_gt_insts = [], [], [], []
|
logger.info('Validation')
|
||||||
logger.info('Validation')
|
with torch.no_grad():
|
||||||
with torch.no_grad():
|
model = model.eval()
|
||||||
model = model.eval()
|
for batch in tqdm(val_loader, total=len(val_loader)):
|
||||||
for batch in tqdm(val_loader, total=len(val_loader)):
|
ret = model(batch)
|
||||||
ret = model(batch)
|
all_sem_preds.append(ret['semantic_preds'])
|
||||||
all_sem_preds.append(ret['semantic_preds'])
|
all_sem_labels.append(ret['semantic_labels'])
|
||||||
all_sem_labels.append(ret['semantic_labels'])
|
if not cfg.model.semantic_only:
|
||||||
|
all_pred_insts.append(ret['pred_instances'])
|
||||||
|
all_gt_insts.append(ret['gt_instances'])
|
||||||
if not cfg.model.semantic_only:
|
if not cfg.model.semantic_only:
|
||||||
all_pred_insts.append(ret['pred_instances'])
|
logger.info('Evaluate instance segmentation')
|
||||||
all_gt_insts.append(ret['gt_instances'])
|
scannet_eval = ScanNetEval(val_loader.dataset.CLASSES)
|
||||||
if not cfg.model.semantic_only:
|
scannet_eval.evaluate(all_pred_insts, all_gt_insts)
|
||||||
logger.info('Evaluate instance segmentation')
|
logger.info('Evaluate semantic segmentation')
|
||||||
scannet_eval = ScanNetEval(val_loader.dataset.CLASSES)
|
miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
|
||||||
scannet_eval.evaluate(all_pred_insts, all_gt_insts)
|
logger)
|
||||||
logger.info('Evaluate semantic segmentation')
|
acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
|
||||||
miou = evaluate_semantic_miou(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
|
logger)
|
||||||
logger)
|
writer.add_scalar('mIoU', miou, epoch)
|
||||||
acc = evaluate_semantic_acc(all_sem_preds, all_sem_labels, cfg.model.ignore_label,
|
writer.add_scalar('Acc', acc, epoch)
|
||||||
logger)
|
writer.flush()
|
||||||
writer.add_scalar('mIoU', miou, epoch)
|
|
||||||
writer.add_scalar('Acc', acc, epoch)
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user