From a4abf23d68c3f95f99da7770458921d1ae6438df Mon Sep 17 00:00:00 2001 From: Wlad <9556979+gosticks@users.noreply.github.com> Date: Mon, 8 Feb 2021 21:02:37 +0100 Subject: [PATCH] add further angles --- modules/angle.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/angle.py b/modules/angle.py index 8f82aba..6a32d44 100644 --- a/modules/angle.py +++ b/modules/angle.py @@ -7,9 +7,9 @@ class AnglePriorsLoss(nn.Module): self, device=torch.device('cpu'), dtype=torch.float32, - angle_idx=[56, 53, 12, 9], - directions=[1, -1, -1, -1], - weights=[1.0, 1.0, 1.0, 1.0] + angle_idx=[56, 53, 12, 9, 37, 40], + directions=[1, -1, -1, -1, 1, -1], + weights=[1.0, 1.0, 0.8, 0.8, 0.02, 0.02] ): super(AnglePriorsLoss, self).__init__() @@ -40,4 +40,4 @@ class AnglePriorsLoss(nn.Module): angles = pose[:, self.angle_idx] # compute cost based not exponential of angle * direction - return torch.exp(angles * self.angle_directions).pow(2).sum() + return (torch.exp(angles * self.angle_directions) * self.weights).pow(2).sum()