mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class AnglePriorsLoss(nn.Module):
|
|
def __init__(
|
|
self,
|
|
device=torch.device('cpu'),
|
|
dtype=torch.float32,
|
|
angle_idx=[56, 53, 12, 9],
|
|
directions=[1, -1, -1, -1],
|
|
weights=[1.0, 1.0, 1.0, 1.0]
|
|
):
|
|
super(AnglePriorsLoss, self).__init__()
|
|
|
|
# angles determined based on
|
|
# angles currently only work with SMPL-X since indices may differ
|
|
angles_idx = torch.tensor(
|
|
angle_idx, dtype=torch.int64).to(device=device)
|
|
self.register_buffer("angle_idx",
|
|
angles_idx)
|
|
|
|
# list of proposed directions
|
|
angle_directions = torch.tensor(
|
|
directions, dtype=dtype).to(device=device)
|
|
self.register_buffer("angle_directions",
|
|
angle_directions)
|
|
|
|
# create buffer for weights
|
|
self.register_buffer(
|
|
"weights",
|
|
torch.tensor(weights, dtype=dtype).to(device=device)
|
|
)
|
|
|
|
def forward(self, pose):
|
|
# compute direction deviation from expected joint rotation directions,
|
|
# e.g. don't rotate the knee joint forwards. Broken knees are not fun.
|
|
|
|
# get relevant angles
|
|
angles = pose[:, self.angle_idx]
|
|
|
|
# compute cost based not exponential of angle * direction
|
|
return torch.exp(angles * self.angle_directions).pow(2).sum()
|