diff --git a/.gitignore b/.gitignore index d1b358f..8ac582c 100644 --- a/.gitignore +++ b/.gitignore @@ -93,4 +93,6 @@ models/* .vscode tum-3d-proj -reference \ No newline at end of file +reference + +vposer_v1_0 \ No newline at end of file diff --git a/camera_estimation.py b/camera_estimation.py index e263c42..bc1e005 100644 --- a/camera_estimation.py +++ b/camera_estimation.py @@ -37,7 +37,8 @@ class CameraEstimate: self.device = device self.image_path = image_path self.keypoints = keypoints - self.scale = torch.tensor([est_scale,est_scale,est_scale], requires_grad=False, dtype=self.dtype, device=self.device) + self.scale = torch.tensor([est_scale, est_scale, est_scale], + requires_grad=False, dtype=self.dtype, device=self.device) def get_torso_keypoints(self): smpl_keypoints = self.output_model.joints.detach().cpu().numpy().squeeze() @@ -73,7 +74,6 @@ class CameraEstimate: def setup_visualization(self, render_points, render_keypoints): self.transformed_points = render_points - def sum_of_squares(self, params, X, Y): y_pred = self.loss_model(params, X) loss = np.sum((y_pred - Y) ** 2) @@ -114,7 +114,7 @@ class CameraEstimate: class TorchCameraEstimate(CameraEstimate): - def estimate_camera_pos(self): + def estimate_camera_pos(self): self.memory = None translation = torch.zeros( 1, 3, requires_grad=True, dtype=self.dtype, device=self.device) @@ -167,8 +167,10 @@ class TorchCameraEstimate(CameraEstimate): pbar.update(per - current) current = per stop = loss > tol - if stop == True: - stop = self.patience_module(loss, 5) + + # FIXME: same error as below + # if stop == True: + # stop = self.patience_module(loss, 5) pbar.update(abs(100 - current)) pbar.close() self.memory = None @@ -204,31 +206,45 @@ class TorchCameraEstimate(CameraEstimate): stop = True first = True - cam_tol = 6e-5 + cam_tol = 6e-3 print("Estimating Camera transformations...") pbar = tqdm(total=100) current = 0 + while stop: y_pred = self.transform_3d_to_2d( params, init_points_3d_prepared) loss = torch.nn.SmoothL1Loss()(init_points_2d.float(), y_pred.float()) loss.requres_grad = True opt2.zero_grad() + if first: loss.backward(retain_graph=True) else: loss.backward() opt2.step() - self.renderer.scene.set_pose( self.camera_renderer, self.torch_params_to_pose(params).detach().numpy()) + self.renderer.scene.set_pose( + self.camera_renderer, self.torch_params_to_pose(params).detach().numpy()) per = int((cam_tol/loss*100).item()) + if per > 100: pbar.update(100 - current) else: pbar.update(per - current) + current = per stop = loss > cam_tol - if stop == True: - stop = self.patience_module(loss, 5) + + # FIXME: this does not work for me, here is the error + # TypeError: eq() received an invalid combination of arguments - got (NoneType), but expected one of: + # * (Tensor other) + # didn't match because some of the arguments have invalid types: (NoneType) + # * (Number other) + # didn't match because some of the arguments have invalid types: (NoneType) + + # if stop == True: + # stop = self.patience_module(loss, 5) + pbar.update(100 - current) pbar.close() camera_transform_matrix = self.torch_params_to_pose( @@ -253,15 +269,15 @@ class TorchCameraEstimate(CameraEstimate): def torch_params_to_pose(self, params): transform = rtvec_to_pose( torch.cat((params[1], params[0])).view(-1).unsqueeze(0)) - for i in range(3): - transform[0,i,i] *= self.scale[i] + for i in range(3): + transform[0, i, i] *= self.scale[i] return transform[0, :, :] def C(self, params, X): Ext_mat = rtvec_to_pose( torch.cat((params[1], params[0])).view(-1).unsqueeze(0)) - for i in range(3): - Ext_mat[0,i,i] *= self.scale[i] + for i in range(3): + Ext_mat[0, i, i] *= self.scale[i] y_pred = Ext_mat @ X y_pred = y_pred.squeeze(2) y_pred = y_pred[:, :3] @@ -276,7 +292,7 @@ class TorchCameraEstimate(CameraEstimate): def patience_module(self, variable, counter: int): if self.memory == None: - self.memory=torch.clone(variable) + self.memory = torch.clone(variable) self.patience_count = 0 return True if self.patience_count >= counter: @@ -289,7 +305,7 @@ class TorchCameraEstimate(CameraEstimate): return True else: self.patience_count = 0 - self.memory=torch.clone(variable) + self.memory = torch.clone(variable) return True # sample_index = 0 diff --git a/example_camera.py b/example_camera.py index e19d3ae..fd7cd06 100644 --- a/example_camera.py +++ b/example_camera.py @@ -127,7 +127,7 @@ for t in range(5000): camera_transf = trans.get_transform_mat(with_translate=True).detach().cpu() print("final pose:", camera_transf.numpy()) -camera = SimpleCamera(dtype, device, z_scale=1, +camera = SimpleCamera(dtype, device, transform_mat=camera_transf) train_pose( diff --git a/example_fit.py b/example_fit.py index 4da1214..276b6b3 100644 --- a/example_fit.py +++ b/example_fit.py @@ -119,11 +119,11 @@ camera = TorchCameraEstimate( device=torch.device('cpu'), dtype=torch.float32, image_path=img_path, - est_scale= est_scale + est_scale=est_scale ) pose, transform, cam_trans = camera.estimate_camera_pos() -camera.setup_visualization(render_points, render_keypoints ) +camera.setup_visualization(render_points, render_keypoints) # start renderer @@ -135,9 +135,9 @@ camera_transformation = transform.clone().detach().to(device=device, dtype=dtype camera_int = pose.clone().detach().to(device=device, dtype=dtype) camera_params = cam_trans.clone().detach().to(device=device, dtype=dtype) -camera = SimpleCamera(dtype, device, z_scale=1, +camera = SimpleCamera(dtype, device, transform_mat=camera_transformation, - # camera_intrinsics=camera_int, camera_trans_rot=camera_params + # camera_intrinsics=camera_int, camera_trans_rot=camera_params ) r.set_group_pose("body", camera_transformation.detach().cpu().numpy()) diff --git a/model.py b/model.py index ad06117..e2247ff 100644 --- a/model.py +++ b/model.py @@ -1,15 +1,60 @@ import matplotlib.pyplot as plt import numpy as np import smplx +from human_body_prior.body_model.body_model_vposer import BodyModelWithPoser + + +class VPoserModel(): + def __init__( + self, + model_type='smpl', + vposer_model_path="./vposer_v1_0", + ext='npz', + gender='neutral', + create_body_pose=True, + plot_joints=True, + num_betas=10, + sample_shape=False, + sample_expression=False, + num_expression_coeffs=10, + use_face_contour=False + ): + self.vposer_model_path = vposer_model_path + self.model_type = model_type + self.ext = ext + self.gender = gender + self.plot_joints = plot_joints + self.num_betas = num_betas + self.sample_shape = sample_shape + self.sample_expression = sample_expression + self.num_expression_coeffs = num_expression_coeffs + self.create_body_pose = create_body_pose + + self.create_model() + + def create_model(self): + self.model = BodyModelWithPoser( + bm_path="./models/smplx/SMPLX_MALE.npz", + batch_size=1, + poser_type="vposer", + smpl_exp_dir=self.vposer_model_path + ) + return self.model + + def get_vposer_latens(self): + return self.model.poZ_body + + def get_pose(self): + return self.model.pose_body class SMPLyModel(): def __init__( self, model_folder, - model_type='smpl', + model_type='smplx', ext='npz', - gender='neutral', + gender='male', create_body_pose=True, plot_joints=True, num_betas=10, diff --git a/modules/pose.py b/modules/pose.py index 1582b3b..de093ee 100644 --- a/modules/pose.py +++ b/modules/pose.py @@ -1,3 +1,4 @@ +from model import VPoserModel from modules.camera import SimpleCamera from renderer import Renderer from utils.mapping import get_mapping_arr @@ -15,23 +16,27 @@ class BodyPose(nn.Module): def __init__( self, model: SMPL, + keypoint_conf=None, dtype=torch.float32, device=None, + model_type="smplx" + ): super(BodyPose, self).__init__() self.dtype = dtype self.device = device self.model = model + self.model_type = model_type # create valid joint filter filter = self.get_joint_filter() self.register_buffer("filter", filter) # attach SMPL pose tensor as parameter to the layer - body_pose = torch.zeros(model.body_pose.shape, - dtype=dtype, device=device) - body_pose = nn.Parameter(body_pose, requires_grad=True) - self.register_parameter("pose", body_pose) + # body_pose = torch.zeros(model.body_pose.shape, + # dtype=dtype, device=device) + # body_pose = nn.Parameter(body_pose, requires_grad=True) + # self.register_parameter("pose", body_pose) def get_joint_filter(self): """OpenPose and SMPL do not have fully matching joint positions, @@ -42,7 +47,8 @@ class BodyPose(nn.Module): """ # create a list with 1s for used joints and 0 for ignored joints - mapping = get_mapping_arr() + mapping = get_mapping_arr(output_format=self.model_type) + print(mapping.shape) filter = torch.zeros( (len(mapping), 3), dtype=self.dtype, device=self.device) for index, valid in enumerate(mapping > -1): @@ -51,15 +57,15 @@ class BodyPose(nn.Module): return filter - def forward(self): + def forward(self, pose): bode_output = self.model( - body_pose=self.pose + body_pose=pose ) + # store model output for later renderer usage self.cur_out = bode_output joints = bode_output.joints - # return a list with invalid joints set to zero return joints * self.filter.unsqueeze(0) @@ -70,14 +76,17 @@ def train_pose( keypoint_conf, camera: SimpleCamera, loss_layer=torch.nn.MSELoss(), - learning_rate=1e-3, + learning_rate=1e-1, device=torch.device('cpu'), dtype=torch.float32, renderer: Renderer = None, optimizer=None, iterations=25 ): - + vposer = VPoserModel() + vposer_model = vposer.model + vposer_model.poZ_body.required_grad = True + vposer_params = vposer.get_vposer_latens() # setup keypoint data keypoints = torch.tensor(keypoints).to(device=device, dtype=dtype) keypoints_conf = torch.tensor(keypoint_conf).to(device) @@ -88,14 +97,19 @@ def train_pose( pose_layer = BodyPose(model, dtype=dtype, device=device).to(device) if optimizer is None: - optimizer = torch.optim.LBFGS([pose_layer.pose], learning_rate) + optimizer = torch.optim.LBFGS( + vposer_model.parameters(), learning_rate) #optimizer = torch.optim.Adam(pose_layer.parameters(), learning_rate) pbar = tqdm(total=iterations) def predict(): + body = vposer_model() + pose = body.pose_body + print(pose) + # return joints based on current model state - body_joints = pose_layer() + body_joints = pose_layer(pose) # compute homogeneous coordinates and project them to 2D space # TODO: create custom cost function