mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
add further angles
This commit is contained in:
parent
9ad19c3569
commit
a4abf23d68
@ -7,9 +7,9 @@ class AnglePriorsLoss(nn.Module):
|
|||||||
self,
|
self,
|
||||||
device=torch.device('cpu'),
|
device=torch.device('cpu'),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
angle_idx=[56, 53, 12, 9],
|
angle_idx=[56, 53, 12, 9, 37, 40],
|
||||||
directions=[1, -1, -1, -1],
|
directions=[1, -1, -1, -1, 1, -1],
|
||||||
weights=[1.0, 1.0, 1.0, 1.0]
|
weights=[1.0, 1.0, 0.8, 0.8, 0.02, 0.02]
|
||||||
):
|
):
|
||||||
super(AnglePriorsLoss, self).__init__()
|
super(AnglePriorsLoss, self).__init__()
|
||||||
|
|
||||||
@ -40,4 +40,4 @@ class AnglePriorsLoss(nn.Module):
|
|||||||
angles = pose[:, self.angle_idx]
|
angles = pose[:, self.angle_idx]
|
||||||
|
|
||||||
# compute cost based not exponential of angle * direction
|
# compute cost based not exponential of angle * direction
|
||||||
return torch.exp(angles * self.angle_directions).pow(2).sum()
|
return (torch.exp(angles * self.angle_directions) * self.weights).pow(2).sum()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user