mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
add train aug_prob remove val aug
This commit is contained in:
parent
3332a478a7
commit
40572e2834
@ -85,13 +85,13 @@ class CustomDataset(Dataset):
|
|||||||
pt_offset_label = pt_mean - xyz
|
pt_offset_label = pt_mean - xyz
|
||||||
return instance_num, instance_pointnum, instance_cls, pt_offset_label
|
return instance_num, instance_pointnum, instance_cls, pt_offset_label
|
||||||
|
|
||||||
def dataAugment(self, xyz, jitter=False, flip=False, rot=False):
|
def dataAugment(self, xyz, jitter=False, flip=False, rot=False, prob=0.9):
|
||||||
m = np.eye(3)
|
m = np.eye(3)
|
||||||
if jitter:
|
if jitter and np.random.rand() < prob:
|
||||||
m += np.random.randn(3, 3) * 0.1
|
m += np.random.randn(3, 3) * 0.1
|
||||||
if flip:
|
if flip and np.random.rand() < prob:
|
||||||
m[0][0] *= np.random.randint(0, 2) * 2 - 1 # flip x randomly
|
m[0][0] *= np.random.randint(0, 2) * 2 - 1 # flip x randomly
|
||||||
if rot:
|
if rot and np.random.rand() < prob:
|
||||||
theta = np.random.rand() * 2 * math.pi
|
theta = np.random.rand() * 2 * math.pi
|
||||||
m = np.matmul(m, [[math.cos(theta), math.sin(theta), 0],
|
m = np.matmul(m, [[math.cos(theta), math.sin(theta), 0],
|
||||||
[-math.sin(theta), math.cos(theta), 0], [0, 0, 1]]) # rotation
|
[-math.sin(theta), math.cos(theta), 0], [0, 0, 1]]) # rotation
|
||||||
@ -122,12 +122,15 @@ class CustomDataset(Dataset):
|
|||||||
j += 1
|
j += 1
|
||||||
return instance_label
|
return instance_label
|
||||||
|
|
||||||
def transform_train(self, xyz, rgb, semantic_label, instance_label):
|
def transform_train(self, xyz, rgb, semantic_label, instance_label, aug_prob=0.9):
|
||||||
xyz_middle = self.dataAugment(xyz, True, True, True)
|
xyz_middle = self.dataAugment(xyz, True, True, True, aug_prob)
|
||||||
xyz = xyz_middle * self.voxel_cfg.scale
|
xyz = xyz_middle * self.voxel_cfg.scale
|
||||||
xyz = self.elastic(xyz, 6 * self.voxel_cfg.scale // 50, 40 * self.voxel_cfg.scale / 50)
|
if np.random.rand() < aug_prob:
|
||||||
xyz = self.elastic(xyz, 20 * self.voxel_cfg.scale // 50, 160 * self.voxel_cfg.scale / 50)
|
xyz = self.elastic(xyz, 6 * self.voxel_cfg.scale // 50, 40 * self.voxel_cfg.scale / 50)
|
||||||
xyz -= xyz.min(0)
|
xyz = self.elastic(xyz, 20 * self.voxel_cfg.scale // 50,
|
||||||
|
160 * self.voxel_cfg.scale / 50)
|
||||||
|
xyz_middle = xyz / self.voxel_cfg.scale
|
||||||
|
xyz = xyz - xyz.min(0)
|
||||||
max_tries = 5
|
max_tries = 5
|
||||||
while (max_tries > 0):
|
while (max_tries > 0):
|
||||||
xyz_offset, valid_idxs = self.crop(xyz)
|
xyz_offset, valid_idxs = self.crop(xyz)
|
||||||
@ -145,11 +148,11 @@ class CustomDataset(Dataset):
|
|||||||
return xyz, xyz_middle, rgb, semantic_label, instance_label
|
return xyz, xyz_middle, rgb, semantic_label, instance_label
|
||||||
|
|
||||||
def transform_test(self, xyz, rgb, semantic_label, instance_label):
|
def transform_test(self, xyz, rgb, semantic_label, instance_label):
|
||||||
xyz_middle = self.dataAugment(xyz, False, True, True)
|
xyz_middle = self.dataAugment(xyz, False, False, False)
|
||||||
xyz = xyz_middle * self.voxel_cfg.scale
|
xyz = xyz_middle * self.voxel_cfg.scale
|
||||||
xyz -= xyz.min(0)
|
xyz -= xyz.min(0)
|
||||||
valid_idxs = np.ones(xyz.shape[0], dtype=bool)
|
valid_idxs = np.ones(xyz.shape[0], dtype=bool)
|
||||||
instance_label = self.getCroppedInstLabel(instance_label, valid_idxs) # TODO remove this
|
instance_label = self.getCroppedInstLabel(instance_label, valid_idxs)
|
||||||
return xyz, xyz_middle, rgb, semantic_label, instance_label
|
return xyz, xyz_middle, rgb, semantic_label, instance_label
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user