mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
merged camera
This commit is contained in:
parent
ca3b4e7167
commit
26401f959e
@ -23,7 +23,7 @@ pose:
|
||||
lr: 0.01
|
||||
optimizer: Adam # currently supported Adam, LBFGS
|
||||
iterations: 100
|
||||
useCameraIntrinsics: false
|
||||
useCameraIntrinsics: true
|
||||
bodyMeanLoss:
|
||||
enabled: false
|
||||
weight: 0.1
|
||||
|
||||
@ -42,16 +42,6 @@ camera = TorchCameraEstimate(
|
||||
# render camera to the scene
|
||||
camera.setup_visualization(r.init_keypoints, r.keypoints)
|
||||
|
||||
# run camera optimizer
|
||||
cam, cam_trans, cam_int, cam_params = SimpleCamera.from_estimation_cam(
|
||||
camera,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# apply transform to scene
|
||||
r.set_group_pose("body", cam_trans.cpu().numpy())
|
||||
|
||||
|
||||
# train for pose
|
||||
train_pose_with_conf(
|
||||
@ -59,7 +49,7 @@ train_pose_with_conf(
|
||||
model=model,
|
||||
keypoints=keypoints,
|
||||
keypoint_conf=conf,
|
||||
camera=cam,
|
||||
camera=camera,
|
||||
renderer=r,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@ -3,12 +3,8 @@ import numpy as np
|
||||
|
||||
# local imports
|
||||
from renderer import DefaultRenderer
|
||||
from train_pose import train_pose_with_conf
|
||||
from modules.camera import SimpleCamera
|
||||
from model import SMPLyModel
|
||||
from utils.general import load_config, setup_training
|
||||
from camera_estimation import TorchCameraEstimate
|
||||
from dataset import SMPLyDataset
|
||||
from utils.general import load_config
|
||||
|
||||
# this a simple pose playground with a async renderer for quick prototyping
|
||||
|
||||
|
||||
@ -2,12 +2,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AngleClipper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
device=torch.device('cpu'),
|
||||
dtype=torch.float32,
|
||||
angle_idx=[24, 10 , 9],
|
||||
angle_idx=[24, 10, 9],
|
||||
# directions=[-1, 1, 1, 1],
|
||||
weights=[1.0, 1.0, 1.0]
|
||||
):
|
||||
@ -35,8 +36,7 @@ class AngleClipper(nn.Module):
|
||||
|
||||
angles = pose[:, self.angle_idx]
|
||||
|
||||
penalty = angles[torch.abs(angles) > self.limit]
|
||||
penalty = angles[torch.abs(angles) > self.limit]
|
||||
|
||||
# get relevant angles
|
||||
return penalty.pow(2).sum() * 0.01
|
||||
|
||||
|
||||
@ -8,34 +8,63 @@ from model import *
|
||||
from dataset import *
|
||||
|
||||
|
||||
class SimpleCamera(nn.Module):
|
||||
class TransformCamera(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
transform_mat: torch.Tensor,
|
||||
dtype=torch.float32,
|
||||
device=None,
|
||||
transform_mat=None,
|
||||
camera_intrinsics=None,
|
||||
camera_trans_rot=None
|
||||
):
|
||||
super(SimpleCamera, self).__init__()
|
||||
self.hasTransform = False
|
||||
self.hasCameraTransform = False
|
||||
super(TransformCamera, self).__init__()
|
||||
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.model_type = "smplx"
|
||||
|
||||
if camera_intrinsics is not None:
|
||||
self.hasCameraTransform = True
|
||||
self.register_buffer("cam_int", camera_intrinsics)
|
||||
self.register_buffer("cam_trans_rot", camera_trans_rot)
|
||||
self.register_buffer("trans", transform_mat)
|
||||
# self.register_buffer("disp_trans", camera_trans_rot)
|
||||
elif transform_mat is not None:
|
||||
self.hasTransform = True
|
||||
self.register_buffer("trans", transform_mat)
|
||||
# self.register_buffer("disp_trans", transform_mat)
|
||||
self.register_buffer("trans", transform_mat.to(
|
||||
device=device, dtype=dtype))
|
||||
|
||||
def from_estimation_cam(cam: TorchCameraEstimate, device=None, dtype=None):
|
||||
def forward(self, points):
|
||||
proj_points = self.trans @ points.reshape(-1, 4, 1)
|
||||
proj_points = proj_points.reshape(1, -1, 4)[:, :, :2] * 1
|
||||
proj_points = F.pad(proj_points, (0, 1, 0, 0), value=0)
|
||||
return proj_points
|
||||
|
||||
|
||||
class IntrinsicsCamera(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
transform_mat: torch.Tensor,
|
||||
camera_intrinsics: torch.Tensor,
|
||||
camera_trans_rot: torch.Tensor,
|
||||
dtype=torch.float32,
|
||||
device=None
|
||||
):
|
||||
super(IntrinsicsCamera, self).__init__()
|
||||
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.register_buffer("cam_int", camera_intrinsics.to(
|
||||
device=device, dtype=dtype))
|
||||
self.register_buffer("cam_trans_rot", camera_trans_rot.to(
|
||||
device=device, dtype=dtype))
|
||||
self.register_buffer("trans", transform_mat.to(
|
||||
device=device, dtype=dtype))
|
||||
|
||||
def forward(self, points):
|
||||
proj_points = self.cam_int[:3, :3] @ self.cam_trans_rot[:3,
|
||||
:] @ self.trans @ points.reshape(-1, 4, 1)
|
||||
result = proj_points.squeeze(2)
|
||||
denomiator = torch.zeros(points.shape[1], 3)
|
||||
for i in range(points.shape[1]):
|
||||
denomiator[i, :] = result[i, 2]
|
||||
result = result/denomiator
|
||||
result[:, 2] = 0
|
||||
return result
|
||||
|
||||
|
||||
class SimpleCamera(nn.Module):
|
||||
def from_estimation_cam(cam: TorchCameraEstimate, use_intrinsics=False, device=None, dtype=None):
|
||||
"""utility to create camera module from estimation camera
|
||||
|
||||
Args:
|
||||
@ -44,29 +73,21 @@ class SimpleCamera(nn.Module):
|
||||
cam_trans, cam_int, cam_params = cam.get_results(
|
||||
device=device, dtype=dtype)
|
||||
|
||||
return SimpleCamera(
|
||||
dtype,
|
||||
device,
|
||||
transform_mat=cam_trans,
|
||||
camera_intrinsics=cam_int, camera_trans_rot=cam_params
|
||||
), cam_trans, cam_int, cam_params
|
||||
cam_layer = None
|
||||
|
||||
def forward(self, points):
|
||||
if self.hasTransform:
|
||||
proj_points = self.trans @ points.reshape(-1, 4, 1)
|
||||
proj_points = proj_points.reshape(1, -1, 4)[:, :, :2] * 1
|
||||
proj_points = F.pad(proj_points, (0, 1, 0, 0), value=0)
|
||||
return proj_points
|
||||
if self.hasCameraTransform:
|
||||
proj_points = self.cam_int[:3, :3] @ self.cam_trans_rot[:3,
|
||||
:] @ self.trans @ points.reshape(-1, 4, 1)
|
||||
result = proj_points.squeeze(2)
|
||||
denomiator = torch.zeros(points.shape[1], 3)
|
||||
for i in range(points.shape[1]):
|
||||
denomiator[i, :] = result[i, 2]
|
||||
result = result/denomiator
|
||||
result[:, 2] = 0
|
||||
return result
|
||||
if use_intrinsics:
|
||||
cam_layer = IntrinsicsCamera(
|
||||
transform_mat=cam_trans,
|
||||
camera_intrinsics=cam_int,
|
||||
camera_trans_rot=cam_params,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
cam_layer = TransformCamera(
|
||||
transform_mat=cam_trans,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# scale = (points[:, :, 2] / self.z_scale)
|
||||
# print(points.shape, scale.shape)
|
||||
return cam_layer, cam_trans, cam_int, cam_params
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from camera_estimation import TorchCameraEstimate
|
||||
from modules.angle_clip import AngleClipper
|
||||
from modules.angle import AnglePriorsLoss
|
||||
import smplx
|
||||
@ -41,7 +42,7 @@ def train_pose(
|
||||
print("[pose] starting training")
|
||||
print("[pose] dtype=", dtype)
|
||||
|
||||
loss_layer = torch.nn.MSELoss().to(device=device, dtype=dtype) #MSELoss()
|
||||
loss_layer = torch.nn.MSELoss().to(device=device, dtype=dtype) # MSELoss()
|
||||
|
||||
clip_loss_layer = AngleClipper().to(device=device, dtype=dtype)
|
||||
|
||||
@ -91,10 +92,9 @@ def train_pose(
|
||||
pose_extra = None
|
||||
|
||||
# if useBodyPrior:
|
||||
# body = vposer_layer()
|
||||
# poZ = body.poZ_body
|
||||
# pose_extra = body.pose_body
|
||||
|
||||
# body = vposer_layer()
|
||||
# poZ = body.poZ_body
|
||||
# pose_extra = body.pose_body
|
||||
|
||||
# return joints based on current model state
|
||||
body_joints, cur_pose = pose_layer()
|
||||
@ -115,12 +115,13 @@ def train_pose(
|
||||
body_mean_loss = 0.0
|
||||
if body_mean_loss:
|
||||
body_mean_loss = (cur_pose -
|
||||
body_mean_pose).pow(2).sum() * body_mean_weight
|
||||
body_mean_pose).pow(2).sum() * body_mean_weight
|
||||
|
||||
body_prior_loss = 0.0
|
||||
if useBodyPrior:
|
||||
# apply pose prior loss.
|
||||
body_prior_loss = latent_body.pose_body.pow(2).sum() * body_prior_weight
|
||||
body_prior_loss = latent_body.pose_body.pow(
|
||||
2).sum() * body_prior_weight
|
||||
|
||||
angle_prior_loss = 0.0
|
||||
if useAnglePrior:
|
||||
@ -130,11 +131,11 @@ def train_pose(
|
||||
|
||||
angle_sum_loss = 0.0
|
||||
if use_angle_sum_loss:
|
||||
angle_sum_loss = clip_loss_layer(cur_pose) * angle_sum_weight
|
||||
angle_sum_loss = clip_loss_layer(cur_pose) # * angle_sum_weight
|
||||
|
||||
loss = loss + body_mean_loss + body_prior_loss + angle_prior_loss + angle_sum_loss
|
||||
|
||||
return loss
|
||||
return loss
|
||||
|
||||
def optim_closure():
|
||||
if torch.is_grad_enabled():
|
||||
@ -191,32 +192,44 @@ def train_pose(
|
||||
|
||||
pbar.close()
|
||||
print("Final result:", loss.item())
|
||||
return pose_layer.cur_out
|
||||
return pose_layer.cur_out, best_pose
|
||||
|
||||
|
||||
def train_pose_with_conf(
|
||||
config,
|
||||
camera: TorchCameraEstimate,
|
||||
model: smplx.SMPL,
|
||||
keypoints,
|
||||
keypoint_conf,
|
||||
camera: SimpleCamera,
|
||||
device=torch.device('cpu'),
|
||||
dtype=torch.float32,
|
||||
renderer: Renderer = None,
|
||||
):
|
||||
|
||||
# configure PyTorch device and format
|
||||
dtype = torch.float64
|
||||
# dtype = torch.float64
|
||||
if 'device' in config['pose'] is not None:
|
||||
device = torch.device(config['pose']['device'])
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
# create camera module
|
||||
pose_camera, cam_trans, cam_int, cam_params = SimpleCamera.from_estimation_cam(
|
||||
cam=camera,
|
||||
use_intrinsics=config['pose']['useCameraIntrinsics'],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# apply transform to scene
|
||||
if renderer is not None:
|
||||
renderer.set_group_pose("body", cam_trans.cpu().numpy())
|
||||
|
||||
return train_pose(
|
||||
model=model.to(dtype=dtype),
|
||||
keypoints=keypoints,
|
||||
keypoint_conf=keypoint_conf,
|
||||
camera=camera,
|
||||
camera=pose_camera,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
renderer=renderer,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user