merged camera

This commit is contained in:
Wlad 2021-02-07 17:40:13 +01:00
parent ca3b4e7167
commit 26401f959e
6 changed files with 96 additions and 76 deletions

View File

@ -23,7 +23,7 @@ pose:
lr: 0.01 lr: 0.01
optimizer: Adam # currently supported Adam, LBFGS optimizer: Adam # currently supported Adam, LBFGS
iterations: 100 iterations: 100
useCameraIntrinsics: false useCameraIntrinsics: true
bodyMeanLoss: bodyMeanLoss:
enabled: false enabled: false
weight: 0.1 weight: 0.1

View File

@ -42,16 +42,6 @@ camera = TorchCameraEstimate(
# render camera to the scene # render camera to the scene
camera.setup_visualization(r.init_keypoints, r.keypoints) camera.setup_visualization(r.init_keypoints, r.keypoints)
# run camera optimizer
cam, cam_trans, cam_int, cam_params = SimpleCamera.from_estimation_cam(
camera,
dtype=dtype,
device=device,
)
# apply transform to scene
r.set_group_pose("body", cam_trans.cpu().numpy())
# train for pose # train for pose
train_pose_with_conf( train_pose_with_conf(
@ -59,7 +49,7 @@ train_pose_with_conf(
model=model, model=model,
keypoints=keypoints, keypoints=keypoints,
keypoint_conf=conf, keypoint_conf=conf,
camera=cam, camera=camera,
renderer=r, renderer=r,
device=device, device=device,
) )

View File

@ -3,12 +3,8 @@ import numpy as np
# local imports # local imports
from renderer import DefaultRenderer from renderer import DefaultRenderer
from train_pose import train_pose_with_conf
from modules.camera import SimpleCamera
from model import SMPLyModel from model import SMPLyModel
from utils.general import load_config, setup_training from utils.general import load_config
from camera_estimation import TorchCameraEstimate
from dataset import SMPLyDataset
# this a simple pose playground with a async renderer for quick prototyping # this a simple pose playground with a async renderer for quick prototyping

View File

@ -2,12 +2,13 @@ import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
class AngleClipper(nn.Module): class AngleClipper(nn.Module):
def __init__( def __init__(
self, self,
device=torch.device('cpu'), device=torch.device('cpu'),
dtype=torch.float32, dtype=torch.float32,
angle_idx=[24, 10 , 9], angle_idx=[24, 10, 9],
# directions=[-1, 1, 1, 1], # directions=[-1, 1, 1, 1],
weights=[1.0, 1.0, 1.0] weights=[1.0, 1.0, 1.0]
): ):
@ -35,8 +36,7 @@ class AngleClipper(nn.Module):
angles = pose[:, self.angle_idx] angles = pose[:, self.angle_idx]
penalty = angles[torch.abs(angles) > self.limit] penalty = angles[torch.abs(angles) > self.limit]
# get relevant angles # get relevant angles
return penalty.pow(2).sum() * 0.01 return penalty.pow(2).sum() * 0.01

View File

@ -8,34 +8,63 @@ from model import *
from dataset import * from dataset import *
class SimpleCamera(nn.Module): class TransformCamera(nn.Module):
def __init__( def __init__(
self, self,
transform_mat: torch.Tensor,
dtype=torch.float32, dtype=torch.float32,
device=None, device=None,
transform_mat=None,
camera_intrinsics=None,
camera_trans_rot=None
): ):
super(SimpleCamera, self).__init__() super(TransformCamera, self).__init__()
self.hasTransform = False
self.hasCameraTransform = False
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.model_type = "smplx"
if camera_intrinsics is not None: self.register_buffer("trans", transform_mat.to(
self.hasCameraTransform = True device=device, dtype=dtype))
self.register_buffer("cam_int", camera_intrinsics)
self.register_buffer("cam_trans_rot", camera_trans_rot)
self.register_buffer("trans", transform_mat)
# self.register_buffer("disp_trans", camera_trans_rot)
elif transform_mat is not None:
self.hasTransform = True
self.register_buffer("trans", transform_mat)
# self.register_buffer("disp_trans", transform_mat)
def from_estimation_cam(cam: TorchCameraEstimate, device=None, dtype=None): def forward(self, points):
proj_points = self.trans @ points.reshape(-1, 4, 1)
proj_points = proj_points.reshape(1, -1, 4)[:, :, :2] * 1
proj_points = F.pad(proj_points, (0, 1, 0, 0), value=0)
return proj_points
class IntrinsicsCamera(nn.Module):
def __init__(
self,
transform_mat: torch.Tensor,
camera_intrinsics: torch.Tensor,
camera_trans_rot: torch.Tensor,
dtype=torch.float32,
device=None
):
super(IntrinsicsCamera, self).__init__()
self.dtype = dtype
self.device = device
self.register_buffer("cam_int", camera_intrinsics.to(
device=device, dtype=dtype))
self.register_buffer("cam_trans_rot", camera_trans_rot.to(
device=device, dtype=dtype))
self.register_buffer("trans", transform_mat.to(
device=device, dtype=dtype))
def forward(self, points):
proj_points = self.cam_int[:3, :3] @ self.cam_trans_rot[:3,
:] @ self.trans @ points.reshape(-1, 4, 1)
result = proj_points.squeeze(2)
denomiator = torch.zeros(points.shape[1], 3)
for i in range(points.shape[1]):
denomiator[i, :] = result[i, 2]
result = result/denomiator
result[:, 2] = 0
return result
class SimpleCamera(nn.Module):
def from_estimation_cam(cam: TorchCameraEstimate, use_intrinsics=False, device=None, dtype=None):
"""utility to create camera module from estimation camera """utility to create camera module from estimation camera
Args: Args:
@ -44,29 +73,21 @@ class SimpleCamera(nn.Module):
cam_trans, cam_int, cam_params = cam.get_results( cam_trans, cam_int, cam_params = cam.get_results(
device=device, dtype=dtype) device=device, dtype=dtype)
return SimpleCamera( cam_layer = None
dtype,
device,
transform_mat=cam_trans,
camera_intrinsics=cam_int, camera_trans_rot=cam_params
), cam_trans, cam_int, cam_params
def forward(self, points): if use_intrinsics:
if self.hasTransform: cam_layer = IntrinsicsCamera(
proj_points = self.trans @ points.reshape(-1, 4, 1) transform_mat=cam_trans,
proj_points = proj_points.reshape(1, -1, 4)[:, :, :2] * 1 camera_intrinsics=cam_int,
proj_points = F.pad(proj_points, (0, 1, 0, 0), value=0) camera_trans_rot=cam_params,
return proj_points device=device,
if self.hasCameraTransform: dtype=dtype,
proj_points = self.cam_int[:3, :3] @ self.cam_trans_rot[:3, )
:] @ self.trans @ points.reshape(-1, 4, 1) else:
result = proj_points.squeeze(2) cam_layer = TransformCamera(
denomiator = torch.zeros(points.shape[1], 3) transform_mat=cam_trans,
for i in range(points.shape[1]): device=device,
denomiator[i, :] = result[i, 2] dtype=dtype,
result = result/denomiator )
result[:, 2] = 0
return result
# scale = (points[:, :, 2] / self.z_scale) return cam_layer, cam_trans, cam_int, cam_params
# print(points.shape, scale.shape)

View File

@ -1,3 +1,4 @@
from camera_estimation import TorchCameraEstimate
from modules.angle_clip import AngleClipper from modules.angle_clip import AngleClipper
from modules.angle import AnglePriorsLoss from modules.angle import AnglePriorsLoss
import smplx import smplx
@ -41,7 +42,7 @@ def train_pose(
print("[pose] starting training") print("[pose] starting training")
print("[pose] dtype=", dtype) print("[pose] dtype=", dtype)
loss_layer = torch.nn.MSELoss().to(device=device, dtype=dtype) #MSELoss() loss_layer = torch.nn.MSELoss().to(device=device, dtype=dtype) # MSELoss()
clip_loss_layer = AngleClipper().to(device=device, dtype=dtype) clip_loss_layer = AngleClipper().to(device=device, dtype=dtype)
@ -91,10 +92,9 @@ def train_pose(
pose_extra = None pose_extra = None
# if useBodyPrior: # if useBodyPrior:
# body = vposer_layer() # body = vposer_layer()
# poZ = body.poZ_body # poZ = body.poZ_body
# pose_extra = body.pose_body # pose_extra = body.pose_body
# return joints based on current model state # return joints based on current model state
body_joints, cur_pose = pose_layer() body_joints, cur_pose = pose_layer()
@ -115,12 +115,13 @@ def train_pose(
body_mean_loss = 0.0 body_mean_loss = 0.0
if body_mean_loss: if body_mean_loss:
body_mean_loss = (cur_pose - body_mean_loss = (cur_pose -
body_mean_pose).pow(2).sum() * body_mean_weight body_mean_pose).pow(2).sum() * body_mean_weight
body_prior_loss = 0.0 body_prior_loss = 0.0
if useBodyPrior: if useBodyPrior:
# apply pose prior loss. # apply pose prior loss.
body_prior_loss = latent_body.pose_body.pow(2).sum() * body_prior_weight body_prior_loss = latent_body.pose_body.pow(
2).sum() * body_prior_weight
angle_prior_loss = 0.0 angle_prior_loss = 0.0
if useAnglePrior: if useAnglePrior:
@ -130,11 +131,11 @@ def train_pose(
angle_sum_loss = 0.0 angle_sum_loss = 0.0
if use_angle_sum_loss: if use_angle_sum_loss:
angle_sum_loss = clip_loss_layer(cur_pose) * angle_sum_weight angle_sum_loss = clip_loss_layer(cur_pose) # * angle_sum_weight
loss = loss + body_mean_loss + body_prior_loss + angle_prior_loss + angle_sum_loss loss = loss + body_mean_loss + body_prior_loss + angle_prior_loss + angle_sum_loss
return loss return loss
def optim_closure(): def optim_closure():
if torch.is_grad_enabled(): if torch.is_grad_enabled():
@ -191,32 +192,44 @@ def train_pose(
pbar.close() pbar.close()
print("Final result:", loss.item()) print("Final result:", loss.item())
return pose_layer.cur_out return pose_layer.cur_out, best_pose
def train_pose_with_conf( def train_pose_with_conf(
config, config,
camera: TorchCameraEstimate,
model: smplx.SMPL, model: smplx.SMPL,
keypoints, keypoints,
keypoint_conf, keypoint_conf,
camera: SimpleCamera,
device=torch.device('cpu'), device=torch.device('cpu'),
dtype=torch.float32, dtype=torch.float32,
renderer: Renderer = None, renderer: Renderer = None,
): ):
# configure PyTorch device and format # configure PyTorch device and format
dtype = torch.float64 # dtype = torch.float64
if 'device' in config['pose'] is not None: if 'device' in config['pose'] is not None:
device = torch.device(config['pose']['device']) device = torch.device(config['pose']['device'])
else: else:
device = torch.device('cpu') device = torch.device('cpu')
# create camera module
pose_camera, cam_trans, cam_int, cam_params = SimpleCamera.from_estimation_cam(
cam=camera,
use_intrinsics=config['pose']['useCameraIntrinsics'],
dtype=dtype,
device=device,
)
# apply transform to scene
if renderer is not None:
renderer.set_group_pose("body", cam_trans.cpu().numpy())
return train_pose( return train_pose(
model=model.to(dtype=dtype), model=model.to(dtype=dtype),
keypoints=keypoints, keypoints=keypoints,
keypoint_conf=keypoint_conf, keypoint_conf=keypoint_conf,
camera=camera, camera=pose_camera,
device=device, device=device,
dtype=dtype, dtype=dtype,
renderer=renderer, renderer=renderer,