WIP: vposer layer support

This commit is contained in:
Wlad 2021-02-01 23:53:36 +01:00
parent 53f76db68a
commit a9e7f221cc
6 changed files with 112 additions and 35 deletions

2
.gitignore vendored
View File

@ -94,3 +94,5 @@ models/*
.vscode .vscode
tum-3d-proj tum-3d-proj
reference reference
vposer_v1_0

View File

@ -37,7 +37,8 @@ class CameraEstimate:
self.device = device self.device = device
self.image_path = image_path self.image_path = image_path
self.keypoints = keypoints self.keypoints = keypoints
self.scale = torch.tensor([est_scale,est_scale,est_scale], requires_grad=False, dtype=self.dtype, device=self.device) self.scale = torch.tensor([est_scale, est_scale, est_scale],
requires_grad=False, dtype=self.dtype, device=self.device)
def get_torso_keypoints(self): def get_torso_keypoints(self):
smpl_keypoints = self.output_model.joints.detach().cpu().numpy().squeeze() smpl_keypoints = self.output_model.joints.detach().cpu().numpy().squeeze()
@ -73,7 +74,6 @@ class CameraEstimate:
def setup_visualization(self, render_points, render_keypoints): def setup_visualization(self, render_points, render_keypoints):
self.transformed_points = render_points self.transformed_points = render_points
def sum_of_squares(self, params, X, Y): def sum_of_squares(self, params, X, Y):
y_pred = self.loss_model(params, X) y_pred = self.loss_model(params, X)
loss = np.sum((y_pred - Y) ** 2) loss = np.sum((y_pred - Y) ** 2)
@ -167,8 +167,10 @@ class TorchCameraEstimate(CameraEstimate):
pbar.update(per - current) pbar.update(per - current)
current = per current = per
stop = loss > tol stop = loss > tol
if stop == True:
stop = self.patience_module(loss, 5) # FIXME: same error as below
# if stop == True:
# stop = self.patience_module(loss, 5)
pbar.update(abs(100 - current)) pbar.update(abs(100 - current))
pbar.close() pbar.close()
self.memory = None self.memory = None
@ -204,31 +206,45 @@ class TorchCameraEstimate(CameraEstimate):
stop = True stop = True
first = True first = True
cam_tol = 6e-5 cam_tol = 6e-3
print("Estimating Camera transformations...") print("Estimating Camera transformations...")
pbar = tqdm(total=100) pbar = tqdm(total=100)
current = 0 current = 0
while stop: while stop:
y_pred = self.transform_3d_to_2d( y_pred = self.transform_3d_to_2d(
params, init_points_3d_prepared) params, init_points_3d_prepared)
loss = torch.nn.SmoothL1Loss()(init_points_2d.float(), y_pred.float()) loss = torch.nn.SmoothL1Loss()(init_points_2d.float(), y_pred.float())
loss.requres_grad = True loss.requres_grad = True
opt2.zero_grad() opt2.zero_grad()
if first: if first:
loss.backward(retain_graph=True) loss.backward(retain_graph=True)
else: else:
loss.backward() loss.backward()
opt2.step() opt2.step()
self.renderer.scene.set_pose( self.camera_renderer, self.torch_params_to_pose(params).detach().numpy()) self.renderer.scene.set_pose(
self.camera_renderer, self.torch_params_to_pose(params).detach().numpy())
per = int((cam_tol/loss*100).item()) per = int((cam_tol/loss*100).item())
if per > 100: if per > 100:
pbar.update(100 - current) pbar.update(100 - current)
else: else:
pbar.update(per - current) pbar.update(per - current)
current = per current = per
stop = loss > cam_tol stop = loss > cam_tol
if stop == True:
stop = self.patience_module(loss, 5) # FIXME: this does not work for me, here is the error
# TypeError: eq() received an invalid combination of arguments - got (NoneType), but expected one of:
# * (Tensor other)
# didn't match because some of the arguments have invalid types: (NoneType)
# * (Number other)
# didn't match because some of the arguments have invalid types: (NoneType)
# if stop == True:
# stop = self.patience_module(loss, 5)
pbar.update(100 - current) pbar.update(100 - current)
pbar.close() pbar.close()
camera_transform_matrix = self.torch_params_to_pose( camera_transform_matrix = self.torch_params_to_pose(
@ -253,15 +269,15 @@ class TorchCameraEstimate(CameraEstimate):
def torch_params_to_pose(self, params): def torch_params_to_pose(self, params):
transform = rtvec_to_pose( transform = rtvec_to_pose(
torch.cat((params[1], params[0])).view(-1).unsqueeze(0)) torch.cat((params[1], params[0])).view(-1).unsqueeze(0))
for i in range(3): for i in range(3):
transform[0,i,i] *= self.scale[i] transform[0, i, i] *= self.scale[i]
return transform[0, :, :] return transform[0, :, :]
def C(self, params, X): def C(self, params, X):
Ext_mat = rtvec_to_pose( Ext_mat = rtvec_to_pose(
torch.cat((params[1], params[0])).view(-1).unsqueeze(0)) torch.cat((params[1], params[0])).view(-1).unsqueeze(0))
for i in range(3): for i in range(3):
Ext_mat[0,i,i] *= self.scale[i] Ext_mat[0, i, i] *= self.scale[i]
y_pred = Ext_mat @ X y_pred = Ext_mat @ X
y_pred = y_pred.squeeze(2) y_pred = y_pred.squeeze(2)
y_pred = y_pred[:, :3] y_pred = y_pred[:, :3]
@ -276,7 +292,7 @@ class TorchCameraEstimate(CameraEstimate):
def patience_module(self, variable, counter: int): def patience_module(self, variable, counter: int):
if self.memory == None: if self.memory == None:
self.memory=torch.clone(variable) self.memory = torch.clone(variable)
self.patience_count = 0 self.patience_count = 0
return True return True
if self.patience_count >= counter: if self.patience_count >= counter:
@ -289,7 +305,7 @@ class TorchCameraEstimate(CameraEstimate):
return True return True
else: else:
self.patience_count = 0 self.patience_count = 0
self.memory=torch.clone(variable) self.memory = torch.clone(variable)
return True return True
# sample_index = 0 # sample_index = 0

View File

@ -127,7 +127,7 @@ for t in range(5000):
camera_transf = trans.get_transform_mat(with_translate=True).detach().cpu() camera_transf = trans.get_transform_mat(with_translate=True).detach().cpu()
print("final pose:", camera_transf.numpy()) print("final pose:", camera_transf.numpy())
camera = SimpleCamera(dtype, device, z_scale=1, camera = SimpleCamera(dtype, device,
transform_mat=camera_transf) transform_mat=camera_transf)
train_pose( train_pose(

View File

@ -119,11 +119,11 @@ camera = TorchCameraEstimate(
device=torch.device('cpu'), device=torch.device('cpu'),
dtype=torch.float32, dtype=torch.float32,
image_path=img_path, image_path=img_path,
est_scale= est_scale est_scale=est_scale
) )
pose, transform, cam_trans = camera.estimate_camera_pos() pose, transform, cam_trans = camera.estimate_camera_pos()
camera.setup_visualization(render_points, render_keypoints ) camera.setup_visualization(render_points, render_keypoints)
# start renderer # start renderer
@ -135,9 +135,9 @@ camera_transformation = transform.clone().detach().to(device=device, dtype=dtype
camera_int = pose.clone().detach().to(device=device, dtype=dtype) camera_int = pose.clone().detach().to(device=device, dtype=dtype)
camera_params = cam_trans.clone().detach().to(device=device, dtype=dtype) camera_params = cam_trans.clone().detach().to(device=device, dtype=dtype)
camera = SimpleCamera(dtype, device, z_scale=1, camera = SimpleCamera(dtype, device,
transform_mat=camera_transformation, transform_mat=camera_transformation,
# camera_intrinsics=camera_int, camera_trans_rot=camera_params # camera_intrinsics=camera_int, camera_trans_rot=camera_params
) )
r.set_group_pose("body", camera_transformation.detach().cpu().numpy()) r.set_group_pose("body", camera_transformation.detach().cpu().numpy())

View File

@ -1,15 +1,60 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import smplx import smplx
from human_body_prior.body_model.body_model_vposer import BodyModelWithPoser
class VPoserModel():
def __init__(
self,
model_type='smpl',
vposer_model_path="./vposer_v1_0",
ext='npz',
gender='neutral',
create_body_pose=True,
plot_joints=True,
num_betas=10,
sample_shape=False,
sample_expression=False,
num_expression_coeffs=10,
use_face_contour=False
):
self.vposer_model_path = vposer_model_path
self.model_type = model_type
self.ext = ext
self.gender = gender
self.plot_joints = plot_joints
self.num_betas = num_betas
self.sample_shape = sample_shape
self.sample_expression = sample_expression
self.num_expression_coeffs = num_expression_coeffs
self.create_body_pose = create_body_pose
self.create_model()
def create_model(self):
self.model = BodyModelWithPoser(
bm_path="./models/smplx/SMPLX_MALE.npz",
batch_size=1,
poser_type="vposer",
smpl_exp_dir=self.vposer_model_path
)
return self.model
def get_vposer_latens(self):
return self.model.poZ_body
def get_pose(self):
return self.model.pose_body
class SMPLyModel(): class SMPLyModel():
def __init__( def __init__(
self, self,
model_folder, model_folder,
model_type='smpl', model_type='smplx',
ext='npz', ext='npz',
gender='neutral', gender='male',
create_body_pose=True, create_body_pose=True,
plot_joints=True, plot_joints=True,
num_betas=10, num_betas=10,

View File

@ -1,3 +1,4 @@
from model import VPoserModel
from modules.camera import SimpleCamera from modules.camera import SimpleCamera
from renderer import Renderer from renderer import Renderer
from utils.mapping import get_mapping_arr from utils.mapping import get_mapping_arr
@ -15,23 +16,27 @@ class BodyPose(nn.Module):
def __init__( def __init__(
self, self,
model: SMPL, model: SMPL,
keypoint_conf=None,
dtype=torch.float32, dtype=torch.float32,
device=None, device=None,
model_type="smplx"
): ):
super(BodyPose, self).__init__() super(BodyPose, self).__init__()
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.model = model self.model = model
self.model_type = model_type
# create valid joint filter # create valid joint filter
filter = self.get_joint_filter() filter = self.get_joint_filter()
self.register_buffer("filter", filter) self.register_buffer("filter", filter)
# attach SMPL pose tensor as parameter to the layer # attach SMPL pose tensor as parameter to the layer
body_pose = torch.zeros(model.body_pose.shape, # body_pose = torch.zeros(model.body_pose.shape,
dtype=dtype, device=device) # dtype=dtype, device=device)
body_pose = nn.Parameter(body_pose, requires_grad=True) # body_pose = nn.Parameter(body_pose, requires_grad=True)
self.register_parameter("pose", body_pose) # self.register_parameter("pose", body_pose)
def get_joint_filter(self): def get_joint_filter(self):
"""OpenPose and SMPL do not have fully matching joint positions, """OpenPose and SMPL do not have fully matching joint positions,
@ -42,7 +47,8 @@ class BodyPose(nn.Module):
""" """
# create a list with 1s for used joints and 0 for ignored joints # create a list with 1s for used joints and 0 for ignored joints
mapping = get_mapping_arr() mapping = get_mapping_arr(output_format=self.model_type)
print(mapping.shape)
filter = torch.zeros( filter = torch.zeros(
(len(mapping), 3), dtype=self.dtype, device=self.device) (len(mapping), 3), dtype=self.dtype, device=self.device)
for index, valid in enumerate(mapping > -1): for index, valid in enumerate(mapping > -1):
@ -51,15 +57,15 @@ class BodyPose(nn.Module):
return filter return filter
def forward(self): def forward(self, pose):
bode_output = self.model( bode_output = self.model(
body_pose=self.pose body_pose=pose
) )
# store model output for later renderer usage # store model output for later renderer usage
self.cur_out = bode_output self.cur_out = bode_output
joints = bode_output.joints joints = bode_output.joints
# return a list with invalid joints set to zero # return a list with invalid joints set to zero
return joints * self.filter.unsqueeze(0) return joints * self.filter.unsqueeze(0)
@ -70,14 +76,17 @@ def train_pose(
keypoint_conf, keypoint_conf,
camera: SimpleCamera, camera: SimpleCamera,
loss_layer=torch.nn.MSELoss(), loss_layer=torch.nn.MSELoss(),
learning_rate=1e-3, learning_rate=1e-1,
device=torch.device('cpu'), device=torch.device('cpu'),
dtype=torch.float32, dtype=torch.float32,
renderer: Renderer = None, renderer: Renderer = None,
optimizer=None, optimizer=None,
iterations=25 iterations=25
): ):
vposer = VPoserModel()
vposer_model = vposer.model
vposer_model.poZ_body.required_grad = True
vposer_params = vposer.get_vposer_latens()
# setup keypoint data # setup keypoint data
keypoints = torch.tensor(keypoints).to(device=device, dtype=dtype) keypoints = torch.tensor(keypoints).to(device=device, dtype=dtype)
keypoints_conf = torch.tensor(keypoint_conf).to(device) keypoints_conf = torch.tensor(keypoint_conf).to(device)
@ -88,14 +97,19 @@ def train_pose(
pose_layer = BodyPose(model, dtype=dtype, device=device).to(device) pose_layer = BodyPose(model, dtype=dtype, device=device).to(device)
if optimizer is None: if optimizer is None:
optimizer = torch.optim.LBFGS([pose_layer.pose], learning_rate) optimizer = torch.optim.LBFGS(
vposer_model.parameters(), learning_rate)
#optimizer = torch.optim.Adam(pose_layer.parameters(), learning_rate) #optimizer = torch.optim.Adam(pose_layer.parameters(), learning_rate)
pbar = tqdm(total=iterations) pbar = tqdm(total=iterations)
def predict(): def predict():
body = vposer_model()
pose = body.pose_body
print(pose)
# return joints based on current model state # return joints based on current model state
body_joints = pose_layer() body_joints = pose_layer(pose)
# compute homogeneous coordinates and project them to 2D space # compute homogeneous coordinates and project them to 2D space
# TODO: create custom cost function # TODO: create custom cost function