mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
fix: minor formatting issues
This commit is contained in:
parent
ec7261621d
commit
dbcab9cffc
3
.gitignore
vendored
3
.gitignore
vendored
@ -98,4 +98,5 @@ reference
|
|||||||
|
|
||||||
vposer_v1_0
|
vposer_v1_0
|
||||||
*.avi
|
*.avi
|
||||||
results/
|
results/
|
||||||
|
output/
|
||||||
@ -25,7 +25,11 @@ class CameraEstimate:
|
|||||||
image_path=None,
|
image_path=None,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=torch.device("cpu"),
|
device=torch.device("cpu"),
|
||||||
|
verbose=True,
|
||||||
|
use_progress_bar=True,
|
||||||
est_scale=1):
|
est_scale=1):
|
||||||
|
self.use_progress_bar = use_progress_bar
|
||||||
|
self.verbose = verbose
|
||||||
self.model = model
|
self.model = model
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.output_model = model(return_verts=True)
|
self.output_model = model(return_verts=True)
|
||||||
@ -45,7 +49,8 @@ class CameraEstimate:
|
|||||||
return torso_keypoints_2d, torso_keypoints_3d
|
return torso_keypoints_2d, torso_keypoints_3d
|
||||||
|
|
||||||
def visualize_mesh(self, keypoints, smpl_points):
|
def visualize_mesh(self, keypoints, smpl_points):
|
||||||
|
if self.renderer is None:
|
||||||
|
return
|
||||||
# hardcoded scaling factor
|
# hardcoded scaling factor
|
||||||
# scaling_factor = 1
|
# scaling_factor = 1
|
||||||
# smpl_points /= scaling_factor
|
# smpl_points /= scaling_factor
|
||||||
@ -82,8 +87,9 @@ class CameraEstimate:
|
|||||||
current_pose = self.params_to_pose(params)
|
current_pose = self.params_to_pose(params)
|
||||||
|
|
||||||
# TODO: use renderer.py methods
|
# TODO: use renderer.py methods
|
||||||
self.renderer.scene.set_group_pose("body", current_pose)
|
if self.renderer is not None:
|
||||||
# self.renderer.scene.set_pose(self.verts, current_pose)
|
self.renderer.scene.set_group_pose("body", current_pose)
|
||||||
|
# self.renderer.scene.set_pose(self.verts, current_pose)
|
||||||
|
|
||||||
def params_to_pose(self, params):
|
def params_to_pose(self, params):
|
||||||
pose = np.eye(4)
|
pose = np.eye(4)
|
||||||
@ -96,7 +102,7 @@ class CameraEstimate:
|
|||||||
translation = np.zeros(3)
|
translation = np.zeros(3)
|
||||||
rotation = np.random.rand(3) * 2 * np.pi
|
rotation = np.random.rand(3) * 2 * np.pi
|
||||||
params = np.concatenate((translation, rotation))
|
params = np.concatenate((translation, rotation))
|
||||||
print(params)
|
# print(params)
|
||||||
|
|
||||||
init_points_2d, init_points_3d = self.get_torso_keypoints()
|
init_points_2d, init_points_3d = self.get_torso_keypoints()
|
||||||
|
|
||||||
@ -104,7 +110,7 @@ class CameraEstimate:
|
|||||||
|
|
||||||
res = minimize(self.sum_of_squares, x0=params, args=(init_points_3d, init_points_2d),
|
res = minimize(self.sum_of_squares, x0=params, args=(init_points_3d, init_points_2d),
|
||||||
callback=self.iteration_callback, tol=1e-4, method="BFGS")
|
callback=self.iteration_callback, tol=1e-4, method="BFGS")
|
||||||
print(res)
|
# print(res)
|
||||||
|
|
||||||
transform_matrix = self.params_to_pose(res.x)
|
transform_matrix = self.params_to_pose(res.x)
|
||||||
return transform_matrix
|
return transform_matrix
|
||||||
@ -141,8 +147,10 @@ class TorchCameraEstimate(CameraEstimate):
|
|||||||
|
|
||||||
stop = True
|
stop = True
|
||||||
tol = 3e-4
|
tol = 3e-4
|
||||||
print("Estimating Initial transform...")
|
if self.verbose:
|
||||||
pbar = tqdm(total=100)
|
print("Estimating Initial transform...")
|
||||||
|
if self.use_progress_bar:
|
||||||
|
pbar = tqdm(total=100)
|
||||||
current = 0
|
current = 0
|
||||||
while stop:
|
while stop:
|
||||||
y_pred = self.C(params, init_points_3d_prepared)
|
y_pred = self.C(params, init_points_3d_prepared)
|
||||||
@ -159,18 +167,22 @@ class TorchCameraEstimate(CameraEstimate):
|
|||||||
if self.renderer is not None:
|
if self.renderer is not None:
|
||||||
self.renderer.set_group_pose("body", current_pose)
|
self.renderer.set_group_pose("body", current_pose)
|
||||||
per = int((tol/loss*100).item())
|
per = int((tol/loss*100).item())
|
||||||
if per > 100:
|
|
||||||
pbar.update(abs(100 - current))
|
if self.use_progress_bar:
|
||||||
current = 100
|
if per > 100:
|
||||||
else:
|
pbar.update(abs(100 - current))
|
||||||
pbar.update(per - current)
|
current = 100
|
||||||
current = per
|
else:
|
||||||
|
pbar.update(per - current)
|
||||||
|
current = per
|
||||||
stop = loss > tol
|
stop = loss > tol
|
||||||
|
|
||||||
if stop == True:
|
if stop == True:
|
||||||
stop = self.patience_module(loss, 5)
|
stop = self.patience_module(loss, 5)
|
||||||
pbar.update(abs(100 - current))
|
|
||||||
pbar.close()
|
if self.use_progress_bar:
|
||||||
|
pbar.update(abs(100 - current))
|
||||||
|
pbar.close()
|
||||||
self.memory = None
|
self.memory = None
|
||||||
transform_matrix = self.torch_params_to_pose(params)
|
transform_matrix = self.torch_params_to_pose(params)
|
||||||
current_pose = transform_matrix.detach().numpy()
|
current_pose = transform_matrix.detach().numpy()
|
||||||
@ -180,7 +192,7 @@ class TorchCameraEstimate(CameraEstimate):
|
|||||||
# camera_translation[0,2] = 5 * torch.ones(1)
|
# camera_translation[0,2] = 5 * torch.ones(1)
|
||||||
|
|
||||||
camera_rotation = torch.tensor(
|
camera_rotation = torch.tensor(
|
||||||
[[0,0,0]], requires_grad=False, dtype=self.dtype, device=self.device)
|
[[0, 0, 0]], requires_grad=False, dtype=self.dtype, device=self.device)
|
||||||
camera_intrinsics = torch.zeros(
|
camera_intrinsics = torch.zeros(
|
||||||
4, 4, dtype=self.dtype, device=self.device)
|
4, 4, dtype=self.dtype, device=self.device)
|
||||||
camera_intrinsics[0, 0] = 5
|
camera_intrinsics[0, 0] = 5
|
||||||
@ -205,8 +217,9 @@ class TorchCameraEstimate(CameraEstimate):
|
|||||||
stop = True
|
stop = True
|
||||||
first = True
|
first = True
|
||||||
cam_tol = 6e-3
|
cam_tol = 6e-3
|
||||||
print("Estimating Camera transformations...")
|
# print("Estimating Camera transformations...")
|
||||||
pbar = tqdm(total=100)
|
if self.use_progress_bar:
|
||||||
|
pbar = tqdm(total=100)
|
||||||
current = 0
|
current = 0
|
||||||
|
|
||||||
while stop:
|
while stop:
|
||||||
@ -221,15 +234,17 @@ class TorchCameraEstimate(CameraEstimate):
|
|||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
opt2.step()
|
opt2.step()
|
||||||
if visualize:
|
|
||||||
|
if visualize and self.renderer is not None:
|
||||||
self.renderer.scene.set_pose(
|
self.renderer.scene.set_pose(
|
||||||
self.camera_renderer, self.torch_params_to_pose(params).detach().numpy())
|
self.camera_renderer, self.torch_params_to_pose(params).detach().numpy())
|
||||||
per = int((cam_tol/loss*100).item())
|
per = int((cam_tol/loss*100).item())
|
||||||
|
|
||||||
if per > 100:
|
if self.use_progress_bar:
|
||||||
pbar.update(100 - current)
|
if per > 100:
|
||||||
else:
|
pbar.update(100 - current)
|
||||||
pbar.update(per - current)
|
else:
|
||||||
|
pbar.update(per - current)
|
||||||
|
|
||||||
current = per
|
current = per
|
||||||
stop = loss > cam_tol
|
stop = loss > cam_tol
|
||||||
@ -237,8 +252,9 @@ class TorchCameraEstimate(CameraEstimate):
|
|||||||
if stop == True:
|
if stop == True:
|
||||||
stop = self.patience_module(loss, 5)
|
stop = self.patience_module(loss, 5)
|
||||||
|
|
||||||
pbar.update(100 - current)
|
if self.use_progress_bar:
|
||||||
pbar.close()
|
pbar.update(100 - current)
|
||||||
|
pbar.close()
|
||||||
camera_transform_matrix = self.torch_params_to_pose(
|
camera_transform_matrix = self.torch_params_to_pose(
|
||||||
params)
|
params)
|
||||||
return camera_intrinsics, transform_matrix, camera_transform_matrix
|
return camera_intrinsics, transform_matrix, camera_transform_matrix
|
||||||
|
|||||||
@ -9,10 +9,10 @@ smpl:
|
|||||||
useVposerInit: false
|
useVposerInit: false
|
||||||
data:
|
data:
|
||||||
renameFiles: false
|
renameFiles: false
|
||||||
rootDir: ./samples
|
rootDir: ./output
|
||||||
personId: 0
|
personId: 0
|
||||||
sampleImageFormat: "%%%.png"
|
sampleImageFormat: "input_%%%%%%%%%%%%_rendered.png"
|
||||||
sampleNameFormat: "%%%.json"
|
sampleNameFormat: "input_%%%%%%%%%%%%_keypoints.json"
|
||||||
sampleCoords: !!python/tuple [1080, 1080]
|
sampleCoords: !!python/tuple [1080, 1080]
|
||||||
camera:
|
camera:
|
||||||
lr: 0.001
|
lr: 0.001
|
||||||
|
|||||||
16
dataset.py
16
dataset.py
@ -16,6 +16,7 @@ class SMPLyDataset(Dataset):
|
|||||||
model_type="smplx",
|
model_type="smplx",
|
||||||
person_id=0,
|
person_id=0,
|
||||||
sample_format="%%%.json",
|
sample_format="%%%.json",
|
||||||
|
image_format="%%%.png",
|
||||||
start_index=1,
|
start_index=1,
|
||||||
sample_id_pad=None
|
sample_id_pad=None
|
||||||
):
|
):
|
||||||
@ -23,6 +24,8 @@ class SMPLyDataset(Dataset):
|
|||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.size = size
|
self.size = size
|
||||||
self.person_id = person_id
|
self.person_id = person_id
|
||||||
|
self.image_format = image_format
|
||||||
|
self.sample_format = sample_format
|
||||||
if sample_id_pad:
|
if sample_id_pad:
|
||||||
self.sample_id_pad = sample_format.count('%')
|
self.sample_id_pad = sample_format.count('%')
|
||||||
else:
|
else:
|
||||||
@ -43,6 +46,8 @@ class SMPLyDataset(Dataset):
|
|||||||
size=config['data']['sampleCoords'],
|
size=config['data']['sampleCoords'],
|
||||||
person_id=config['data']['personId'],
|
person_id=config['data']['personId'],
|
||||||
model_type=config['smpl']['type'],
|
model_type=config['smpl']['type'],
|
||||||
|
image_format=config['data']['sampleImageFormat'],
|
||||||
|
sample_format=config['data']['sampleNameFormat'],
|
||||||
# img_format=config['data']['sampleImageFormat'],
|
# img_format=config['data']['sampleImageFormat'],
|
||||||
sample_id_pad=config['data']['sampleImageFormat'].count('%')
|
sample_id_pad=config['data']['sampleImageFormat'].count('%')
|
||||||
)
|
)
|
||||||
@ -55,8 +60,12 @@ class SMPLyDataset(Dataset):
|
|||||||
index + self.start_index).zfill(self.sample_id_pad)
|
index + self.start_index).zfill(self.sample_id_pad)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
def get_keypoint_name(self, index):
|
||||||
|
id = self.get_item_name(index)
|
||||||
|
return self.sample_format.replace("%" * len(id), id)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
name = self.get_item_name(index) + ".json"
|
name = self.get_keypoint_name(index)
|
||||||
path = os.path.join(
|
path = os.path.join(
|
||||||
self.root_dir, name)
|
self.root_dir, name)
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
@ -69,7 +78,7 @@ class SMPLyDataset(Dataset):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def transform(self, data, origin_format="body_25"):
|
def transform(self, data, origin_format="body_25"):
|
||||||
"""
|
"""
|
||||||
transform: transforms the order of an origin array to the target format
|
transform: transforms the order of an origin array to the target format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -85,6 +94,7 @@ class SMPLyDataset(Dataset):
|
|||||||
return sample_count
|
return sample_count
|
||||||
|
|
||||||
def get_image_path(self, index):
|
def get_image_path(self, index):
|
||||||
name = self.get_item_name(index) + ".png"
|
id = self.get_item_name(index)
|
||||||
|
name = self.image_format.replace("%" * len(id), id)
|
||||||
path = os.path.join(self.root_dir, name)
|
path = os.path.join(self.root_dir, name)
|
||||||
return path
|
return path
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import pickle
|
|||||||
import time
|
import time
|
||||||
from utils.render import make_video
|
from utils.render import make_video
|
||||||
import torch
|
import torch
|
||||||
from tqdm.std import tqdm
|
from tqdm.auto import trange
|
||||||
|
|
||||||
from dataset import SMPLyDataset
|
from dataset import SMPLyDataset
|
||||||
from model import *
|
from model import *
|
||||||
@ -16,7 +16,7 @@ from utils.general import rename_files, get_new_filename
|
|||||||
START_IDX = 150 # starting index of the frame to optimize for
|
START_IDX = 150 # starting index of the frame to optimize for
|
||||||
FINISH_IDX = 200 # choose a big number to optimize for all frames in samples directory
|
FINISH_IDX = 200 # choose a big number to optimize for all frames in samples directory
|
||||||
# if False, only run already saved animation without optimization
|
# if False, only run already saved animation without optimization
|
||||||
RUN_OPTIMIZATION = False
|
RUN_OPTIMIZATION = True
|
||||||
|
|
||||||
final_poses = [] # optimized poses array that is saved for playing the animation
|
final_poses = [] # optimized poses array that is saved for playing the animation
|
||||||
result_image = []
|
result_image = []
|
||||||
@ -60,7 +60,7 @@ joints = model_out.joints.detach().cpu().numpy().squeeze()
|
|||||||
Optimization part without visualization
|
Optimization part without visualization
|
||||||
'''
|
'''
|
||||||
if RUN_OPTIMIZATION:
|
if RUN_OPTIMIZATION:
|
||||||
for idx in dataset:
|
for idx in trange(100, desc='Optimizing'):
|
||||||
|
|
||||||
init_keypoints, init_joints, keypoints, conf, est_scale, r, img_path = setup_training(
|
init_keypoints, init_joints, keypoints, conf, est_scale, r, img_path = setup_training(
|
||||||
model=model,
|
model=model,
|
||||||
@ -70,7 +70,9 @@ if RUN_OPTIMIZATION:
|
|||||||
sample_index=idx
|
sample_index=idx
|
||||||
)
|
)
|
||||||
|
|
||||||
camera = TorchCameraEstimate(
|
r.start()
|
||||||
|
|
||||||
|
cam = TorchCameraEstimate(
|
||||||
model,
|
model,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
keypoints=keypoints,
|
keypoints=keypoints,
|
||||||
@ -78,24 +80,28 @@ if RUN_OPTIMIZATION:
|
|||||||
device=torch.device('cpu'),
|
device=torch.device('cpu'),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
image_path=img_path,
|
image_path=img_path,
|
||||||
est_scale=est_scale
|
est_scale=est_scale,
|
||||||
|
use_progress_bar=False,
|
||||||
|
verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\nCamera optimization of frame", idx, "is finished.")
|
# print("\nCamera optimization of frame", idx, "is finished.")
|
||||||
camera = SimpleCamera.from_estimation_cam(camera)
|
|
||||||
|
|
||||||
cur_pose, final_pose, loss, frames = train_pose_with_conf(
|
cur_pose, final_pose, loss, frames = train_pose_with_conf(
|
||||||
config=config,
|
config=config,
|
||||||
model=model,
|
model=model,
|
||||||
keypoints=keypoints,
|
keypoints=keypoints,
|
||||||
keypoint_conf=conf,
|
keypoint_conf=conf,
|
||||||
camera=camera,
|
camera=cam,
|
||||||
renderer=r,
|
renderer=r,
|
||||||
device=device
|
device=device,
|
||||||
|
use_progress_bar=False
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\nPose optimization of frame", idx, "is finished.")
|
camera_transformation, camera_int, camera_params = cam.get_results()
|
||||||
R = camera.trans.numpy().squeeze()
|
|
||||||
|
# print("\nPose optimization of frame", idx, "is finished.")
|
||||||
|
R = camera_transformation.numpy().squeeze()
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
# append optimized pose and camera transformation to the array
|
# append optimized pose and camera transformation to the array
|
||||||
|
|||||||
19
model.py
19
model.py
@ -8,7 +8,7 @@ from human_body_prior.tools.model_loader import load_vposer
|
|||||||
|
|
||||||
|
|
||||||
def get_model_path(type, gender, dir):
|
def get_model_path(type, gender, dir):
|
||||||
return os.path.joint(
|
return os.path.join(
|
||||||
dir,
|
dir,
|
||||||
type,
|
type,
|
||||||
type.upper() + "_" +
|
type.upper() + "_" +
|
||||||
@ -24,6 +24,8 @@ def get_model_path_from_conf(config):
|
|||||||
|
|
||||||
|
|
||||||
class VPoserModel():
|
class VPoserModel():
|
||||||
|
global_vposer = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_type='smpl',
|
model_type='smpl',
|
||||||
@ -70,14 +72,17 @@ class VPoserModel():
|
|||||||
def get_pose(self):
|
def get_pose(self):
|
||||||
return self.model.pose_body
|
return self.model.pose_body
|
||||||
|
|
||||||
def from_conf(config):
|
def from_conf(config, use_global=True):
|
||||||
model_path = get_model_path_from_conf(config)
|
model_path = get_model_path_from_conf(config)
|
||||||
|
|
||||||
return VPoserModel(
|
if VPoserModel.global_vposer is None:
|
||||||
model_type=config['smpl']['type'],
|
VPoserModel.global_vposer = VPoserModel(
|
||||||
gender=config['smpl']['gender'],
|
model_type=config['smpl']['type'],
|
||||||
vposer_model_path=config['pose']['vposerPath'],
|
gender=config['smpl']['gender'],
|
||||||
body_model_path=model_path)
|
vposer_model_path=config['pose']['vposerPath'],
|
||||||
|
body_model_path=model_path)
|
||||||
|
|
||||||
|
return VPoserModel.global_vposer
|
||||||
|
|
||||||
|
|
||||||
class SMPLyModel():
|
class SMPLyModel():
|
||||||
|
|||||||
@ -53,11 +53,14 @@ def train_pose(
|
|||||||
# renderer options
|
# renderer options
|
||||||
renderer: Renderer = None,
|
renderer: Renderer = None,
|
||||||
render_steps=True,
|
render_steps=True,
|
||||||
render_offscreen=True
|
render_offscreen=True,
|
||||||
):
|
|
||||||
|
|
||||||
print("[pose] starting training")
|
vposer=None,
|
||||||
print("[pose] dtype=", dtype)
|
|
||||||
|
use_progress_bar=True
|
||||||
|
):
|
||||||
|
# print("[pose] starting training")
|
||||||
|
# print("[pose] dtype=", dtype)
|
||||||
|
|
||||||
offscreen_step_output = []
|
offscreen_step_output = []
|
||||||
|
|
||||||
@ -85,7 +88,6 @@ def train_pose(
|
|||||||
|
|
||||||
# loss layers
|
# loss layers
|
||||||
if useBodyPrior:
|
if useBodyPrior:
|
||||||
vposer = VPoserModel()
|
|
||||||
# TODO: handle this in vposer model
|
# TODO: handle this in vposer model
|
||||||
vposer.model.to(device=device, dtype=dtype)
|
vposer.model.to(device=device, dtype=dtype)
|
||||||
latent_body = vposer.get_pose()
|
latent_body = vposer.get_pose()
|
||||||
@ -105,7 +107,8 @@ def train_pose(
|
|||||||
|
|
||||||
optimizer = optimizer(parameters, learning_rate)
|
optimizer = optimizer(parameters, learning_rate)
|
||||||
|
|
||||||
pbar = tqdm(total=iterations)
|
if use_progress_bar:
|
||||||
|
pbar = tqdm(total=iterations)
|
||||||
|
|
||||||
def predict():
|
def predict():
|
||||||
# pose_extra = None
|
# pose_extra = None
|
||||||
@ -197,8 +200,9 @@ def train_pose(
|
|||||||
if patience == 0:
|
if patience == 0:
|
||||||
print("[train] aborted due to patience limit reached")
|
print("[train] aborted due to patience limit reached")
|
||||||
|
|
||||||
pbar.set_description("Error %f" % cur_loss)
|
if use_progress_bar:
|
||||||
pbar.update(1)
|
pbar.set_description("Error %f" % cur_loss)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
if renderer is not None and render_steps:
|
if renderer is not None and render_steps:
|
||||||
R = camera.trans.detach().cpu().numpy().squeeze()
|
R = camera.trans.detach().cpu().numpy().squeeze()
|
||||||
@ -209,8 +213,9 @@ def train_pose(
|
|||||||
offscreen_step_output.append(renderer.get_snapshot())
|
offscreen_step_output.append(renderer.get_snapshot())
|
||||||
# renderer.set_group_pose("body", R)
|
# renderer.set_group_pose("body", R)
|
||||||
|
|
||||||
pbar.close()
|
if use_progress_bar:
|
||||||
print("Final result:", loss.item())
|
pbar.close()
|
||||||
|
print("Final result:", loss.item())
|
||||||
return pose_layer.cur_out, best_pose, loss_history, offscreen_step_output
|
return pose_layer.cur_out, best_pose, loss_history, offscreen_step_output
|
||||||
|
|
||||||
|
|
||||||
@ -223,7 +228,8 @@ def train_pose_with_conf(
|
|||||||
device=torch.device('cpu'),
|
device=torch.device('cpu'),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
renderer: Renderer = None,
|
renderer: Renderer = None,
|
||||||
render_steps=True
|
render_steps=True,
|
||||||
|
use_progress_bar=True
|
||||||
):
|
):
|
||||||
|
|
||||||
# configure PyTorch device and format
|
# configure PyTorch device and format
|
||||||
@ -245,6 +251,8 @@ def train_pose_with_conf(
|
|||||||
if renderer is not None:
|
if renderer is not None:
|
||||||
renderer.set_group_pose("body", cam_trans.cpu().numpy())
|
renderer.set_group_pose("body", cam_trans.cpu().numpy())
|
||||||
|
|
||||||
|
vposer = VPoserModel.from_conf(config)
|
||||||
|
|
||||||
return train_pose(
|
return train_pose(
|
||||||
model=model.to(dtype=dtype),
|
model=model.to(dtype=dtype),
|
||||||
keypoints=keypoints,
|
keypoints=keypoints,
|
||||||
@ -259,11 +267,13 @@ def train_pose_with_conf(
|
|||||||
learning_rate=config['pose']['lr'],
|
learning_rate=config['pose']['lr'],
|
||||||
optimizer_type=config['pose']['optimizer'],
|
optimizer_type=config['pose']['optimizer'],
|
||||||
iterations=config['pose']['iterations'],
|
iterations=config['pose']['iterations'],
|
||||||
|
vposer=vposer,
|
||||||
body_prior_weight=config['pose']['bodyPrior']['weight'],
|
body_prior_weight=config['pose']['bodyPrior']['weight'],
|
||||||
angle_prior_weight=config['pose']['anglePrior']['weight'],
|
angle_prior_weight=config['pose']['anglePrior']['weight'],
|
||||||
body_mean_loss=config['pose']['bodyMeanLoss']['enabled'],
|
body_mean_loss=config['pose']['bodyMeanLoss']['enabled'],
|
||||||
body_mean_weight=config['pose']['bodyMeanLoss']['weight'],
|
body_mean_weight=config['pose']['bodyMeanLoss']['weight'],
|
||||||
use_angle_sum_loss=config['pose']['angleSumLoss']['enabled'],
|
use_angle_sum_loss=config['pose']['angleSumLoss']['enabled'],
|
||||||
angle_sum_weight=config['pose']['angleSumLoss']['weight'],
|
angle_sum_weight=config['pose']['angleSumLoss']['weight'],
|
||||||
render_steps=render_steps
|
render_steps=render_steps,
|
||||||
|
use_progress_bar=use_progress_bar
|
||||||
)
|
)
|
||||||
|
|||||||
@ -140,8 +140,8 @@ def rename_files(dir):
|
|||||||
|
|
||||||
def get_new_filename():
|
def get_new_filename():
|
||||||
conf = load_config()
|
conf = load_config()
|
||||||
results_dir = conf['resultsPath']
|
results_dir = conf['output']['rootDir']
|
||||||
result_prefix = conf['resultPrefix']
|
result_prefix = conf['output']['prefix']
|
||||||
|
|
||||||
results = glob.glob(results_dir + "*.pkl")
|
results = glob.glob(results_dir + "*.pkl")
|
||||||
if len(results) == 0:
|
if len(results) == 0:
|
||||||
|
|||||||
@ -10,7 +10,6 @@ from tqdm import tqdm
|
|||||||
def make_video(images, video_name: str, fps=5):
|
def make_video(images, video_name: str, fps=5):
|
||||||
|
|
||||||
images = np.array(images)
|
images = np.array(images)
|
||||||
print(images.shape)
|
|
||||||
width = images.shape[2]
|
width = images.shape[2]
|
||||||
height = images.shape[1]
|
height = images.shape[1]
|
||||||
video = cv2.VideoWriter(
|
video = cv2.VideoWriter(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user