mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
WIP: fix mapping issues for SMPLX
This commit is contained in:
parent
69d5256fab
commit
00d3dbe1c7
106
modules/pose.py
106
modules/pose.py
@ -1,21 +1,23 @@
|
|||||||
from model import VPoserModel
|
from model import VPoserModel
|
||||||
from modules.camera import SimpleCamera
|
from modules.camera import SimpleCamera
|
||||||
from renderer import Renderer
|
from renderer import Renderer
|
||||||
from utils.mapping import get_mapping_arr
|
from utils.mapping import get_mapping_arr, get_named_joint, get_named_joints
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from smplx.joint_names import JOINT_NAMES
|
||||||
from smplx import SMPL
|
from smplx import SMPL
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torchgeometry as tgm
|
import torchgeometry as tgm
|
||||||
|
from human_body_prior.tools.model_loader import load_vposer
|
||||||
|
|
||||||
|
|
||||||
class BodyPose(nn.Module):
|
class BodyPose(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: SMPL,
|
model,
|
||||||
keypoint_conf=None,
|
keypoint_conf=None,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=None,
|
device=None,
|
||||||
@ -31,12 +33,25 @@ class BodyPose(nn.Module):
|
|||||||
# create valid joint filter
|
# create valid joint filter
|
||||||
filter = self.get_joint_filter()
|
filter = self.get_joint_filter()
|
||||||
self.register_buffer("filter", filter)
|
self.register_buffer("filter", filter)
|
||||||
|
# vp, ps = load_vposer("./vposer_v1_0")
|
||||||
|
# vp = vp.to(device=device)
|
||||||
|
# vp.requires_grad = True
|
||||||
|
# self.vp = vp
|
||||||
|
# poZ_body_sample = torch.from_numpy(
|
||||||
|
# np.random.randn(1, 32).astype(np.float32)).to(device=device)
|
||||||
|
|
||||||
|
# poZ = nn.Parameter(poZ_body_sample, requires_grad=True)
|
||||||
|
# self.register_parameter("poZ", poZ)
|
||||||
|
|
||||||
# attach SMPL pose tensor as parameter to the layer
|
# attach SMPL pose tensor as parameter to the layer
|
||||||
|
body_pose = torch.zeros(model.body_pose.shape,
|
||||||
|
dtype=dtype, device=device)
|
||||||
|
body_pose = nn.Parameter(body_pose, requires_grad=True)
|
||||||
|
self.register_parameter("body_pose", body_pose)
|
||||||
# body_pose = torch.zeros(model.body_pose.shape,
|
# body_pose = torch.zeros(model.body_pose.shape,
|
||||||
# dtype=dtype, device=device)
|
# dtype=dtype, device=device)
|
||||||
# body_pose = nn.Parameter(body_pose, requires_grad=True)
|
# body_pose = nn.Parameter(model.pose_body, requires_grad=True)
|
||||||
# self.register_parameter("pose", body_pose)
|
# self.register_parameter("body_pose", body_pose)
|
||||||
|
|
||||||
def get_joint_filter(self):
|
def get_joint_filter(self):
|
||||||
"""OpenPose and SMPL do not have fully matching joint positions,
|
"""OpenPose and SMPL do not have fully matching joint positions,
|
||||||
@ -48,18 +63,24 @@ class BodyPose(nn.Module):
|
|||||||
|
|
||||||
# create a list with 1s for used joints and 0 for ignored joints
|
# create a list with 1s for used joints and 0 for ignored joints
|
||||||
mapping = get_mapping_arr(output_format=self.model_type)
|
mapping = get_mapping_arr(output_format=self.model_type)
|
||||||
print(mapping.shape)
|
|
||||||
|
filter_shape = (len(mapping), 3)
|
||||||
|
|
||||||
filter = torch.zeros(
|
filter = torch.zeros(
|
||||||
(len(mapping), 3), dtype=self.dtype, device=self.device)
|
filter_shape, dtype=self.dtype, device=self.device)
|
||||||
for index, valid in enumerate(mapping > -1):
|
for index, valid in enumerate(mapping > -1):
|
||||||
if valid:
|
if valid:
|
||||||
filter[index] += 1
|
filter[index] += 1
|
||||||
|
|
||||||
|
# print("mapping:", get_named_joints(
|
||||||
|
# filter.detach().cpu().numpy(), ["shoulder-left", "hand-left", "elbow-left"]))
|
||||||
return filter
|
return filter
|
||||||
|
|
||||||
def forward(self, pose):
|
def forward(self, vpose_pose):
|
||||||
|
# pose_body = self.vp.decode(self.poZ, output_type='aa').view(-1, 63)
|
||||||
|
# pose_body.requires_grad = True
|
||||||
bode_output = self.model(
|
bode_output = self.model(
|
||||||
body_pose=pose
|
body_pose=self.body_pose + vpose_pose
|
||||||
)
|
)
|
||||||
|
|
||||||
# store model output for later renderer usage
|
# store model output for later renderer usage
|
||||||
@ -67,7 +88,11 @@ class BodyPose(nn.Module):
|
|||||||
|
|
||||||
joints = bode_output.joints
|
joints = bode_output.joints
|
||||||
# return a list with invalid joints set to zero
|
# return a list with invalid joints set to zero
|
||||||
return joints * self.filter.unsqueeze(0)
|
filtered_joints = joints * self.filter.unsqueeze(0)
|
||||||
|
|
||||||
|
# print("filtered:", filtered_joints.shape, get_named_joints(
|
||||||
|
# filtered_joints.detach().cpu().numpy().squeeze(), ["shoulder-left", "hand-left", "elbow-left"]))
|
||||||
|
return filtered_joints
|
||||||
|
|
||||||
|
|
||||||
def train_pose(
|
def train_pose(
|
||||||
@ -76,17 +101,31 @@ def train_pose(
|
|||||||
keypoint_conf,
|
keypoint_conf,
|
||||||
camera: SimpleCamera,
|
camera: SimpleCamera,
|
||||||
loss_layer=torch.nn.MSELoss(),
|
loss_layer=torch.nn.MSELoss(),
|
||||||
learning_rate=1e-1,
|
learning_rate=1e-3,
|
||||||
device=torch.device('cpu'),
|
device=torch.device('cuda'),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
renderer: Renderer = None,
|
renderer: Renderer = None,
|
||||||
optimizer=None,
|
optimizer=None,
|
||||||
iterations=25
|
iterations=60
|
||||||
):
|
):
|
||||||
|
|
||||||
|
# filter keypoints to only include desired components
|
||||||
|
mapping = get_mapping_arr(output_format="smplx")
|
||||||
|
|
||||||
|
filter_shape = (len(mapping), 3)
|
||||||
|
|
||||||
|
filter = np.zeros(filter_shape)
|
||||||
|
for index, valid in enumerate(mapping > -1):
|
||||||
|
if valid:
|
||||||
|
filter[index] += 1
|
||||||
|
keypoints = keypoints * filter
|
||||||
vposer = VPoserModel()
|
vposer = VPoserModel()
|
||||||
vposer_model = vposer.model
|
vposer_layer = vposer.model
|
||||||
vposer_model.poZ_body.required_grad = True
|
|
||||||
vposer_params = vposer.get_vposer_latens()
|
vposer_params = vposer.get_vposer_latens()
|
||||||
|
|
||||||
|
index = JOINT_NAMES.index("left_middle1")
|
||||||
|
print(index)
|
||||||
|
print("Keypoint:", keypoints.squeeze()[index])
|
||||||
# setup keypoint data
|
# setup keypoint data
|
||||||
keypoints = torch.tensor(keypoints).to(device=device, dtype=dtype)
|
keypoints = torch.tensor(keypoints).to(device=device, dtype=dtype)
|
||||||
keypoints_conf = torch.tensor(keypoint_conf).to(device)
|
keypoints_conf = torch.tensor(keypoint_conf).to(device)
|
||||||
@ -97,27 +136,41 @@ def train_pose(
|
|||||||
pose_layer = BodyPose(model, dtype=dtype, device=device).to(device)
|
pose_layer = BodyPose(model, dtype=dtype, device=device).to(device)
|
||||||
|
|
||||||
if optimizer is None:
|
if optimizer is None:
|
||||||
optimizer = torch.optim.LBFGS(
|
parameters = [pose_layer.body_pose, vposer_params]
|
||||||
vposer_model.parameters(), learning_rate)
|
optimizer = torch.optim.LBFGS(parameters, learning_rate)
|
||||||
#optimizer = torch.optim.Adam(pose_layer.parameters(), learning_rate)
|
# optimizer = torch.optim.Adam(parameters, learning_rate)
|
||||||
|
|
||||||
pbar = tqdm(total=iterations)
|
pbar = tqdm(total=iterations)
|
||||||
|
|
||||||
def predict():
|
def predict():
|
||||||
body = vposer_model()
|
body = vposer_layer()
|
||||||
pose = body.pose_body
|
poZ = body.poZ_body
|
||||||
print(pose)
|
|
||||||
|
|
||||||
# return joints based on current model state
|
# return joints based on current model state
|
||||||
body_joints = pose_layer(pose)
|
body_joints = pose_layer(body.pose_body)
|
||||||
|
|
||||||
# compute homogeneous coordinates and project them to 2D space
|
# compute homogeneous coordinates and project them to 2D space
|
||||||
# TODO: create custom cost function
|
|
||||||
|
|
||||||
points = tgm.convert_points_to_homogeneous(body_joints)
|
points = tgm.convert_points_to_homogeneous(body_joints)
|
||||||
points = camera(points).squeeze()
|
points = camera(points).squeeze()
|
||||||
|
|
||||||
return loss_layer(points, keypoints)
|
# TODO: create custom cost function
|
||||||
|
|
||||||
|
a = points.detach().cpu().numpy().squeeze()[index]
|
||||||
|
b = keypoints.detach().cpu().numpy().squeeze()[index]
|
||||||
|
|
||||||
|
# print(points)
|
||||||
|
|
||||||
|
print("j:", a)
|
||||||
|
print("k:", b)
|
||||||
|
|
||||||
|
print("loss:", -np.mean(a - b))
|
||||||
|
|
||||||
|
joint_loss = loss_layer(points, keypoints)
|
||||||
|
|
||||||
|
# apply pose prior loss.
|
||||||
|
prior_loss = poZ.pow(2).sum()
|
||||||
|
|
||||||
|
return joint_loss + prior_loss
|
||||||
|
|
||||||
def optim_closure():
|
def optim_closure():
|
||||||
if torch.is_grad_enabled():
|
if torch.is_grad_enabled():
|
||||||
@ -133,8 +186,9 @@ def train_pose(
|
|||||||
optimizer.step(optim_closure)
|
optimizer.step(optim_closure)
|
||||||
|
|
||||||
# LBFGS does not return the result, therefore we should rerun the model to get it
|
# LBFGS does not return the result, therefore we should rerun the model to get it
|
||||||
pred = predict()
|
with torch.no_grad():
|
||||||
loss = optim_closure()
|
pred = predict()
|
||||||
|
loss = optim_closure()
|
||||||
|
|
||||||
# if t % 5 == 0:
|
# if t % 5 == 0:
|
||||||
# time.sleep(5)
|
# time.sleep(5)
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import cv2
|
|||||||
import yaml
|
import yaml
|
||||||
import os.path
|
import os.path
|
||||||
|
|
||||||
|
|
||||||
def load_config():
|
def load_config():
|
||||||
with open('./config.yaml') as file:
|
with open('./config.yaml') as file:
|
||||||
# The FullLoader parameter handles the conversion from YAML
|
# The FullLoader parameter handles the conversion from YAML
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
from trimesh.triangles import normals
|
from trimesh.triangles import normals
|
||||||
|
|
||||||
openpose_to_smpl = np.array([
|
openpose_to_smpl = np.array([
|
||||||
8, # hip - middle
|
8, # hip - middle / pelvis
|
||||||
12, # hip - right
|
12, # hip - right
|
||||||
9, # hip - left
|
9, # hip - left
|
||||||
-1, # body center (belly, not present in body_25)
|
-1, # body center (belly, not present in body_25)
|
||||||
@ -64,9 +64,10 @@ def get_mapping_arr(
|
|||||||
return openpose_to_smpl
|
return openpose_to_smpl
|
||||||
if output_format == "smplx":
|
if output_format == "smplx":
|
||||||
# create a list of length 127 and pad all values beyond 47 with -1 since we do not perform face and finger detection
|
# create a list of length 127 and pad all values beyond 47 with -1 since we do not perform face and finger detection
|
||||||
new = np.pad(openpose_to_smpl,
|
new = np.pad(
|
||||||
(0, 127-len(openpose_to_smpl)), constant_values=(0, -1))
|
openpose_to_smpl,
|
||||||
print(openpose_to_smpl, new)
|
(0, 127-len(openpose_to_smpl)),
|
||||||
|
constant_values=(0, -1))
|
||||||
return new
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user