mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
Hack Patience implemented
This commit is contained in:
parent
403992baae
commit
10f7ed5215
@ -107,7 +107,8 @@ class CameraEstimate:
|
|||||||
|
|
||||||
|
|
||||||
class TorchCameraEstimate(CameraEstimate):
|
class TorchCameraEstimate(CameraEstimate):
|
||||||
def estimate_camera_pos(self):
|
def estimate_camera_pos(self):
|
||||||
|
self.memory = None
|
||||||
translation = torch.zeros(
|
translation = torch.zeros(
|
||||||
1, 3, requires_grad=True, dtype=self.dtype, device=self.device)
|
1, 3, requires_grad=True, dtype=self.dtype, device=self.device)
|
||||||
rotation = torch.rand(1, 3, requires_grad=True,
|
rotation = torch.rand(1, 3, requires_grad=True,
|
||||||
@ -133,9 +134,12 @@ class TorchCameraEstimate(CameraEstimate):
|
|||||||
|
|
||||||
loss_layer = torch.nn.MSELoss()
|
loss_layer = torch.nn.MSELoss()
|
||||||
|
|
||||||
loss = 10000
|
stop = True
|
||||||
|
tol = 3e-4
|
||||||
while loss > 3e-4:
|
print("Estimating Initial transform...")
|
||||||
|
pbar = tqdm(total=100)
|
||||||
|
current = 0
|
||||||
|
while stop:
|
||||||
y_pred = self.C(params, init_points_3d_prepared)
|
y_pred = self.C(params, init_points_3d_prepared)
|
||||||
loss = loss_layer(init_points_2d, y_pred)
|
loss = loss_layer(init_points_2d, y_pred)
|
||||||
|
|
||||||
@ -148,7 +152,19 @@ class TorchCameraEstimate(CameraEstimate):
|
|||||||
current_pose = current_pose.detach().numpy()
|
current_pose = current_pose.detach().numpy()
|
||||||
|
|
||||||
self.renderer.set_group_pose("body", current_pose)
|
self.renderer.set_group_pose("body", current_pose)
|
||||||
|
per = int((tol/loss*100).item())
|
||||||
|
if per > 100:
|
||||||
|
pbar.update(abs(100 - current))
|
||||||
|
current = 100
|
||||||
|
else:
|
||||||
|
pbar.update(per - current)
|
||||||
|
current = per
|
||||||
|
stop = loss > tol
|
||||||
|
if stop == True:
|
||||||
|
stop = self.patience_module(loss, 5)
|
||||||
|
pbar.update(abs(100 - current))
|
||||||
|
pbar.close()
|
||||||
|
self.memory = None
|
||||||
transform_matrix = self.torch_params_to_pose(params)
|
transform_matrix = self.torch_params_to_pose(params)
|
||||||
current_pose = transform_matrix.detach().numpy()
|
current_pose = transform_matrix.detach().numpy()
|
||||||
|
|
||||||
@ -196,7 +212,6 @@ class TorchCameraEstimate(CameraEstimate):
|
|||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
opt2.step()
|
opt2.step()
|
||||||
stop = loss > cam_tol
|
|
||||||
self.renderer.scene.set_pose(
|
self.renderer.scene.set_pose(
|
||||||
self.camera_renderer, self.torch_params_to_pose(params).detach().numpy())
|
self.camera_renderer, self.torch_params_to_pose(params).detach().numpy())
|
||||||
per = int((cam_tol/loss*100).item())
|
per = int((cam_tol/loss*100).item())
|
||||||
@ -205,7 +220,10 @@ class TorchCameraEstimate(CameraEstimate):
|
|||||||
else:
|
else:
|
||||||
pbar.update(per - current)
|
pbar.update(per - current)
|
||||||
current = per
|
current = per
|
||||||
# print(camera_translation, camera_rotation, cam_tol/loss*100)
|
stop = loss > cam_tol
|
||||||
|
if stop == True:
|
||||||
|
stop = self.patience_module(loss, 5)
|
||||||
|
pbar.update(100 - current)
|
||||||
pbar.close()
|
pbar.close()
|
||||||
camera_transform_matrix = camera_intrinsics @ self.torch_params_to_pose(
|
camera_transform_matrix = camera_intrinsics @ self.torch_params_to_pose(
|
||||||
params)
|
params)
|
||||||
@ -246,6 +264,23 @@ class TorchCameraEstimate(CameraEstimate):
|
|||||||
y_pred = points @ rotation.as_matrix() + translation
|
y_pred = points @ rotation.as_matrix() + translation
|
||||||
return y_pred
|
return y_pred
|
||||||
|
|
||||||
|
def patience_module(self, variable, counter: int):
|
||||||
|
if self.memory == None:
|
||||||
|
self.memory=torch.clone(variable)
|
||||||
|
self.patience_count = 0
|
||||||
|
return True
|
||||||
|
if self.patience_count >= counter:
|
||||||
|
self.memory == None
|
||||||
|
self.patience_count = 0
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
if torch.isclose(variable, self.memory).item():
|
||||||
|
self.patience_count += 1
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.patience_count = 0
|
||||||
|
self.memory=torch.clone(variable)
|
||||||
|
return True
|
||||||
|
|
||||||
sample_index = 0
|
sample_index = 0
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user