diff --git a/config.yaml b/config.yaml index f5d3eb9..f2d203b 100644 --- a/config.yaml +++ b/config.yaml @@ -23,7 +23,7 @@ pose: lr: 0.01 optimizer: Adam # currently supported Adam, LBFGS iterations: 100 - useCameraIntrinsics: false + useCameraIntrinsics: true bodyMeanLoss: enabled: false weight: 0.1 diff --git a/example_fit.py b/example_fit.py index 4407143..82c8e83 100644 --- a/example_fit.py +++ b/example_fit.py @@ -42,16 +42,6 @@ camera = TorchCameraEstimate( # render camera to the scene camera.setup_visualization(r.init_keypoints, r.keypoints) -# run camera optimizer -cam, cam_trans, cam_int, cam_params = SimpleCamera.from_estimation_cam( - camera, - dtype=dtype, - device=device, -) - -# apply transform to scene -r.set_group_pose("body", cam_trans.cpu().numpy()) - # train for pose train_pose_with_conf( @@ -59,7 +49,7 @@ train_pose_with_conf( model=model, keypoints=keypoints, keypoint_conf=conf, - camera=cam, + camera=camera, renderer=r, device=device, ) diff --git a/example_pose.py b/example_pose.py index d7fd13b..4e13c40 100644 --- a/example_pose.py +++ b/example_pose.py @@ -3,12 +3,8 @@ import numpy as np # local imports from renderer import DefaultRenderer -from train_pose import train_pose_with_conf -from modules.camera import SimpleCamera from model import SMPLyModel -from utils.general import load_config, setup_training -from camera_estimation import TorchCameraEstimate -from dataset import SMPLyDataset +from utils.general import load_config # this a simple pose playground with a async renderer for quick prototyping diff --git a/modules/angle_clip.py b/modules/angle_clip.py index b31a66a..476e529 100644 --- a/modules/angle_clip.py +++ b/modules/angle_clip.py @@ -2,12 +2,13 @@ import torch import torch.nn as nn import numpy as np + class AngleClipper(nn.Module): def __init__( self, device=torch.device('cpu'), dtype=torch.float32, - angle_idx=[24, 10 , 9], + angle_idx=[24, 10, 9], # directions=[-1, 1, 1, 1], weights=[1.0, 1.0, 1.0] ): @@ -35,8 +36,7 @@ class AngleClipper(nn.Module): angles = pose[:, self.angle_idx] - penalty = angles[torch.abs(angles) > self.limit] + penalty = angles[torch.abs(angles) > self.limit] # get relevant angles return penalty.pow(2).sum() * 0.01 - diff --git a/modules/camera.py b/modules/camera.py index eb92c96..56b6771 100644 --- a/modules/camera.py +++ b/modules/camera.py @@ -8,34 +8,63 @@ from model import * from dataset import * -class SimpleCamera(nn.Module): +class TransformCamera(nn.Module): def __init__( self, + transform_mat: torch.Tensor, dtype=torch.float32, device=None, - transform_mat=None, - camera_intrinsics=None, - camera_trans_rot=None ): - super(SimpleCamera, self).__init__() - self.hasTransform = False - self.hasCameraTransform = False + super(TransformCamera, self).__init__() + self.dtype = dtype self.device = device - self.model_type = "smplx" - if camera_intrinsics is not None: - self.hasCameraTransform = True - self.register_buffer("cam_int", camera_intrinsics) - self.register_buffer("cam_trans_rot", camera_trans_rot) - self.register_buffer("trans", transform_mat) - # self.register_buffer("disp_trans", camera_trans_rot) - elif transform_mat is not None: - self.hasTransform = True - self.register_buffer("trans", transform_mat) - # self.register_buffer("disp_trans", transform_mat) + self.register_buffer("trans", transform_mat.to( + device=device, dtype=dtype)) - def from_estimation_cam(cam: TorchCameraEstimate, device=None, dtype=None): + def forward(self, points): + proj_points = self.trans @ points.reshape(-1, 4, 1) + proj_points = proj_points.reshape(1, -1, 4)[:, :, :2] * 1 + proj_points = F.pad(proj_points, (0, 1, 0, 0), value=0) + return proj_points + + +class IntrinsicsCamera(nn.Module): + def __init__( + self, + transform_mat: torch.Tensor, + camera_intrinsics: torch.Tensor, + camera_trans_rot: torch.Tensor, + dtype=torch.float32, + device=None + ): + super(IntrinsicsCamera, self).__init__() + + self.dtype = dtype + self.device = device + + self.register_buffer("cam_int", camera_intrinsics.to( + device=device, dtype=dtype)) + self.register_buffer("cam_trans_rot", camera_trans_rot.to( + device=device, dtype=dtype)) + self.register_buffer("trans", transform_mat.to( + device=device, dtype=dtype)) + + def forward(self, points): + proj_points = self.cam_int[:3, :3] @ self.cam_trans_rot[:3, + :] @ self.trans @ points.reshape(-1, 4, 1) + result = proj_points.squeeze(2) + denomiator = torch.zeros(points.shape[1], 3) + for i in range(points.shape[1]): + denomiator[i, :] = result[i, 2] + result = result/denomiator + result[:, 2] = 0 + return result + + +class SimpleCamera(nn.Module): + def from_estimation_cam(cam: TorchCameraEstimate, use_intrinsics=False, device=None, dtype=None): """utility to create camera module from estimation camera Args: @@ -44,29 +73,21 @@ class SimpleCamera(nn.Module): cam_trans, cam_int, cam_params = cam.get_results( device=device, dtype=dtype) - return SimpleCamera( - dtype, - device, - transform_mat=cam_trans, - camera_intrinsics=cam_int, camera_trans_rot=cam_params - ), cam_trans, cam_int, cam_params + cam_layer = None - def forward(self, points): - if self.hasTransform: - proj_points = self.trans @ points.reshape(-1, 4, 1) - proj_points = proj_points.reshape(1, -1, 4)[:, :, :2] * 1 - proj_points = F.pad(proj_points, (0, 1, 0, 0), value=0) - return proj_points - if self.hasCameraTransform: - proj_points = self.cam_int[:3, :3] @ self.cam_trans_rot[:3, - :] @ self.trans @ points.reshape(-1, 4, 1) - result = proj_points.squeeze(2) - denomiator = torch.zeros(points.shape[1], 3) - for i in range(points.shape[1]): - denomiator[i, :] = result[i, 2] - result = result/denomiator - result[:, 2] = 0 - return result + if use_intrinsics: + cam_layer = IntrinsicsCamera( + transform_mat=cam_trans, + camera_intrinsics=cam_int, + camera_trans_rot=cam_params, + device=device, + dtype=dtype, + ) + else: + cam_layer = TransformCamera( + transform_mat=cam_trans, + device=device, + dtype=dtype, + ) - # scale = (points[:, :, 2] / self.z_scale) - # print(points.shape, scale.shape) + return cam_layer, cam_trans, cam_int, cam_params diff --git a/train_pose.py b/train_pose.py index ae7cf29..e4b83b9 100644 --- a/train_pose.py +++ b/train_pose.py @@ -1,3 +1,4 @@ +from camera_estimation import TorchCameraEstimate from modules.angle_clip import AngleClipper from modules.angle import AnglePriorsLoss import smplx @@ -41,7 +42,7 @@ def train_pose( print("[pose] starting training") print("[pose] dtype=", dtype) - loss_layer = torch.nn.MSELoss().to(device=device, dtype=dtype) #MSELoss() + loss_layer = torch.nn.MSELoss().to(device=device, dtype=dtype) # MSELoss() clip_loss_layer = AngleClipper().to(device=device, dtype=dtype) @@ -91,10 +92,9 @@ def train_pose( pose_extra = None # if useBodyPrior: - # body = vposer_layer() - # poZ = body.poZ_body - # pose_extra = body.pose_body - + # body = vposer_layer() + # poZ = body.poZ_body + # pose_extra = body.pose_body # return joints based on current model state body_joints, cur_pose = pose_layer() @@ -115,12 +115,13 @@ def train_pose( body_mean_loss = 0.0 if body_mean_loss: body_mean_loss = (cur_pose - - body_mean_pose).pow(2).sum() * body_mean_weight + body_mean_pose).pow(2).sum() * body_mean_weight body_prior_loss = 0.0 if useBodyPrior: # apply pose prior loss. - body_prior_loss = latent_body.pose_body.pow(2).sum() * body_prior_weight + body_prior_loss = latent_body.pose_body.pow( + 2).sum() * body_prior_weight angle_prior_loss = 0.0 if useAnglePrior: @@ -130,11 +131,11 @@ def train_pose( angle_sum_loss = 0.0 if use_angle_sum_loss: - angle_sum_loss = clip_loss_layer(cur_pose) * angle_sum_weight + angle_sum_loss = clip_loss_layer(cur_pose) # * angle_sum_weight loss = loss + body_mean_loss + body_prior_loss + angle_prior_loss + angle_sum_loss - return loss + return loss def optim_closure(): if torch.is_grad_enabled(): @@ -191,32 +192,44 @@ def train_pose( pbar.close() print("Final result:", loss.item()) - return pose_layer.cur_out + return pose_layer.cur_out, best_pose def train_pose_with_conf( config, + camera: TorchCameraEstimate, model: smplx.SMPL, keypoints, keypoint_conf, - camera: SimpleCamera, device=torch.device('cpu'), dtype=torch.float32, renderer: Renderer = None, ): # configure PyTorch device and format - dtype = torch.float64 + # dtype = torch.float64 if 'device' in config['pose'] is not None: device = torch.device(config['pose']['device']) else: device = torch.device('cpu') + # create camera module + pose_camera, cam_trans, cam_int, cam_params = SimpleCamera.from_estimation_cam( + cam=camera, + use_intrinsics=config['pose']['useCameraIntrinsics'], + dtype=dtype, + device=device, + ) + + # apply transform to scene + if renderer is not None: + renderer.set_group_pose("body", cam_trans.cpu().numpy()) + return train_pose( model=model.to(dtype=dtype), keypoints=keypoints, keypoint_conf=keypoint_conf, - camera=camera, + camera=pose_camera, device=device, dtype=dtype, renderer=renderer,