From ad7f658bf23f8b858c95f57743a8fa3f42f9c5f1 Mon Sep 17 00:00:00 2001 From: Wlad Meixner <9556979+gosticks@users.noreply.github.com> Date: Sat, 23 Jan 2021 12:45:56 +0100 Subject: [PATCH] WIP: try using MSELoss instead of L2 --- example_fitter.py | 19 ++++++++----- modules/transform.py | 64 ++++++++++++++++++++++---------------------- 2 files changed, 45 insertions(+), 38 deletions(-) diff --git a/example_fitter.py b/example_fitter.py index 20386c9..79f80ab 100644 --- a/example_fitter.py +++ b/example_fitter.py @@ -110,17 +110,24 @@ learning_rate = 1e-3 trans = Transform(dtype, device) proj = CameraProjSimple(dtype, device, -est_depth) optimizer = torch.optim.Adam(trans.parameters(), lr=learning_rate) +loss_layer = torch.nn.MSELoss() for t in range(50000): + homog_coord = torch.ones(list(smpl_torso.shape)[:-1] + [1], + dtype=smpl_torso.dtype, + device=device) + # Convert the points to homogeneous coordinates + points_h = torch.cat([smpl_torso, homog_coord], dim=-1) - points = trans(smpl_torso) + points = trans(points_h) points_2d = proj(points) # point wise differences diff = points_2d - keyp_torso # Compute cost function - loss = torch.norm(diff) + # loss = torch.norm(diff) + loss = loss_layer(keyp_torso, points_2d) if t % 100 == 99: print(t, loss.item()) @@ -130,10 +137,10 @@ for t in range(50000): with torch.no_grad(): R = trans.get_transform_mat().numpy() - translation = trans.translation.numpy() + translation = trans.translation.detach().numpy() # update model rendering - r.set_group_transform("body", R, translation) + r.set_homog_group_transform("body", R, translation) -for t in range(50000): - print("do cost evaluation here") +# for t in range(50000): +# print("do cost evaluation here") diff --git a/modules/transform.py b/modules/transform.py index bf03c8d..481a5ae 100644 --- a/modules/transform.py +++ b/modules/transform.py @@ -2,6 +2,7 @@ import math import torch import torch.nn.functional as F import torch.nn as nn +import torchgeometry as tgm class Transform(nn.Module): @@ -16,11 +17,8 @@ class Transform(nn.Module): translation = nn.Parameter(translation, requires_grad=True) self.register_parameter("translation", translation) - orientation = torch.ones((1, 3), device=device, dtype=dtype) - orientation = nn.init.xavier_uniform_( - orientation, gain=1.0) - orientation = orientation.clamp(-math.pi * 0.25, math.pi * 0.25) - orientation = nn.Parameter(orientation.squeeze(), requires_grad=True) + orientation = torch.rand((1, 3), device=device, dtype=dtype) + orientation = nn.Parameter(orientation, requires_grad=True) self.register_parameter("orientation", orientation) # self.roll = torch.randn( @@ -31,40 +29,42 @@ class Transform(nn.Module): # 1, device=device, dtype=dtype, requires_grad=True) # init addition buffers - tensor_0 = torch.zeros(1, device=device, dtype=dtype) - self.register_buffer("tensor_0", tensor_0) - tensor_1 = torch.ones(1, device=device, dtype=dtype) - self.register_buffer("tensor_1", tensor_1) + # tensor_0 = torch.zeros(1, device=device, dtype=dtype) + # self.register_buffer("tensor_0", tensor_0) + # tensor_1 = torch.ones(1, device=device, dtype=dtype) + # self.register_buffer("tensor_1", tensor_1) def get_transform_mat(self): - tensor_1 = self.tensor_1.squeeze() - tensor_0 = self.tensor_0.squeeze() - roll = self.orientation[0] - pitch = self.orientation[1] - yaw = self.orientation[2] + # tensor_1 = self.tensor_1.squeeze() + # tensor_0 = self.tensor_0.squeeze() + # roll = self.orientation[0] + # pitch = self.orientation[1] + # yaw = self.orientation[2] - RX = torch.stack([ - torch.stack([tensor_1, tensor_0, tensor_0]), - torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]), - torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3, 3) + # RX = torch.stack([ + # torch.stack([tensor_1, tensor_0, tensor_0]), + # torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]), + # torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3, 3) - RY = torch.stack([ - torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]), - torch.stack([tensor_0, tensor_1, tensor_0]), - torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3, 3) + # RY = torch.stack([ + # torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]), + # torch.stack([tensor_0, tensor_1, tensor_0]), + # torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3, 3) - RZ = torch.stack([ - torch.stack([torch.cos(yaw), -torch.sin(yaw), tensor_0]), - torch.stack([torch.sin(yaw), torch.cos(yaw), tensor_0]), - torch.stack([tensor_0, tensor_0, tensor_1])]).reshape(3, 3) + # RZ = torch.stack([ + # torch.stack([torch.cos(yaw), -torch.sin(yaw), tensor_0]), + # torch.stack([torch.sin(yaw), torch.cos(yaw), tensor_0]), + # torch.stack([tensor_0, tensor_0, tensor_1])]).reshape(3, 3) - R = torch.mm(RX, RY) - R = torch.mm(R, RZ) + # R = torch.mm(RX, RY) + # R = torch.mm(R, RZ) # R = torch.mm(RZ, RY) #R = torch.mm(R, RX) - return R + + + transform = tgm.angle_axis_to_rotation_matrix(self.orientation) + return transform def forward(self, joints): - R = self.get_transform_mat() - - return joints @ R + self.translation + R = self.get_transform_mat().squeeze() + return joints @ R + F.pad(self.translation, (0,1), value=1)