mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
WIP: try using MSELoss instead of L2
This commit is contained in:
parent
3aada5b694
commit
ad7f658bf2
@ -110,17 +110,24 @@ learning_rate = 1e-3
|
||||
trans = Transform(dtype, device)
|
||||
proj = CameraProjSimple(dtype, device, -est_depth)
|
||||
optimizer = torch.optim.Adam(trans.parameters(), lr=learning_rate)
|
||||
loss_layer = torch.nn.MSELoss()
|
||||
|
||||
for t in range(50000):
|
||||
homog_coord = torch.ones(list(smpl_torso.shape)[:-1] + [1],
|
||||
dtype=smpl_torso.dtype,
|
||||
device=device)
|
||||
# Convert the points to homogeneous coordinates
|
||||
points_h = torch.cat([smpl_torso, homog_coord], dim=-1)
|
||||
|
||||
points = trans(smpl_torso)
|
||||
points = trans(points_h)
|
||||
points_2d = proj(points)
|
||||
|
||||
# point wise differences
|
||||
diff = points_2d - keyp_torso
|
||||
|
||||
# Compute cost function
|
||||
loss = torch.norm(diff)
|
||||
# loss = torch.norm(diff)
|
||||
loss = loss_layer(keyp_torso, points_2d)
|
||||
if t % 100 == 99:
|
||||
print(t, loss.item())
|
||||
|
||||
@ -130,10 +137,10 @@ for t in range(50000):
|
||||
|
||||
with torch.no_grad():
|
||||
R = trans.get_transform_mat().numpy()
|
||||
translation = trans.translation.numpy()
|
||||
translation = trans.translation.detach().numpy()
|
||||
# update model rendering
|
||||
r.set_group_transform("body", R, translation)
|
||||
r.set_homog_group_transform("body", R, translation)
|
||||
|
||||
|
||||
for t in range(50000):
|
||||
print("do cost evaluation here")
|
||||
# for t in range(50000):
|
||||
# print("do cost evaluation here")
|
||||
|
||||
@ -2,6 +2,7 @@ import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import torchgeometry as tgm
|
||||
|
||||
|
||||
class Transform(nn.Module):
|
||||
@ -16,11 +17,8 @@ class Transform(nn.Module):
|
||||
translation = nn.Parameter(translation, requires_grad=True)
|
||||
self.register_parameter("translation", translation)
|
||||
|
||||
orientation = torch.ones((1, 3), device=device, dtype=dtype)
|
||||
orientation = nn.init.xavier_uniform_(
|
||||
orientation, gain=1.0)
|
||||
orientation = orientation.clamp(-math.pi * 0.25, math.pi * 0.25)
|
||||
orientation = nn.Parameter(orientation.squeeze(), requires_grad=True)
|
||||
orientation = torch.rand((1, 3), device=device, dtype=dtype)
|
||||
orientation = nn.Parameter(orientation, requires_grad=True)
|
||||
self.register_parameter("orientation", orientation)
|
||||
|
||||
# self.roll = torch.randn(
|
||||
@ -31,40 +29,42 @@ class Transform(nn.Module):
|
||||
# 1, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
# init addition buffers
|
||||
tensor_0 = torch.zeros(1, device=device, dtype=dtype)
|
||||
self.register_buffer("tensor_0", tensor_0)
|
||||
tensor_1 = torch.ones(1, device=device, dtype=dtype)
|
||||
self.register_buffer("tensor_1", tensor_1)
|
||||
# tensor_0 = torch.zeros(1, device=device, dtype=dtype)
|
||||
# self.register_buffer("tensor_0", tensor_0)
|
||||
# tensor_1 = torch.ones(1, device=device, dtype=dtype)
|
||||
# self.register_buffer("tensor_1", tensor_1)
|
||||
|
||||
def get_transform_mat(self):
|
||||
tensor_1 = self.tensor_1.squeeze()
|
||||
tensor_0 = self.tensor_0.squeeze()
|
||||
roll = self.orientation[0]
|
||||
pitch = self.orientation[1]
|
||||
yaw = self.orientation[2]
|
||||
# tensor_1 = self.tensor_1.squeeze()
|
||||
# tensor_0 = self.tensor_0.squeeze()
|
||||
# roll = self.orientation[0]
|
||||
# pitch = self.orientation[1]
|
||||
# yaw = self.orientation[2]
|
||||
|
||||
RX = torch.stack([
|
||||
torch.stack([tensor_1, tensor_0, tensor_0]),
|
||||
torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),
|
||||
torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3, 3)
|
||||
# RX = torch.stack([
|
||||
# torch.stack([tensor_1, tensor_0, tensor_0]),
|
||||
# torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),
|
||||
# torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3, 3)
|
||||
|
||||
RY = torch.stack([
|
||||
torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]),
|
||||
torch.stack([tensor_0, tensor_1, tensor_0]),
|
||||
torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3, 3)
|
||||
# RY = torch.stack([
|
||||
# torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]),
|
||||
# torch.stack([tensor_0, tensor_1, tensor_0]),
|
||||
# torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3, 3)
|
||||
|
||||
RZ = torch.stack([
|
||||
torch.stack([torch.cos(yaw), -torch.sin(yaw), tensor_0]),
|
||||
torch.stack([torch.sin(yaw), torch.cos(yaw), tensor_0]),
|
||||
torch.stack([tensor_0, tensor_0, tensor_1])]).reshape(3, 3)
|
||||
# RZ = torch.stack([
|
||||
# torch.stack([torch.cos(yaw), -torch.sin(yaw), tensor_0]),
|
||||
# torch.stack([torch.sin(yaw), torch.cos(yaw), tensor_0]),
|
||||
# torch.stack([tensor_0, tensor_0, tensor_1])]).reshape(3, 3)
|
||||
|
||||
R = torch.mm(RX, RY)
|
||||
R = torch.mm(R, RZ)
|
||||
# R = torch.mm(RX, RY)
|
||||
# R = torch.mm(R, RZ)
|
||||
# R = torch.mm(RZ, RY)
|
||||
#R = torch.mm(R, RX)
|
||||
return R
|
||||
|
||||
|
||||
transform = tgm.angle_axis_to_rotation_matrix(self.orientation)
|
||||
return transform
|
||||
|
||||
def forward(self, joints):
|
||||
R = self.get_transform_mat()
|
||||
|
||||
return joints @ R + self.translation
|
||||
R = self.get_transform_mat().squeeze()
|
||||
return joints @ R + F.pad(self.translation, (0,1), value=1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user