body-pose-animation/modules/pose.py
2021-02-07 17:09:09 +01:00

44 lines
1.2 KiB
Python

import torch
import torch.nn as nn
from smplx.joint_names import JOINT_NAMES
from smplx import SMPL, body_models
class BodyPose(nn.Module):
def __init__(
self,
model,
keypoint_conf=None,
dtype=torch.float32,
device=None,
model_type="smplx",
useBodyMeanAngles=True
):
super(BodyPose, self).__init__()
self.dtype = dtype
self.device = device
self.model = model
self.model_type = model_type
self.useBodyMeanAngles = useBodyMeanAngles
# attach SMPL pose tensor as parameter to the layer
body_pose = torch.zeros(model.body_pose.shape,
dtype=dtype, device=device)
body_pose = nn.Parameter(body_pose, requires_grad=True)
self.register_parameter("body_pose", body_pose)
def forward(self, pose_extra=None):
pose_in = self.body_pose
if pose_extra is not None:
pose_in = pose_in + pose_extra
bode_output = self.model(
return_full_pose=True,
body_pose=pose_in
)
# store model output for later renderer usage
self.cur_out = bode_output
return bode_output.joints, bode_output.body_pose