mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
71 lines
2.6 KiB
Python
71 lines
2.6 KiB
Python
import math
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
import torchgeometry as tgm
|
|
|
|
|
|
class Transform(nn.Module):
|
|
def __init__(self, dtype, device) -> None:
|
|
super(Transform, self).__init__()
|
|
|
|
self.dtype = dtype
|
|
self.device = device
|
|
|
|
# init parameters
|
|
translation = torch.zeros(3, device=device, dtype=dtype)
|
|
translation = nn.Parameter(translation, requires_grad=True)
|
|
self.register_parameter("translation", translation)
|
|
|
|
orientation = torch.rand((1, 3), device=device, dtype=dtype)
|
|
orientation = nn.Parameter(orientation, requires_grad=True)
|
|
self.register_parameter("orientation", orientation)
|
|
|
|
# self.roll = torch.randn(
|
|
# 1, device=device, dtype=dtype, requires_grad=True)
|
|
# self.yaw = torch.randn(
|
|
# 1, device=device, dtype=dtype, requires_grad=True)
|
|
# self.pitch = torch.randn(
|
|
# 1, device=device, dtype=dtype, requires_grad=True)
|
|
|
|
# init addition buffers
|
|
# tensor_0 = torch.zeros(1, device=device, dtype=dtype)
|
|
# self.register_buffer("tensor_0", tensor_0)
|
|
# tensor_1 = torch.ones(1, device=device, dtype=dtype)
|
|
# self.register_buffer("tensor_1", tensor_1)
|
|
|
|
def get_transform_mat(self):
|
|
# tensor_1 = self.tensor_1.squeeze()
|
|
# tensor_0 = self.tensor_0.squeeze()
|
|
# roll = self.orientation[0]
|
|
# pitch = self.orientation[1]
|
|
# yaw = self.orientation[2]
|
|
|
|
# RX = torch.stack([
|
|
# torch.stack([tensor_1, tensor_0, tensor_0]),
|
|
# torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),
|
|
# torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3, 3)
|
|
|
|
# RY = torch.stack([
|
|
# torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]),
|
|
# torch.stack([tensor_0, tensor_1, tensor_0]),
|
|
# torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3, 3)
|
|
|
|
# RZ = torch.stack([
|
|
# torch.stack([torch.cos(yaw), -torch.sin(yaw), tensor_0]),
|
|
# torch.stack([torch.sin(yaw), torch.cos(yaw), tensor_0]),
|
|
# torch.stack([tensor_0, tensor_0, tensor_1])]).reshape(3, 3)
|
|
|
|
# R = torch.mm(RX, RY)
|
|
# R = torch.mm(R, RZ)
|
|
# R = torch.mm(RZ, RY)
|
|
#R = torch.mm(R, RX)
|
|
|
|
|
|
transform = tgm.angle_axis_to_rotation_matrix(self.orientation)
|
|
return transform
|
|
|
|
def forward(self, joints):
|
|
R = self.get_transform_mat()
|
|
return joints @ R + F.pad(self.translation, (0,1), value=1)
|