mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
Hack Patience implemented
This commit is contained in:
parent
403992baae
commit
10f7ed5215
@ -107,7 +107,8 @@ class CameraEstimate:
|
||||
|
||||
|
||||
class TorchCameraEstimate(CameraEstimate):
|
||||
def estimate_camera_pos(self):
|
||||
def estimate_camera_pos(self):
|
||||
self.memory = None
|
||||
translation = torch.zeros(
|
||||
1, 3, requires_grad=True, dtype=self.dtype, device=self.device)
|
||||
rotation = torch.rand(1, 3, requires_grad=True,
|
||||
@ -133,9 +134,12 @@ class TorchCameraEstimate(CameraEstimate):
|
||||
|
||||
loss_layer = torch.nn.MSELoss()
|
||||
|
||||
loss = 10000
|
||||
|
||||
while loss > 3e-4:
|
||||
stop = True
|
||||
tol = 3e-4
|
||||
print("Estimating Initial transform...")
|
||||
pbar = tqdm(total=100)
|
||||
current = 0
|
||||
while stop:
|
||||
y_pred = self.C(params, init_points_3d_prepared)
|
||||
loss = loss_layer(init_points_2d, y_pred)
|
||||
|
||||
@ -148,7 +152,19 @@ class TorchCameraEstimate(CameraEstimate):
|
||||
current_pose = current_pose.detach().numpy()
|
||||
|
||||
self.renderer.set_group_pose("body", current_pose)
|
||||
|
||||
per = int((tol/loss*100).item())
|
||||
if per > 100:
|
||||
pbar.update(abs(100 - current))
|
||||
current = 100
|
||||
else:
|
||||
pbar.update(per - current)
|
||||
current = per
|
||||
stop = loss > tol
|
||||
if stop == True:
|
||||
stop = self.patience_module(loss, 5)
|
||||
pbar.update(abs(100 - current))
|
||||
pbar.close()
|
||||
self.memory = None
|
||||
transform_matrix = self.torch_params_to_pose(params)
|
||||
current_pose = transform_matrix.detach().numpy()
|
||||
|
||||
@ -196,7 +212,6 @@ class TorchCameraEstimate(CameraEstimate):
|
||||
else:
|
||||
loss.backward()
|
||||
opt2.step()
|
||||
stop = loss > cam_tol
|
||||
self.renderer.scene.set_pose(
|
||||
self.camera_renderer, self.torch_params_to_pose(params).detach().numpy())
|
||||
per = int((cam_tol/loss*100).item())
|
||||
@ -205,7 +220,10 @@ class TorchCameraEstimate(CameraEstimate):
|
||||
else:
|
||||
pbar.update(per - current)
|
||||
current = per
|
||||
# print(camera_translation, camera_rotation, cam_tol/loss*100)
|
||||
stop = loss > cam_tol
|
||||
if stop == True:
|
||||
stop = self.patience_module(loss, 5)
|
||||
pbar.update(100 - current)
|
||||
pbar.close()
|
||||
camera_transform_matrix = camera_intrinsics @ self.torch_params_to_pose(
|
||||
params)
|
||||
@ -246,6 +264,23 @@ class TorchCameraEstimate(CameraEstimate):
|
||||
y_pred = points @ rotation.as_matrix() + translation
|
||||
return y_pred
|
||||
|
||||
def patience_module(self, variable, counter: int):
|
||||
if self.memory == None:
|
||||
self.memory=torch.clone(variable)
|
||||
self.patience_count = 0
|
||||
return True
|
||||
if self.patience_count >= counter:
|
||||
self.memory == None
|
||||
self.patience_count = 0
|
||||
return False
|
||||
else:
|
||||
if torch.isclose(variable, self.memory).item():
|
||||
self.patience_count += 1
|
||||
return True
|
||||
else:
|
||||
self.patience_count = 0
|
||||
self.memory=torch.clone(variable)
|
||||
return True
|
||||
|
||||
sample_index = 0
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user