SoftGroup/softgroup/model/blocks.py
2022-04-09 03:17:01 +00:00

127 lines
4.3 KiB
Python

from collections import OrderedDict
import spconv
import torch
from spconv.modules import SparseModule
from torch import nn
class MLP(nn.Sequential):
def __init__(self, in_channels, out_channels, norm_fn, num_layers=2):
modules = []
for _ in range(num_layers - 1):
modules.extend(
[nn.Linear(in_channels, in_channels, bias=False),
norm_fn(in_channels),
nn.ReLU()])
modules.append(nn.Linear(in_channels, out_channels))
return super().__init__(*modules)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.normal_(self[-1].weight, 0, 0.01)
nn.init.constant_(self[-1].bias, 0)
class ResidualBlock(SparseModule):
def __init__(self, in_channels, out_channels, norm_fn, indice_key=None):
super().__init__()
if in_channels == out_channels:
self.i_branch = spconv.SparseSequential(nn.Identity())
else:
self.i_branch = spconv.SparseSequential(
spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, bias=False))
self.conv_branch = spconv.SparseSequential(
norm_fn(in_channels), nn.ReLU(),
spconv.SubMConv3d(
in_channels,
out_channels,
kernel_size=3,
padding=1,
bias=False,
indice_key=indice_key), norm_fn(out_channels), nn.ReLU(),
spconv.SubMConv3d(
out_channels,
out_channels,
kernel_size=3,
padding=1,
bias=False,
indice_key=indice_key))
def forward(self, input):
identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape,
input.batch_size)
output = self.conv_branch(input)
output.features += self.i_branch(identity).features
return output
class UBlock(nn.Module):
def __init__(self, nPlanes, norm_fn, block_reps, block, indice_key_id=1):
super().__init__()
self.nPlanes = nPlanes
blocks = {
'block{}'.format(i):
block(nPlanes[0], nPlanes[0], norm_fn, indice_key='subm{}'.format(indice_key_id))
for i in range(block_reps)
}
blocks = OrderedDict(blocks)
self.blocks = spconv.SparseSequential(blocks)
if len(nPlanes) > 1:
self.conv = spconv.SparseSequential(
norm_fn(nPlanes[0]), nn.ReLU(),
spconv.SparseConv3d(
nPlanes[0],
nPlanes[1],
kernel_size=2,
stride=2,
bias=False,
indice_key='spconv{}'.format(indice_key_id)))
self.u = UBlock(
nPlanes[1:], norm_fn, block_reps, block, indice_key_id=indice_key_id + 1)
self.deconv = spconv.SparseSequential(
norm_fn(nPlanes[1]), nn.ReLU(),
spconv.SparseInverseConv3d(
nPlanes[1],
nPlanes[0],
kernel_size=2,
bias=False,
indice_key='spconv{}'.format(indice_key_id)))
blocks_tail = {}
for i in range(block_reps):
blocks_tail['block{}'.format(i)] = block(
nPlanes[0] * (2 - i),
nPlanes[0],
norm_fn,
indice_key='subm{}'.format(indice_key_id))
blocks_tail = OrderedDict(blocks_tail)
self.blocks_tail = spconv.SparseSequential(blocks_tail)
def forward(self, input):
output = self.blocks(input)
identity = spconv.SparseConvTensor(output.features, output.indices, output.spatial_shape,
output.batch_size)
if len(self.nPlanes) > 1:
output_decoder = self.conv(output)
output_decoder = self.u(output_decoder)
output_decoder = self.deconv(output_decoder)
output.features = torch.cat((identity.features, output_decoder.features), dim=1)
output = self.blocks_tail(output)
return output