mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
WIP: vposer layer support
This commit is contained in:
parent
53f76db68a
commit
a9e7f221cc
2
.gitignore
vendored
2
.gitignore
vendored
@ -94,3 +94,5 @@ models/*
|
|||||||
.vscode
|
.vscode
|
||||||
tum-3d-proj
|
tum-3d-proj
|
||||||
reference
|
reference
|
||||||
|
|
||||||
|
vposer_v1_0
|
||||||
@ -37,7 +37,8 @@ class CameraEstimate:
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.image_path = image_path
|
self.image_path = image_path
|
||||||
self.keypoints = keypoints
|
self.keypoints = keypoints
|
||||||
self.scale = torch.tensor([est_scale,est_scale,est_scale], requires_grad=False, dtype=self.dtype, device=self.device)
|
self.scale = torch.tensor([est_scale, est_scale, est_scale],
|
||||||
|
requires_grad=False, dtype=self.dtype, device=self.device)
|
||||||
|
|
||||||
def get_torso_keypoints(self):
|
def get_torso_keypoints(self):
|
||||||
smpl_keypoints = self.output_model.joints.detach().cpu().numpy().squeeze()
|
smpl_keypoints = self.output_model.joints.detach().cpu().numpy().squeeze()
|
||||||
@ -73,7 +74,6 @@ class CameraEstimate:
|
|||||||
def setup_visualization(self, render_points, render_keypoints):
|
def setup_visualization(self, render_points, render_keypoints):
|
||||||
self.transformed_points = render_points
|
self.transformed_points = render_points
|
||||||
|
|
||||||
|
|
||||||
def sum_of_squares(self, params, X, Y):
|
def sum_of_squares(self, params, X, Y):
|
||||||
y_pred = self.loss_model(params, X)
|
y_pred = self.loss_model(params, X)
|
||||||
loss = np.sum((y_pred - Y) ** 2)
|
loss = np.sum((y_pred - Y) ** 2)
|
||||||
@ -167,8 +167,10 @@ class TorchCameraEstimate(CameraEstimate):
|
|||||||
pbar.update(per - current)
|
pbar.update(per - current)
|
||||||
current = per
|
current = per
|
||||||
stop = loss > tol
|
stop = loss > tol
|
||||||
if stop == True:
|
|
||||||
stop = self.patience_module(loss, 5)
|
# FIXME: same error as below
|
||||||
|
# if stop == True:
|
||||||
|
# stop = self.patience_module(loss, 5)
|
||||||
pbar.update(abs(100 - current))
|
pbar.update(abs(100 - current))
|
||||||
pbar.close()
|
pbar.close()
|
||||||
self.memory = None
|
self.memory = None
|
||||||
@ -204,31 +206,45 @@ class TorchCameraEstimate(CameraEstimate):
|
|||||||
|
|
||||||
stop = True
|
stop = True
|
||||||
first = True
|
first = True
|
||||||
cam_tol = 6e-5
|
cam_tol = 6e-3
|
||||||
print("Estimating Camera transformations...")
|
print("Estimating Camera transformations...")
|
||||||
pbar = tqdm(total=100)
|
pbar = tqdm(total=100)
|
||||||
current = 0
|
current = 0
|
||||||
|
|
||||||
while stop:
|
while stop:
|
||||||
y_pred = self.transform_3d_to_2d(
|
y_pred = self.transform_3d_to_2d(
|
||||||
params, init_points_3d_prepared)
|
params, init_points_3d_prepared)
|
||||||
loss = torch.nn.SmoothL1Loss()(init_points_2d.float(), y_pred.float())
|
loss = torch.nn.SmoothL1Loss()(init_points_2d.float(), y_pred.float())
|
||||||
loss.requres_grad = True
|
loss.requres_grad = True
|
||||||
opt2.zero_grad()
|
opt2.zero_grad()
|
||||||
|
|
||||||
if first:
|
if first:
|
||||||
loss.backward(retain_graph=True)
|
loss.backward(retain_graph=True)
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
opt2.step()
|
opt2.step()
|
||||||
self.renderer.scene.set_pose( self.camera_renderer, self.torch_params_to_pose(params).detach().numpy())
|
self.renderer.scene.set_pose(
|
||||||
|
self.camera_renderer, self.torch_params_to_pose(params).detach().numpy())
|
||||||
per = int((cam_tol/loss*100).item())
|
per = int((cam_tol/loss*100).item())
|
||||||
|
|
||||||
if per > 100:
|
if per > 100:
|
||||||
pbar.update(100 - current)
|
pbar.update(100 - current)
|
||||||
else:
|
else:
|
||||||
pbar.update(per - current)
|
pbar.update(per - current)
|
||||||
|
|
||||||
current = per
|
current = per
|
||||||
stop = loss > cam_tol
|
stop = loss > cam_tol
|
||||||
if stop == True:
|
|
||||||
stop = self.patience_module(loss, 5)
|
# FIXME: this does not work for me, here is the error
|
||||||
|
# TypeError: eq() received an invalid combination of arguments - got (NoneType), but expected one of:
|
||||||
|
# * (Tensor other)
|
||||||
|
# didn't match because some of the arguments have invalid types: (NoneType)
|
||||||
|
# * (Number other)
|
||||||
|
# didn't match because some of the arguments have invalid types: (NoneType)
|
||||||
|
|
||||||
|
# if stop == True:
|
||||||
|
# stop = self.patience_module(loss, 5)
|
||||||
|
|
||||||
pbar.update(100 - current)
|
pbar.update(100 - current)
|
||||||
pbar.close()
|
pbar.close()
|
||||||
camera_transform_matrix = self.torch_params_to_pose(
|
camera_transform_matrix = self.torch_params_to_pose(
|
||||||
@ -253,15 +269,15 @@ class TorchCameraEstimate(CameraEstimate):
|
|||||||
def torch_params_to_pose(self, params):
|
def torch_params_to_pose(self, params):
|
||||||
transform = rtvec_to_pose(
|
transform = rtvec_to_pose(
|
||||||
torch.cat((params[1], params[0])).view(-1).unsqueeze(0))
|
torch.cat((params[1], params[0])).view(-1).unsqueeze(0))
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
transform[0,i,i] *= self.scale[i]
|
transform[0, i, i] *= self.scale[i]
|
||||||
return transform[0, :, :]
|
return transform[0, :, :]
|
||||||
|
|
||||||
def C(self, params, X):
|
def C(self, params, X):
|
||||||
Ext_mat = rtvec_to_pose(
|
Ext_mat = rtvec_to_pose(
|
||||||
torch.cat((params[1], params[0])).view(-1).unsqueeze(0))
|
torch.cat((params[1], params[0])).view(-1).unsqueeze(0))
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
Ext_mat[0,i,i] *= self.scale[i]
|
Ext_mat[0, i, i] *= self.scale[i]
|
||||||
y_pred = Ext_mat @ X
|
y_pred = Ext_mat @ X
|
||||||
y_pred = y_pred.squeeze(2)
|
y_pred = y_pred.squeeze(2)
|
||||||
y_pred = y_pred[:, :3]
|
y_pred = y_pred[:, :3]
|
||||||
@ -276,7 +292,7 @@ class TorchCameraEstimate(CameraEstimate):
|
|||||||
|
|
||||||
def patience_module(self, variable, counter: int):
|
def patience_module(self, variable, counter: int):
|
||||||
if self.memory == None:
|
if self.memory == None:
|
||||||
self.memory=torch.clone(variable)
|
self.memory = torch.clone(variable)
|
||||||
self.patience_count = 0
|
self.patience_count = 0
|
||||||
return True
|
return True
|
||||||
if self.patience_count >= counter:
|
if self.patience_count >= counter:
|
||||||
@ -289,7 +305,7 @@ class TorchCameraEstimate(CameraEstimate):
|
|||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
self.patience_count = 0
|
self.patience_count = 0
|
||||||
self.memory=torch.clone(variable)
|
self.memory = torch.clone(variable)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# sample_index = 0
|
# sample_index = 0
|
||||||
|
|||||||
@ -127,7 +127,7 @@ for t in range(5000):
|
|||||||
camera_transf = trans.get_transform_mat(with_translate=True).detach().cpu()
|
camera_transf = trans.get_transform_mat(with_translate=True).detach().cpu()
|
||||||
print("final pose:", camera_transf.numpy())
|
print("final pose:", camera_transf.numpy())
|
||||||
|
|
||||||
camera = SimpleCamera(dtype, device, z_scale=1,
|
camera = SimpleCamera(dtype, device,
|
||||||
transform_mat=camera_transf)
|
transform_mat=camera_transf)
|
||||||
|
|
||||||
train_pose(
|
train_pose(
|
||||||
|
|||||||
@ -119,11 +119,11 @@ camera = TorchCameraEstimate(
|
|||||||
device=torch.device('cpu'),
|
device=torch.device('cpu'),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
image_path=img_path,
|
image_path=img_path,
|
||||||
est_scale= est_scale
|
est_scale=est_scale
|
||||||
)
|
)
|
||||||
pose, transform, cam_trans = camera.estimate_camera_pos()
|
pose, transform, cam_trans = camera.estimate_camera_pos()
|
||||||
|
|
||||||
camera.setup_visualization(render_points, render_keypoints )
|
camera.setup_visualization(render_points, render_keypoints)
|
||||||
|
|
||||||
|
|
||||||
# start renderer
|
# start renderer
|
||||||
@ -135,9 +135,9 @@ camera_transformation = transform.clone().detach().to(device=device, dtype=dtype
|
|||||||
camera_int = pose.clone().detach().to(device=device, dtype=dtype)
|
camera_int = pose.clone().detach().to(device=device, dtype=dtype)
|
||||||
camera_params = cam_trans.clone().detach().to(device=device, dtype=dtype)
|
camera_params = cam_trans.clone().detach().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
camera = SimpleCamera(dtype, device, z_scale=1,
|
camera = SimpleCamera(dtype, device,
|
||||||
transform_mat=camera_transformation,
|
transform_mat=camera_transformation,
|
||||||
# camera_intrinsics=camera_int, camera_trans_rot=camera_params
|
# camera_intrinsics=camera_int, camera_trans_rot=camera_params
|
||||||
)
|
)
|
||||||
|
|
||||||
r.set_group_pose("body", camera_transformation.detach().cpu().numpy())
|
r.set_group_pose("body", camera_transformation.detach().cpu().numpy())
|
||||||
|
|||||||
49
model.py
49
model.py
@ -1,15 +1,60 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import smplx
|
import smplx
|
||||||
|
from human_body_prior.body_model.body_model_vposer import BodyModelWithPoser
|
||||||
|
|
||||||
|
|
||||||
|
class VPoserModel():
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_type='smpl',
|
||||||
|
vposer_model_path="./vposer_v1_0",
|
||||||
|
ext='npz',
|
||||||
|
gender='neutral',
|
||||||
|
create_body_pose=True,
|
||||||
|
plot_joints=True,
|
||||||
|
num_betas=10,
|
||||||
|
sample_shape=False,
|
||||||
|
sample_expression=False,
|
||||||
|
num_expression_coeffs=10,
|
||||||
|
use_face_contour=False
|
||||||
|
):
|
||||||
|
self.vposer_model_path = vposer_model_path
|
||||||
|
self.model_type = model_type
|
||||||
|
self.ext = ext
|
||||||
|
self.gender = gender
|
||||||
|
self.plot_joints = plot_joints
|
||||||
|
self.num_betas = num_betas
|
||||||
|
self.sample_shape = sample_shape
|
||||||
|
self.sample_expression = sample_expression
|
||||||
|
self.num_expression_coeffs = num_expression_coeffs
|
||||||
|
self.create_body_pose = create_body_pose
|
||||||
|
|
||||||
|
self.create_model()
|
||||||
|
|
||||||
|
def create_model(self):
|
||||||
|
self.model = BodyModelWithPoser(
|
||||||
|
bm_path="./models/smplx/SMPLX_MALE.npz",
|
||||||
|
batch_size=1,
|
||||||
|
poser_type="vposer",
|
||||||
|
smpl_exp_dir=self.vposer_model_path
|
||||||
|
)
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def get_vposer_latens(self):
|
||||||
|
return self.model.poZ_body
|
||||||
|
|
||||||
|
def get_pose(self):
|
||||||
|
return self.model.pose_body
|
||||||
|
|
||||||
|
|
||||||
class SMPLyModel():
|
class SMPLyModel():
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_folder,
|
model_folder,
|
||||||
model_type='smpl',
|
model_type='smplx',
|
||||||
ext='npz',
|
ext='npz',
|
||||||
gender='neutral',
|
gender='male',
|
||||||
create_body_pose=True,
|
create_body_pose=True,
|
||||||
plot_joints=True,
|
plot_joints=True,
|
||||||
num_betas=10,
|
num_betas=10,
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from model import VPoserModel
|
||||||
from modules.camera import SimpleCamera
|
from modules.camera import SimpleCamera
|
||||||
from renderer import Renderer
|
from renderer import Renderer
|
||||||
from utils.mapping import get_mapping_arr
|
from utils.mapping import get_mapping_arr
|
||||||
@ -15,23 +16,27 @@ class BodyPose(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: SMPL,
|
model: SMPL,
|
||||||
|
keypoint_conf=None,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=None,
|
device=None,
|
||||||
|
model_type="smplx"
|
||||||
|
|
||||||
):
|
):
|
||||||
super(BodyPose, self).__init__()
|
super(BodyPose, self).__init__()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.model_type = model_type
|
||||||
|
|
||||||
# create valid joint filter
|
# create valid joint filter
|
||||||
filter = self.get_joint_filter()
|
filter = self.get_joint_filter()
|
||||||
self.register_buffer("filter", filter)
|
self.register_buffer("filter", filter)
|
||||||
|
|
||||||
# attach SMPL pose tensor as parameter to the layer
|
# attach SMPL pose tensor as parameter to the layer
|
||||||
body_pose = torch.zeros(model.body_pose.shape,
|
# body_pose = torch.zeros(model.body_pose.shape,
|
||||||
dtype=dtype, device=device)
|
# dtype=dtype, device=device)
|
||||||
body_pose = nn.Parameter(body_pose, requires_grad=True)
|
# body_pose = nn.Parameter(body_pose, requires_grad=True)
|
||||||
self.register_parameter("pose", body_pose)
|
# self.register_parameter("pose", body_pose)
|
||||||
|
|
||||||
def get_joint_filter(self):
|
def get_joint_filter(self):
|
||||||
"""OpenPose and SMPL do not have fully matching joint positions,
|
"""OpenPose and SMPL do not have fully matching joint positions,
|
||||||
@ -42,7 +47,8 @@ class BodyPose(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# create a list with 1s for used joints and 0 for ignored joints
|
# create a list with 1s for used joints and 0 for ignored joints
|
||||||
mapping = get_mapping_arr()
|
mapping = get_mapping_arr(output_format=self.model_type)
|
||||||
|
print(mapping.shape)
|
||||||
filter = torch.zeros(
|
filter = torch.zeros(
|
||||||
(len(mapping), 3), dtype=self.dtype, device=self.device)
|
(len(mapping), 3), dtype=self.dtype, device=self.device)
|
||||||
for index, valid in enumerate(mapping > -1):
|
for index, valid in enumerate(mapping > -1):
|
||||||
@ -51,15 +57,15 @@ class BodyPose(nn.Module):
|
|||||||
|
|
||||||
return filter
|
return filter
|
||||||
|
|
||||||
def forward(self):
|
def forward(self, pose):
|
||||||
bode_output = self.model(
|
bode_output = self.model(
|
||||||
body_pose=self.pose
|
body_pose=pose
|
||||||
)
|
)
|
||||||
|
|
||||||
# store model output for later renderer usage
|
# store model output for later renderer usage
|
||||||
self.cur_out = bode_output
|
self.cur_out = bode_output
|
||||||
|
|
||||||
joints = bode_output.joints
|
joints = bode_output.joints
|
||||||
|
|
||||||
# return a list with invalid joints set to zero
|
# return a list with invalid joints set to zero
|
||||||
return joints * self.filter.unsqueeze(0)
|
return joints * self.filter.unsqueeze(0)
|
||||||
|
|
||||||
@ -70,14 +76,17 @@ def train_pose(
|
|||||||
keypoint_conf,
|
keypoint_conf,
|
||||||
camera: SimpleCamera,
|
camera: SimpleCamera,
|
||||||
loss_layer=torch.nn.MSELoss(),
|
loss_layer=torch.nn.MSELoss(),
|
||||||
learning_rate=1e-3,
|
learning_rate=1e-1,
|
||||||
device=torch.device('cpu'),
|
device=torch.device('cpu'),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
renderer: Renderer = None,
|
renderer: Renderer = None,
|
||||||
optimizer=None,
|
optimizer=None,
|
||||||
iterations=25
|
iterations=25
|
||||||
):
|
):
|
||||||
|
vposer = VPoserModel()
|
||||||
|
vposer_model = vposer.model
|
||||||
|
vposer_model.poZ_body.required_grad = True
|
||||||
|
vposer_params = vposer.get_vposer_latens()
|
||||||
# setup keypoint data
|
# setup keypoint data
|
||||||
keypoints = torch.tensor(keypoints).to(device=device, dtype=dtype)
|
keypoints = torch.tensor(keypoints).to(device=device, dtype=dtype)
|
||||||
keypoints_conf = torch.tensor(keypoint_conf).to(device)
|
keypoints_conf = torch.tensor(keypoint_conf).to(device)
|
||||||
@ -88,14 +97,19 @@ def train_pose(
|
|||||||
pose_layer = BodyPose(model, dtype=dtype, device=device).to(device)
|
pose_layer = BodyPose(model, dtype=dtype, device=device).to(device)
|
||||||
|
|
||||||
if optimizer is None:
|
if optimizer is None:
|
||||||
optimizer = torch.optim.LBFGS([pose_layer.pose], learning_rate)
|
optimizer = torch.optim.LBFGS(
|
||||||
|
vposer_model.parameters(), learning_rate)
|
||||||
#optimizer = torch.optim.Adam(pose_layer.parameters(), learning_rate)
|
#optimizer = torch.optim.Adam(pose_layer.parameters(), learning_rate)
|
||||||
|
|
||||||
pbar = tqdm(total=iterations)
|
pbar = tqdm(total=iterations)
|
||||||
|
|
||||||
def predict():
|
def predict():
|
||||||
|
body = vposer_model()
|
||||||
|
pose = body.pose_body
|
||||||
|
print(pose)
|
||||||
|
|
||||||
# return joints based on current model state
|
# return joints based on current model state
|
||||||
body_joints = pose_layer()
|
body_joints = pose_layer(pose)
|
||||||
|
|
||||||
# compute homogeneous coordinates and project them to 2D space
|
# compute homogeneous coordinates and project them to 2D space
|
||||||
# TODO: create custom cost function
|
# TODO: create custom cost function
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user