body-pose-animation/modules/angle_clip.py
2021-02-17 13:18:15 +01:00

45 lines
1.3 KiB
Python

import torch
import torch.nn as nn
import numpy as np
class AngleClipper(nn.Module):
def __init__(
self,
device=torch.device('cpu'),
dtype=torch.float32,
angle_idx=[24, 25, 26, 10, 11, 13, 14, 9, 1, 56, 53, 12, 9, 37, 40],
# directions=[-1, 1, 1, 1],
weight=1
):
super(AngleClipper, self).__init__()
self.has_parameters = False
# angles determined based on
# angles currently only work with SMPL-X since indices may differ
angles_idx = torch.tensor(
angle_idx, dtype=torch.int64).to(device=device)
self.register_buffer("angle_idx",
angles_idx)
# # list of proposed directions
limit = torch.tensor(np.pi, dtype=dtype).to(device=device)
self.register_buffer("limit",
limit)
# create buffer for weights
self.register_buffer(
"weight",
torch.tensor(weight, dtype=dtype).to(device=device)
)
def forward(self, pose, joints, points, keypoints, raw_output):
angles = pose[:, self.angle_idx]
penalty = angles[torch.abs(angles) > self.limit]
# get relevant angles
return penalty.pow(2).sum() * self.weight