WIP: try using MSELoss instead of L2

This commit is contained in:
Wlad Meixner 2021-01-23 12:45:56 +01:00 committed by Wlad
parent 3aada5b694
commit ad7f658bf2
2 changed files with 45 additions and 38 deletions

View File

@ -110,17 +110,24 @@ learning_rate = 1e-3
trans = Transform(dtype, device)
proj = CameraProjSimple(dtype, device, -est_depth)
optimizer = torch.optim.Adam(trans.parameters(), lr=learning_rate)
loss_layer = torch.nn.MSELoss()
for t in range(50000):
homog_coord = torch.ones(list(smpl_torso.shape)[:-1] + [1],
dtype=smpl_torso.dtype,
device=device)
# Convert the points to homogeneous coordinates
points_h = torch.cat([smpl_torso, homog_coord], dim=-1)
points = trans(smpl_torso)
points = trans(points_h)
points_2d = proj(points)
# point wise differences
diff = points_2d - keyp_torso
# Compute cost function
loss = torch.norm(diff)
# loss = torch.norm(diff)
loss = loss_layer(keyp_torso, points_2d)
if t % 100 == 99:
print(t, loss.item())
@ -130,10 +137,10 @@ for t in range(50000):
with torch.no_grad():
R = trans.get_transform_mat().numpy()
translation = trans.translation.numpy()
translation = trans.translation.detach().numpy()
# update model rendering
r.set_group_transform("body", R, translation)
r.set_homog_group_transform("body", R, translation)
for t in range(50000):
print("do cost evaluation here")
# for t in range(50000):
# print("do cost evaluation here")

View File

@ -2,6 +2,7 @@ import math
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchgeometry as tgm
class Transform(nn.Module):
@ -16,11 +17,8 @@ class Transform(nn.Module):
translation = nn.Parameter(translation, requires_grad=True)
self.register_parameter("translation", translation)
orientation = torch.ones((1, 3), device=device, dtype=dtype)
orientation = nn.init.xavier_uniform_(
orientation, gain=1.0)
orientation = orientation.clamp(-math.pi * 0.25, math.pi * 0.25)
orientation = nn.Parameter(orientation.squeeze(), requires_grad=True)
orientation = torch.rand((1, 3), device=device, dtype=dtype)
orientation = nn.Parameter(orientation, requires_grad=True)
self.register_parameter("orientation", orientation)
# self.roll = torch.randn(
@ -31,40 +29,42 @@ class Transform(nn.Module):
# 1, device=device, dtype=dtype, requires_grad=True)
# init addition buffers
tensor_0 = torch.zeros(1, device=device, dtype=dtype)
self.register_buffer("tensor_0", tensor_0)
tensor_1 = torch.ones(1, device=device, dtype=dtype)
self.register_buffer("tensor_1", tensor_1)
# tensor_0 = torch.zeros(1, device=device, dtype=dtype)
# self.register_buffer("tensor_0", tensor_0)
# tensor_1 = torch.ones(1, device=device, dtype=dtype)
# self.register_buffer("tensor_1", tensor_1)
def get_transform_mat(self):
tensor_1 = self.tensor_1.squeeze()
tensor_0 = self.tensor_0.squeeze()
roll = self.orientation[0]
pitch = self.orientation[1]
yaw = self.orientation[2]
# tensor_1 = self.tensor_1.squeeze()
# tensor_0 = self.tensor_0.squeeze()
# roll = self.orientation[0]
# pitch = self.orientation[1]
# yaw = self.orientation[2]
RX = torch.stack([
torch.stack([tensor_1, tensor_0, tensor_0]),
torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),
torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3, 3)
# RX = torch.stack([
# torch.stack([tensor_1, tensor_0, tensor_0]),
# torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),
# torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3, 3)
RY = torch.stack([
torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]),
torch.stack([tensor_0, tensor_1, tensor_0]),
torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3, 3)
# RY = torch.stack([
# torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]),
# torch.stack([tensor_0, tensor_1, tensor_0]),
# torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3, 3)
RZ = torch.stack([
torch.stack([torch.cos(yaw), -torch.sin(yaw), tensor_0]),
torch.stack([torch.sin(yaw), torch.cos(yaw), tensor_0]),
torch.stack([tensor_0, tensor_0, tensor_1])]).reshape(3, 3)
# RZ = torch.stack([
# torch.stack([torch.cos(yaw), -torch.sin(yaw), tensor_0]),
# torch.stack([torch.sin(yaw), torch.cos(yaw), tensor_0]),
# torch.stack([tensor_0, tensor_0, tensor_1])]).reshape(3, 3)
R = torch.mm(RX, RY)
R = torch.mm(R, RZ)
# R = torch.mm(RX, RY)
# R = torch.mm(R, RZ)
# R = torch.mm(RZ, RY)
#R = torch.mm(R, RX)
return R
transform = tgm.angle_axis_to_rotation_matrix(self.orientation)
return transform
def forward(self, joints):
R = self.get_transform_mat()
return joints @ R + self.translation
R = self.get_transform_mat().squeeze()
return joints @ R + F.pad(self.translation, (0,1), value=1)