mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
import torchgeometry as tgm
|
|
|
|
|
|
class Transform(nn.Module):
|
|
def __init__(self, dtype, device) -> None:
|
|
super(Transform, self).__init__()
|
|
|
|
self.dtype = dtype
|
|
self.device = device
|
|
|
|
# init parameters
|
|
translation = torch.rand(3, device=device, dtype=dtype)
|
|
translation = nn.Parameter(translation, requires_grad=True)
|
|
self.register_parameter("translation", translation)
|
|
|
|
orientation = torch.rand((1, 3), device=device, dtype=dtype)
|
|
orientation = nn.Parameter(orientation, requires_grad=True)
|
|
self.register_parameter("orientation", orientation)
|
|
|
|
def get_transform_mat(self, with_translate=False):
|
|
|
|
transform = tgm.angle_axis_to_rotation_matrix(self.orientation)
|
|
# print(transform.shape)
|
|
if with_translate:
|
|
transform[:, :3, 3] = self.translation
|
|
return transform
|
|
|
|
def forward(self, joints):
|
|
R = self.get_transform_mat()
|
|
translation = F.pad(self.translation, (0, 1), value=1)
|
|
return joints @ R + translation
|