body-pose-animation/modules/priors.py
2021-02-02 17:41:44 +01:00

63 lines
2.7 KiB
Python

from utils.mapping import get_mapping_arr, get_named_joint, get_named_joints
import time
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from smplx.joint_names import JOINT_NAMES
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems and the Max Planck Institute for Biological
# Cybernetics. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
class SMPLifyAnglePrior(nn.Module):
def __init__(self, dtype=torch.float32, **kwargs):
super(SMPLifyAnglePrior, self).__init__()
# Indices for the roration angle of
# 55: left elbow, 90deg bend at -np.pi/2
# 58: right elbow, 90deg bend at np.pi/2
# 12: left knee, 90deg bend at np.pi/2
# 15: right knee, 90deg bend at np.pi/2
angle_prior_idxs = np.array([55, 58, 12, 15], dtype=np.int64)
angle_prior_idxs = torch.tensor(angle_prior_idxs, dtype=torch.long)
self.register_buffer('angle_prior_idxs', angle_prior_idxs)
angle_prior_signs = np.array([1, -1, -1, -1],
dtype=np.float32 if dtype == torch.float32
else np.float64)
angle_prior_signs = torch.tensor(angle_prior_signs,
dtype=dtype)
self.register_buffer('angle_prior_signs', angle_prior_signs)
def forward(self, pose, with_global_pose=False):
''' Returns the angle prior loss for the given pose
Args:
pose: (Bx[23 + 1] * 3) torch tensor with the axis-angle
representation of the rotations of the joints of the SMPL model.
Kwargs:
with_global_pose: Whether the pose vector also contains the global
orientation of the SMPL model. If not then the indices must be
corrected.
Returns:
A sze (B) tensor containing the angle prior loss for each element
in the batch.
'''
angle_prior_idxs = self.angle_prior_idxs - (not with_global_pose) * 3
return torch.exp(pose[:, angle_prior_idxs] *
self.angle_prior_signs).pow(2)