mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
63 lines
2.7 KiB
Python
63 lines
2.7 KiB
Python
from utils.mapping import get_mapping_arr, get_named_joint, get_named_joints
|
|
import time
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from smplx.joint_names import JOINT_NAMES
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
|
# holder of all proprietary rights on this computer program.
|
|
# You can only use this computer program if you have closed
|
|
# a license agreement with MPG or you get the right to use the computer
|
|
# program from someone who is authorized to grant you that right.
|
|
# Any use of the computer program without a valid license is prohibited and
|
|
# liable to prosecution.
|
|
#
|
|
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
|
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
|
# for Intelligent Systems and the Max Planck Institute for Biological
|
|
# Cybernetics. All rights reserved.
|
|
#
|
|
# Contact: ps-license@tuebingen.mpg.de
|
|
|
|
|
|
class SMPLifyAnglePrior(nn.Module):
|
|
def __init__(self, dtype=torch.float32, **kwargs):
|
|
super(SMPLifyAnglePrior, self).__init__()
|
|
|
|
# Indices for the roration angle of
|
|
# 55: left elbow, 90deg bend at -np.pi/2
|
|
# 58: right elbow, 90deg bend at np.pi/2
|
|
# 12: left knee, 90deg bend at np.pi/2
|
|
# 15: right knee, 90deg bend at np.pi/2
|
|
angle_prior_idxs = np.array([55, 58, 12, 15], dtype=np.int64)
|
|
angle_prior_idxs = torch.tensor(angle_prior_idxs, dtype=torch.long)
|
|
self.register_buffer('angle_prior_idxs', angle_prior_idxs)
|
|
|
|
angle_prior_signs = np.array([1, -1, -1, -1],
|
|
dtype=np.float32 if dtype == torch.float32
|
|
else np.float64)
|
|
angle_prior_signs = torch.tensor(angle_prior_signs,
|
|
dtype=dtype)
|
|
self.register_buffer('angle_prior_signs', angle_prior_signs)
|
|
|
|
def forward(self, pose, with_global_pose=False):
|
|
''' Returns the angle prior loss for the given pose
|
|
Args:
|
|
pose: (Bx[23 + 1] * 3) torch tensor with the axis-angle
|
|
representation of the rotations of the joints of the SMPL model.
|
|
Kwargs:
|
|
with_global_pose: Whether the pose vector also contains the global
|
|
orientation of the SMPL model. If not then the indices must be
|
|
corrected.
|
|
Returns:
|
|
A sze (B) tensor containing the angle prior loss for each element
|
|
in the batch.
|
|
'''
|
|
angle_prior_idxs = self.angle_prior_idxs - (not with_global_pose) * 3
|
|
return torch.exp(pose[:, angle_prior_idxs] *
|
|
self.angle_prior_signs).pow(2)
|