mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
245 lines
10 KiB
Python
245 lines
10 KiB
Python
import math
|
|
import os.path as osp
|
|
from glob import glob
|
|
|
|
import numpy as np
|
|
import scipy.interpolate
|
|
import scipy.ndimage
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
from ..lib.softgroup_ops import voxelization_idx
|
|
|
|
|
|
class CustomDataset(Dataset):
|
|
|
|
CLASSES = None
|
|
|
|
def __init__(self,
|
|
data_root,
|
|
prefix,
|
|
suffix,
|
|
voxel_cfg=None,
|
|
training=True,
|
|
repeat=1,
|
|
logger=None):
|
|
self.data_root = data_root
|
|
self.prefix = prefix
|
|
self.suffix = suffix
|
|
self.voxel_cfg = voxel_cfg
|
|
self.training = training
|
|
self.repeat = repeat
|
|
self.logger = logger
|
|
self.mode = 'train' if training else 'test'
|
|
self.filenames = self.get_filenames()
|
|
self.logger.info(f'Load {self.mode} dataset: {len(self.filenames)} scans')
|
|
|
|
def get_filenames(self):
|
|
filenames = glob(osp.join(self.data_root, self.prefix, '*' + self.suffix))
|
|
assert len(filenames) > 0, 'Empty dataset.'
|
|
filenames = sorted(filenames * self.repeat)
|
|
return filenames
|
|
|
|
def load(self, filename):
|
|
return torch.load(filename)
|
|
|
|
def __len__(self):
|
|
return len(self.filenames)
|
|
|
|
def elastic(self, x, gran, mag):
|
|
blur0 = np.ones((3, 1, 1)).astype('float32') / 3
|
|
blur1 = np.ones((1, 3, 1)).astype('float32') / 3
|
|
blur2 = np.ones((1, 1, 3)).astype('float32') / 3
|
|
|
|
bb = np.abs(x).max(0).astype(np.int32) // gran + 3
|
|
noise = [np.random.randn(bb[0], bb[1], bb[2]).astype('float32') for _ in range(3)]
|
|
noise = [scipy.ndimage.filters.convolve(n, blur0, mode='constant', cval=0) for n in noise]
|
|
noise = [scipy.ndimage.filters.convolve(n, blur1, mode='constant', cval=0) for n in noise]
|
|
noise = [scipy.ndimage.filters.convolve(n, blur2, mode='constant', cval=0) for n in noise]
|
|
noise = [scipy.ndimage.filters.convolve(n, blur0, mode='constant', cval=0) for n in noise]
|
|
noise = [scipy.ndimage.filters.convolve(n, blur1, mode='constant', cval=0) for n in noise]
|
|
noise = [scipy.ndimage.filters.convolve(n, blur2, mode='constant', cval=0) for n in noise]
|
|
ax = [np.linspace(-(b - 1) * gran, (b - 1) * gran, b) for b in bb]
|
|
interp = [
|
|
scipy.interpolate.RegularGridInterpolator(ax, n, bounds_error=0, fill_value=0)
|
|
for n in noise
|
|
]
|
|
|
|
def g(x_):
|
|
return np.hstack([i(x_)[:, None] for i in interp])
|
|
|
|
return x + g(x) * mag
|
|
|
|
def getInstanceInfo(self, xyz, instance_label, semantic_label):
|
|
pt_mean = np.ones((xyz.shape[0], 3), dtype=np.float32) * -100.0
|
|
instance_pointnum = []
|
|
instance_cls = []
|
|
instance_num = int(instance_label.max()) + 1
|
|
for i_ in range(instance_num):
|
|
inst_idx_i = np.where(instance_label == i_)
|
|
xyz_i = xyz[inst_idx_i]
|
|
pt_mean[inst_idx_i] = xyz_i.mean(0)
|
|
instance_pointnum.append(inst_idx_i[0].size)
|
|
cls_idx = inst_idx_i[0][0]
|
|
instance_cls.append(semantic_label[cls_idx])
|
|
pt_offset_label = pt_mean - xyz
|
|
return instance_num, instance_pointnum, instance_cls, pt_offset_label
|
|
|
|
def dataAugment(self, xyz, jitter=False, flip=False, rot=False, prob=0.9):
|
|
m = np.eye(3)
|
|
if jitter and np.random.rand() < prob:
|
|
m += np.random.randn(3, 3) * 0.1
|
|
if flip and np.random.rand() < prob:
|
|
m[0][0] *= np.random.randint(0, 2) * 2 - 1 # flip x randomly
|
|
if rot and np.random.rand() < prob:
|
|
theta = np.random.rand() * 2 * math.pi
|
|
m = np.matmul(m, [[math.cos(theta), math.sin(theta), 0],
|
|
[-math.sin(theta), math.cos(theta), 0], [0, 0, 1]]) # rotation
|
|
return np.matmul(xyz, m)
|
|
|
|
def crop(self, xyz, step=32):
|
|
xyz_offset = xyz.copy()
|
|
valid_idxs = xyz_offset.min(1) >= 0
|
|
assert valid_idxs.sum() == xyz.shape[0]
|
|
spatial_shape = np.array([self.voxel_cfg.spatial_shape[1]] * 3)
|
|
room_range = xyz.max(0) - xyz.min(0)
|
|
while (valid_idxs.sum() > self.voxel_cfg.max_npoint):
|
|
step_temp = step
|
|
if valid_idxs.sum() > 1e6:
|
|
step_temp = step * 2
|
|
offset = np.clip(spatial_shape - room_range + 0.001, None, 0) * np.random.rand(3)
|
|
xyz_offset = xyz + offset
|
|
valid_idxs = (xyz_offset.min(1) >= 0) * ((xyz_offset < spatial_shape).sum(1) == 3)
|
|
spatial_shape[:2] -= step_temp
|
|
return xyz_offset, valid_idxs
|
|
|
|
def getCroppedInstLabel(self, instance_label, valid_idxs):
|
|
instance_label = instance_label[valid_idxs]
|
|
j = 0
|
|
while (j < instance_label.max()):
|
|
if (len(np.where(instance_label == j)[0]) == 0):
|
|
instance_label[instance_label == instance_label.max()] = j
|
|
j += 1
|
|
return instance_label
|
|
|
|
def transform_train(self, xyz, rgb, semantic_label, instance_label, aug_prob=0.9):
|
|
xyz_middle = self.dataAugment(xyz, True, True, True, aug_prob)
|
|
xyz = xyz_middle * self.voxel_cfg.scale
|
|
if np.random.rand() < aug_prob:
|
|
xyz = self.elastic(xyz, 6 * self.voxel_cfg.scale // 50, 40 * self.voxel_cfg.scale / 50)
|
|
xyz = self.elastic(xyz, 20 * self.voxel_cfg.scale // 50,
|
|
160 * self.voxel_cfg.scale / 50)
|
|
xyz_middle = xyz / self.voxel_cfg.scale
|
|
xyz = xyz - xyz.min(0)
|
|
max_tries = 5
|
|
while (max_tries > 0):
|
|
xyz_offset, valid_idxs = self.crop(xyz)
|
|
if valid_idxs.sum() >= self.voxel_cfg.min_npoint:
|
|
xyz = xyz_offset
|
|
break
|
|
max_tries -= 1
|
|
if valid_idxs.sum() < self.voxel_cfg.min_npoint:
|
|
return None
|
|
xyz = xyz[valid_idxs]
|
|
xyz_middle = xyz_middle[valid_idxs]
|
|
rgb = rgb[valid_idxs]
|
|
semantic_label = semantic_label[valid_idxs]
|
|
instance_label = self.getCroppedInstLabel(instance_label, valid_idxs)
|
|
return xyz, xyz_middle, rgb, semantic_label, instance_label
|
|
|
|
def transform_test(self, xyz, rgb, semantic_label, instance_label):
|
|
xyz_middle = self.dataAugment(xyz, False, False, False)
|
|
xyz = xyz_middle * self.voxel_cfg.scale
|
|
xyz -= xyz.min(0)
|
|
valid_idxs = np.ones(xyz.shape[0], dtype=bool)
|
|
instance_label = self.getCroppedInstLabel(instance_label, valid_idxs)
|
|
return xyz, xyz_middle, rgb, semantic_label, instance_label
|
|
|
|
def __getitem__(self, index):
|
|
filename = self.filenames[index]
|
|
scan_id = osp.basename(filename).replace(self.suffix, '')
|
|
data = self.load(filename)
|
|
data = self.transform_train(*data) if self.training else self.transform_test(*data)
|
|
if data is None:
|
|
return None
|
|
xyz, xyz_middle, rgb, semantic_label, instance_label = data
|
|
info = self.getInstanceInfo(xyz_middle, instance_label.astype(np.int32), semantic_label)
|
|
inst_num, inst_pointnum, inst_cls, pt_offset_label = info
|
|
coord = torch.from_numpy(xyz).long()
|
|
coord_float = torch.from_numpy(xyz_middle)
|
|
feat = torch.from_numpy(rgb).float()
|
|
if self.training:
|
|
feat += torch.randn(3) * 0.1
|
|
semantic_label = torch.from_numpy(semantic_label)
|
|
instance_label = torch.from_numpy(instance_label)
|
|
pt_offset_label = torch.from_numpy(pt_offset_label)
|
|
return (scan_id, coord, coord_float, feat, semantic_label, instance_label, inst_num,
|
|
inst_pointnum, inst_cls, pt_offset_label)
|
|
|
|
def collate_fn(self, batch):
|
|
scan_ids = []
|
|
coords = []
|
|
coords_float = []
|
|
feats = []
|
|
semantic_labels = []
|
|
instance_labels = []
|
|
|
|
instance_pointnum = [] # (total_nInst), int
|
|
instance_cls = [] # (total_nInst), long
|
|
pt_offset_labels = []
|
|
|
|
total_inst_num = 0
|
|
batch_id = 0
|
|
for data in batch:
|
|
if data is None:
|
|
continue
|
|
(scan_id, coord, coord_float, feat, semantic_label, instance_label, inst_num,
|
|
inst_pointnum, inst_cls, pt_offset_label) = data
|
|
instance_label[np.where(instance_label != -100)] += total_inst_num
|
|
total_inst_num += inst_num
|
|
scan_ids.append(scan_id)
|
|
coords.append(torch.cat([coord.new_full((coord.size(0), 1), batch_id), coord], 1))
|
|
coords_float.append(coord_float)
|
|
feats.append(feat)
|
|
semantic_labels.append(semantic_label)
|
|
instance_labels.append(instance_label)
|
|
instance_pointnum.extend(inst_pointnum)
|
|
instance_cls.extend(inst_cls)
|
|
pt_offset_labels.append(pt_offset_label)
|
|
batch_id += 1
|
|
assert batch_id > 0, 'empty batch'
|
|
if batch_id < len(batch):
|
|
self.logger.info(f'batch is truncated from size {len(batch)} to {batch_id}')
|
|
|
|
# merge all the scenes in the batch
|
|
coords = torch.cat(coords, 0) # long (N, 1 + 3), the batch item idx is put in coords[:, 0]
|
|
batch_idxs = coords[:, 0].int()
|
|
coords_float = torch.cat(coords_float, 0).to(torch.float32) # float (N, 3)
|
|
feats = torch.cat(feats, 0) # float (N, C)
|
|
semantic_labels = torch.cat(semantic_labels, 0).long() # long (N)
|
|
instance_labels = torch.cat(instance_labels, 0).long() # long (N)
|
|
instance_pointnum = torch.tensor(instance_pointnum, dtype=torch.int) # int (total_nInst)
|
|
instance_cls = torch.tensor(instance_cls, dtype=torch.long) # long (total_nInst)
|
|
pt_offset_labels = torch.cat(pt_offset_labels).float()
|
|
|
|
spatial_shape = np.clip(
|
|
coords.max(0)[0][1:].numpy() + 1, self.voxel_cfg.spatial_shape[0], None)
|
|
voxel_coords, v2p_map, p2v_map = voxelization_idx(coords, 1)
|
|
return {
|
|
'scan_ids': scan_ids,
|
|
'coords': coords,
|
|
'batch_idxs': batch_idxs,
|
|
'voxel_coords': voxel_coords,
|
|
'p2v_map': p2v_map,
|
|
'v2p_map': v2p_map,
|
|
'coords_float': coords_float,
|
|
'feats': feats,
|
|
'semantic_labels': semantic_labels,
|
|
'instance_labels': instance_labels,
|
|
'instance_pointnum': instance_pointnum,
|
|
'instance_cls': instance_cls,
|
|
'pt_offset_labels': pt_offset_labels,
|
|
'spatial_shape': spatial_shape,
|
|
'batch_size': batch_id,
|
|
}
|