import math import torch import torch.nn.functional as F import torch.nn as nn import torchgeometry as tgm class Transform(nn.Module): def __init__(self, dtype, device) -> None: super(Transform, self).__init__() self.dtype = dtype self.device = device # init parameters translation = torch.zeros(3, device=device, dtype=dtype) translation = nn.Parameter(translation, requires_grad=True) self.register_parameter("translation", translation) orientation = torch.rand((1, 3), device=device, dtype=dtype) orientation = nn.Parameter(orientation, requires_grad=True) self.register_parameter("orientation", orientation) # self.roll = torch.randn( # 1, device=device, dtype=dtype, requires_grad=True) # self.yaw = torch.randn( # 1, device=device, dtype=dtype, requires_grad=True) # self.pitch = torch.randn( # 1, device=device, dtype=dtype, requires_grad=True) # init addition buffers # tensor_0 = torch.zeros(1, device=device, dtype=dtype) # self.register_buffer("tensor_0", tensor_0) # tensor_1 = torch.ones(1, device=device, dtype=dtype) # self.register_buffer("tensor_1", tensor_1) def get_transform_mat(self): # tensor_1 = self.tensor_1.squeeze() # tensor_0 = self.tensor_0.squeeze() # roll = self.orientation[0] # pitch = self.orientation[1] # yaw = self.orientation[2] # RX = torch.stack([ # torch.stack([tensor_1, tensor_0, tensor_0]), # torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]), # torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3, 3) # RY = torch.stack([ # torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]), # torch.stack([tensor_0, tensor_1, tensor_0]), # torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3, 3) # RZ = torch.stack([ # torch.stack([torch.cos(yaw), -torch.sin(yaw), tensor_0]), # torch.stack([torch.sin(yaw), torch.cos(yaw), tensor_0]), # torch.stack([tensor_0, tensor_0, tensor_1])]).reshape(3, 3) # R = torch.mm(RX, RY) # R = torch.mm(R, RZ) # R = torch.mm(RZ, RY) #R = torch.mm(R, RX) transform = tgm.angle_axis_to_rotation_matrix(self.orientation) return transform def forward(self, joints): R = self.get_transform_mat().squeeze() return joints @ R + F.pad(self.translation, (0,1), value=1)