body-pose-animation/example_fit.py
2021-02-07 19:23:55 +01:00

65 lines
1.5 KiB
Python

# library imports
import torch
import matplotlib.pyplot as plt
# local imports
from train_pose import train_pose_with_conf
from modules.camera import SimpleCamera
from model import SMPLyModel
from utils.general import getfilename_from_conf, load_config, setup_training
from camera_estimation import TorchCameraEstimate
from dataset import SMPLyDataset
# load and select sample
config = load_config()
dataset = SMPLyDataset.from_config(config=config)
sample_index = 0
# prepare data and SMPL model
model = SMPLyModel.model_from_conf(config)
init_keypoints, init_joints, keypoints, conf, est_scale, r, img_path = setup_training(
model=model,
renderer=True,
dataset=dataset,
sample_index=sample_index
)
# configure PyTorch device and format
dtype = torch.float32
device = torch.device('cpu')
camera = TorchCameraEstimate(
model,
dataset=dataset,
keypoints=keypoints,
renderer=r,
device=device,
dtype=dtype,
image_path=img_path,
est_scale=est_scale
)
# render camera to the scene
camera.setup_visualization(r.init_keypoints, r.keypoints)
# train for pose
result, best, train_loss = train_pose_with_conf(
config=config,
model=model,
keypoints=keypoints,
keypoint_conf=conf,
camera=camera,
renderer=r,
device=device,
)
fig, ax = plt.subplots()
name = getfilename_from_conf(config=config, index=sample_index)
ax.plot(train_loss[1::], label='sgd')
ax.set(xlabel="Training iteration", ylabel="Loss", title='Training loss')
fig.savefig("results/" + name + ".png")
ax.legend()
plt.show()