from utils.mapping import get_mapping_arr, get_named_joint, get_named_joints import time import torch import torch.nn.functional as F import torch.nn as nn import numpy as np from smplx.joint_names import JOINT_NAMES # -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2019 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems and the Max Planck Institute for Biological # Cybernetics. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de class SMPLifyAnglePrior(nn.Module): def __init__(self, dtype=torch.float32, **kwargs): super(SMPLifyAnglePrior, self).__init__() # Indices for the roration angle of # 55: left elbow, 90deg bend at -np.pi/2 # 58: right elbow, 90deg bend at np.pi/2 # 12: left knee, 90deg bend at np.pi/2 # 15: right knee, 90deg bend at np.pi/2 angle_prior_idxs = np.array([55, 58, 12, 15], dtype=np.int64) angle_prior_idxs = torch.tensor(angle_prior_idxs, dtype=torch.long) self.register_buffer('angle_prior_idxs', angle_prior_idxs) angle_prior_signs = np.array([1, -1, -1, -1], dtype=np.float32 if dtype == torch.float32 else np.float64) angle_prior_signs = torch.tensor(angle_prior_signs, dtype=dtype) self.register_buffer('angle_prior_signs', angle_prior_signs) def forward(self, pose, with_global_pose=False): ''' Returns the angle prior loss for the given pose Args: pose: (Bx[23 + 1] * 3) torch tensor with the axis-angle representation of the rotations of the joints of the SMPL model. Kwargs: with_global_pose: Whether the pose vector also contains the global orientation of the SMPL model. If not then the indices must be corrected. Returns: A sze (B) tensor containing the angle prior loss for each element in the batch. ''' angle_prior_idxs = self.angle_prior_idxs - (not with_global_pose) * 3 return torch.exp(pose[:, angle_prior_idxs] * self.angle_prior_signs).pow(2)