Hack Patience implemented

This commit is contained in:
Umesh Ramchandani 2021-01-26 22:33:15 +01:00
parent 403992baae
commit 10f7ed5215

View File

@ -107,7 +107,8 @@ class CameraEstimate:
class TorchCameraEstimate(CameraEstimate):
def estimate_camera_pos(self):
def estimate_camera_pos(self):
self.memory = None
translation = torch.zeros(
1, 3, requires_grad=True, dtype=self.dtype, device=self.device)
rotation = torch.rand(1, 3, requires_grad=True,
@ -133,9 +134,12 @@ class TorchCameraEstimate(CameraEstimate):
loss_layer = torch.nn.MSELoss()
loss = 10000
while loss > 3e-4:
stop = True
tol = 3e-4
print("Estimating Initial transform...")
pbar = tqdm(total=100)
current = 0
while stop:
y_pred = self.C(params, init_points_3d_prepared)
loss = loss_layer(init_points_2d, y_pred)
@ -148,7 +152,19 @@ class TorchCameraEstimate(CameraEstimate):
current_pose = current_pose.detach().numpy()
self.renderer.set_group_pose("body", current_pose)
per = int((tol/loss*100).item())
if per > 100:
pbar.update(abs(100 - current))
current = 100
else:
pbar.update(per - current)
current = per
stop = loss > tol
if stop == True:
stop = self.patience_module(loss, 5)
pbar.update(abs(100 - current))
pbar.close()
self.memory = None
transform_matrix = self.torch_params_to_pose(params)
current_pose = transform_matrix.detach().numpy()
@ -196,7 +212,6 @@ class TorchCameraEstimate(CameraEstimate):
else:
loss.backward()
opt2.step()
stop = loss > cam_tol
self.renderer.scene.set_pose(
self.camera_renderer, self.torch_params_to_pose(params).detach().numpy())
per = int((cam_tol/loss*100).item())
@ -205,7 +220,10 @@ class TorchCameraEstimate(CameraEstimate):
else:
pbar.update(per - current)
current = per
# print(camera_translation, camera_rotation, cam_tol/loss*100)
stop = loss > cam_tol
if stop == True:
stop = self.patience_module(loss, 5)
pbar.update(100 - current)
pbar.close()
camera_transform_matrix = camera_intrinsics @ self.torch_params_to_pose(
params)
@ -246,6 +264,23 @@ class TorchCameraEstimate(CameraEstimate):
y_pred = points @ rotation.as_matrix() + translation
return y_pred
def patience_module(self, variable, counter: int):
if self.memory == None:
self.memory=torch.clone(variable)
self.patience_count = 0
return True
if self.patience_count >= counter:
self.memory == None
self.patience_count = 0
return False
else:
if torch.isclose(variable, self.memory).item():
self.patience_count += 1
return True
else:
self.patience_count = 0
self.memory=torch.clone(variable)
return True
sample_index = 0