mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
44 lines
1.2 KiB
Python
44 lines
1.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from smplx.joint_names import JOINT_NAMES
|
|
from smplx import SMPL, body_models
|
|
|
|
|
|
class BodyPose(nn.Module):
|
|
def __init__(
|
|
self,
|
|
model,
|
|
keypoint_conf=None,
|
|
dtype=torch.float32,
|
|
device=None,
|
|
model_type="smplx",
|
|
useBodyMeanAngles=True
|
|
):
|
|
super(BodyPose, self).__init__()
|
|
self.dtype = dtype
|
|
self.device = device
|
|
self.model = model
|
|
self.model_type = model_type
|
|
self.useBodyMeanAngles = useBodyMeanAngles
|
|
|
|
# attach SMPL pose tensor as parameter to the layer
|
|
body_pose = torch.zeros(model.body_pose.shape,
|
|
dtype=dtype, device=device)
|
|
body_pose = nn.Parameter(body_pose, requires_grad=True)
|
|
self.register_parameter("body_pose", body_pose)
|
|
|
|
def forward(self, pose_extra=None):
|
|
pose_in = self.body_pose
|
|
if pose_extra is not None:
|
|
pose_in = pose_in + pose_extra
|
|
|
|
bode_output = self.model(
|
|
return_full_pose=False,
|
|
body_pose=pose_in
|
|
)
|
|
|
|
# store model output for later renderer usage
|
|
self.cur_out = bode_output
|
|
|
|
return bode_output.joints, bode_output.body_pose
|