diff --git a/softgroup/data/custom.py b/softgroup/data/custom.py index d62ce70..0c34fea 100644 --- a/softgroup/data/custom.py +++ b/softgroup/data/custom.py @@ -85,13 +85,13 @@ class CustomDataset(Dataset): pt_offset_label = pt_mean - xyz return instance_num, instance_pointnum, instance_cls, pt_offset_label - def dataAugment(self, xyz, jitter=False, flip=False, rot=False): + def dataAugment(self, xyz, jitter=False, flip=False, rot=False, prob=0.9): m = np.eye(3) - if jitter: + if jitter and np.random.rand() < prob: m += np.random.randn(3, 3) * 0.1 - if flip: + if flip and np.random.rand() < prob: m[0][0] *= np.random.randint(0, 2) * 2 - 1 # flip x randomly - if rot: + if rot and np.random.rand() < prob: theta = np.random.rand() * 2 * math.pi m = np.matmul(m, [[math.cos(theta), math.sin(theta), 0], [-math.sin(theta), math.cos(theta), 0], [0, 0, 1]]) # rotation @@ -122,12 +122,15 @@ class CustomDataset(Dataset): j += 1 return instance_label - def transform_train(self, xyz, rgb, semantic_label, instance_label): - xyz_middle = self.dataAugment(xyz, True, True, True) + def transform_train(self, xyz, rgb, semantic_label, instance_label, aug_prob=0.9): + xyz_middle = self.dataAugment(xyz, True, True, True, aug_prob) xyz = xyz_middle * self.voxel_cfg.scale - xyz = self.elastic(xyz, 6 * self.voxel_cfg.scale // 50, 40 * self.voxel_cfg.scale / 50) - xyz = self.elastic(xyz, 20 * self.voxel_cfg.scale // 50, 160 * self.voxel_cfg.scale / 50) - xyz -= xyz.min(0) + if np.random.rand() < aug_prob: + xyz = self.elastic(xyz, 6 * self.voxel_cfg.scale // 50, 40 * self.voxel_cfg.scale / 50) + xyz = self.elastic(xyz, 20 * self.voxel_cfg.scale // 50, + 160 * self.voxel_cfg.scale / 50) + xyz_middle = xyz / self.voxel_cfg.scale + xyz = xyz - xyz.min(0) max_tries = 5 while (max_tries > 0): xyz_offset, valid_idxs = self.crop(xyz) @@ -145,11 +148,11 @@ class CustomDataset(Dataset): return xyz, xyz_middle, rgb, semantic_label, instance_label def transform_test(self, xyz, rgb, semantic_label, instance_label): - xyz_middle = self.dataAugment(xyz, False, True, True) + xyz_middle = self.dataAugment(xyz, False, False, False) xyz = xyz_middle * self.voxel_cfg.scale xyz -= xyz.min(0) valid_idxs = np.ones(xyz.shape[0], dtype=bool) - instance_label = self.getCroppedInstLabel(instance_label, valid_idxs) # TODO remove this + instance_label = self.getCroppedInstLabel(instance_label, valid_idxs) return xyz, xyz_middle, rgb, semantic_label, instance_label def __getitem__(self, index):