body-pose-animation/modules/perspective_cam.py
2021-02-26 12:50:38 +01:00

85 lines
3.1 KiB
Python

from collections import namedtuple
import torch
import torch.nn as nn
from smplx.lbs import transform_mat
class PerspectiveCamera(nn.Module):
FOCAL_LENGTH = 5000
def __init__(self, rotation=None, translation=None,
focal_length_x=None, focal_length_y=None,
batch_size=1,
center=None, dtype=torch.float32, **kwargs):
super(PerspectiveCamera, self).__init__()
self.batch_size = batch_size
self.dtype = dtype
# Make a buffer so that PyTorch does not complain when creating
# the camera matrix
self.register_buffer('zero',
torch.zeros([batch_size], dtype=dtype))
if focal_length_x is None or type(focal_length_x) == float:
focal_length_x = torch.full(
[batch_size],
self.FOCAL_LENGTH if focal_length_x is None else
focal_length_x,
dtype=dtype)
if focal_length_y is None or type(focal_length_y) == float:
focal_length_y = torch.full(
[batch_size],
self.FOCAL_LENGTH if focal_length_y is None else
focal_length_y,
dtype=dtype)
self.register_buffer('focal_length_x', focal_length_x)
self.register_buffer('focal_length_y', focal_length_y)
if center is None:
center = torch.zeros([batch_size, 2], dtype=dtype)
self.register_buffer('center', center)
if rotation is None:
rotation = torch.eye(
3, dtype=dtype).unsqueeze(dim=0).repeat(batch_size, 1, 1)
rotation = nn.Parameter(rotation, requires_grad=True)
self.register_parameter('rotation', rotation)
if translation is None:
translation = torch.zeros([batch_size, 3], dtype=dtype)
translation = nn.Parameter(translation,
requires_grad=True)
self.register_parameter('translation', translation)
def forward(self, points):
device = points.device
with torch.no_grad():
camera_mat = torch.zeros([self.batch_size, 2, 2],
dtype=self.dtype, device=points.device)
camera_mat[:, 0, 0] = self.focal_length_x
camera_mat[:, 1, 1] = self.focal_length_y
camera_transform = transform_mat(self.rotation,
self.translation.unsqueeze(dim=-1))
homog_coord = torch.ones(list(points.shape)[:-1] + [1],
dtype=points.dtype,
device=device)
# Convert the points to homogeneous coordinates
points_h = torch.cat([points, homog_coord], dim=-1)
projected_points = torch.einsum('bki,bji->bjk',
[camera_transform, points_h])
img_points = torch.div(projected_points[:, :, :2],
projected_points[:, :, 2].unsqueeze(dim=-1))
img_points = torch.einsum('bki,bji->bjk', [camera_mat, img_points]) \
+ self.center.unsqueeze(dim=1)
return img_points