mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
143 lines
5.0 KiB
Python
143 lines
5.0 KiB
Python
from collections import OrderedDict
|
|
|
|
import spconv.pytorch as spconv
|
|
import torch
|
|
from spconv.pytorch.modules import SparseModule
|
|
from torch import nn
|
|
|
|
|
|
class MLP(nn.Sequential):
|
|
|
|
def __init__(self, in_channels, out_channels, norm_fn, num_layers=2):
|
|
modules = []
|
|
for _ in range(num_layers - 1):
|
|
modules.extend(
|
|
[nn.Linear(in_channels, in_channels, bias=False),
|
|
norm_fn(in_channels),
|
|
nn.ReLU()])
|
|
modules.append(nn.Linear(in_channels, out_channels))
|
|
return super().__init__(*modules)
|
|
|
|
def init_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Linear):
|
|
nn.init.xavier_uniform_(m.weight)
|
|
nn.init.normal_(self[-1].weight, 0, 0.01)
|
|
nn.init.constant_(self[-1].bias, 0)
|
|
|
|
|
|
# current 1x1 conv in spconv2x has a bug. It will be removed after the bug is fixed
|
|
class Custom1x1Subm3d(spconv.SparseConv3d):
|
|
|
|
def forward(self, input):
|
|
features = torch.mm(input.features, self.weight.view(self.in_channels, self.out_channels))
|
|
if self.bias is not None:
|
|
features += self.bias
|
|
out_tensor = spconv.SparseConvTensor(features, input.indices, input.spatial_shape,
|
|
input.batch_size)
|
|
out_tensor.indice_dict = input.indice_dict
|
|
out_tensor.grid = input.grid
|
|
return out_tensor
|
|
|
|
|
|
class ResidualBlock(SparseModule):
|
|
|
|
def __init__(self, in_channels, out_channels, norm_fn, indice_key=None):
|
|
super().__init__()
|
|
|
|
if in_channels == out_channels:
|
|
self.i_branch = spconv.SparseSequential(nn.Identity())
|
|
else:
|
|
self.i_branch = spconv.SparseSequential(
|
|
Custom1x1Subm3d(in_channels, out_channels, kernel_size=1, bias=False))
|
|
|
|
self.conv_branch = spconv.SparseSequential(
|
|
norm_fn(in_channels), nn.ReLU(),
|
|
spconv.SubMConv3d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
bias=False,
|
|
indice_key=indice_key), norm_fn(out_channels), nn.ReLU(),
|
|
spconv.SubMConv3d(
|
|
out_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
bias=False,
|
|
indice_key=indice_key))
|
|
|
|
def forward(self, input):
|
|
identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape,
|
|
input.batch_size)
|
|
output = self.conv_branch(input)
|
|
out_feats = output.features + self.i_branch(identity).features
|
|
output = output.replace_feature(out_feats)
|
|
|
|
return output
|
|
|
|
|
|
class UBlock(nn.Module):
|
|
|
|
def __init__(self, nPlanes, norm_fn, block_reps, block, indice_key_id=1):
|
|
|
|
super().__init__()
|
|
|
|
self.nPlanes = nPlanes
|
|
|
|
blocks = {
|
|
'block{}'.format(i):
|
|
block(nPlanes[0], nPlanes[0], norm_fn, indice_key='subm{}'.format(indice_key_id))
|
|
for i in range(block_reps)
|
|
}
|
|
blocks = OrderedDict(blocks)
|
|
self.blocks = spconv.SparseSequential(blocks)
|
|
|
|
if len(nPlanes) > 1:
|
|
self.conv = spconv.SparseSequential(
|
|
norm_fn(nPlanes[0]), nn.ReLU(),
|
|
spconv.SparseConv3d(
|
|
nPlanes[0],
|
|
nPlanes[1],
|
|
kernel_size=2,
|
|
stride=2,
|
|
bias=False,
|
|
indice_key='spconv{}'.format(indice_key_id)))
|
|
|
|
self.u = UBlock(
|
|
nPlanes[1:], norm_fn, block_reps, block, indice_key_id=indice_key_id + 1)
|
|
|
|
self.deconv = spconv.SparseSequential(
|
|
norm_fn(nPlanes[1]), nn.ReLU(),
|
|
spconv.SparseInverseConv3d(
|
|
nPlanes[1],
|
|
nPlanes[0],
|
|
kernel_size=2,
|
|
bias=False,
|
|
indice_key='spconv{}'.format(indice_key_id)))
|
|
|
|
blocks_tail = {}
|
|
for i in range(block_reps):
|
|
blocks_tail['block{}'.format(i)] = block(
|
|
nPlanes[0] * (2 - i),
|
|
nPlanes[0],
|
|
norm_fn,
|
|
indice_key='subm{}'.format(indice_key_id))
|
|
blocks_tail = OrderedDict(blocks_tail)
|
|
self.blocks_tail = spconv.SparseSequential(blocks_tail)
|
|
|
|
def forward(self, input):
|
|
|
|
output = self.blocks(input)
|
|
identity = spconv.SparseConvTensor(output.features, output.indices, output.spatial_shape,
|
|
output.batch_size)
|
|
if len(self.nPlanes) > 1:
|
|
output_decoder = self.conv(output)
|
|
output_decoder = self.u(output_decoder)
|
|
output_decoder = self.deconv(output_decoder)
|
|
out_feats = torch.cat((identity.features, output_decoder.features), dim=1)
|
|
output = output.replace_feature(out_feats)
|
|
output = self.blocks_tail(output)
|
|
return output
|