diff --git a/modules/angle.py b/modules/angle.py index 8f82aba..6a32d44 100644 --- a/modules/angle.py +++ b/modules/angle.py @@ -7,9 +7,9 @@ class AnglePriorsLoss(nn.Module): self, device=torch.device('cpu'), dtype=torch.float32, - angle_idx=[56, 53, 12, 9], - directions=[1, -1, -1, -1], - weights=[1.0, 1.0, 1.0, 1.0] + angle_idx=[56, 53, 12, 9, 37, 40], + directions=[1, -1, -1, -1, 1, -1], + weights=[1.0, 1.0, 0.8, 0.8, 0.02, 0.02] ): super(AnglePriorsLoss, self).__init__() @@ -40,4 +40,4 @@ class AnglePriorsLoss(nn.Module): angles = pose[:, self.angle_idx] # compute cost based not exponential of angle * direction - return torch.exp(angles * self.angle_directions).pow(2).sum() + return (torch.exp(angles * self.angle_directions) * self.weights).pow(2).sum()