mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
85 lines
3.1 KiB
Python
85 lines
3.1 KiB
Python
from collections import namedtuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from smplx.lbs import transform_mat
|
|
|
|
|
|
class PerspectiveCamera(nn.Module):
|
|
|
|
FOCAL_LENGTH = 5000
|
|
|
|
def __init__(self, rotation=None, translation=None,
|
|
focal_length_x=None, focal_length_y=None,
|
|
batch_size=1,
|
|
center=None, dtype=torch.float32, **kwargs):
|
|
super(PerspectiveCamera, self).__init__()
|
|
self.batch_size = batch_size
|
|
self.dtype = dtype
|
|
# Make a buffer so that PyTorch does not complain when creating
|
|
# the camera matrix
|
|
self.register_buffer('zero',
|
|
torch.zeros([batch_size], dtype=dtype))
|
|
|
|
if focal_length_x is None or type(focal_length_x) == float:
|
|
focal_length_x = torch.full(
|
|
[batch_size],
|
|
self.FOCAL_LENGTH if focal_length_x is None else
|
|
focal_length_x,
|
|
dtype=dtype)
|
|
|
|
if focal_length_y is None or type(focal_length_y) == float:
|
|
focal_length_y = torch.full(
|
|
[batch_size],
|
|
self.FOCAL_LENGTH if focal_length_y is None else
|
|
focal_length_y,
|
|
dtype=dtype)
|
|
|
|
self.register_buffer('focal_length_x', focal_length_x)
|
|
self.register_buffer('focal_length_y', focal_length_y)
|
|
|
|
if center is None:
|
|
center = torch.zeros([batch_size, 2], dtype=dtype)
|
|
self.register_buffer('center', center)
|
|
|
|
if rotation is None:
|
|
rotation = torch.eye(
|
|
3, dtype=dtype).unsqueeze(dim=0).repeat(batch_size, 1, 1)
|
|
|
|
rotation = nn.Parameter(rotation, requires_grad=True)
|
|
self.register_parameter('rotation', rotation)
|
|
|
|
if translation is None:
|
|
translation = torch.zeros([batch_size, 3], dtype=dtype)
|
|
|
|
translation = nn.Parameter(translation,
|
|
requires_grad=True)
|
|
self.register_parameter('translation', translation)
|
|
|
|
def forward(self, points):
|
|
device = points.device
|
|
|
|
with torch.no_grad():
|
|
camera_mat = torch.zeros([self.batch_size, 2, 2],
|
|
dtype=self.dtype, device=points.device)
|
|
camera_mat[:, 0, 0] = self.focal_length_x
|
|
camera_mat[:, 1, 1] = self.focal_length_y
|
|
|
|
camera_transform = transform_mat(self.rotation,
|
|
self.translation.unsqueeze(dim=-1))
|
|
homog_coord = torch.ones(list(points.shape)[:-1] + [1],
|
|
dtype=points.dtype,
|
|
device=device)
|
|
# Convert the points to homogeneous coordinates
|
|
points_h = torch.cat([points, homog_coord], dim=-1)
|
|
|
|
projected_points = torch.einsum('bki,bji->bjk',
|
|
[camera_transform, points_h])
|
|
|
|
img_points = torch.div(projected_points[:, :, :2],
|
|
projected_points[:, :, 2].unsqueeze(dim=-1))
|
|
img_points = torch.einsum('bki,bji->bjk', [camera_mat, img_points]) \
|
|
+ self.center.unsqueeze(dim=1)
|
|
return img_points
|