mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
50 lines
1.2 KiB
Python
50 lines
1.2 KiB
Python
# library imports
|
|
import math
|
|
import os
|
|
from train import optimize_sample
|
|
import matplotlib.pyplot as plt
|
|
|
|
# local imports
|
|
from utils.general import load_config
|
|
from dataset import SMPLyDataset
|
|
|
|
# load and select sample
|
|
config = load_config()
|
|
dataset = SMPLyDataset.from_config(config=config)
|
|
sample_index = 55
|
|
|
|
if os.getenv('SAMPLE_INDEX') is not None:
|
|
sample_index = int(os.getenv('SAMPLE_INDEX'))
|
|
|
|
# train for pose
|
|
pose, camera_transformation, loss_history, step_imgs, loss_components = optimize_sample(
|
|
sample_index,
|
|
dataset,
|
|
config,
|
|
interactive=True
|
|
)
|
|
|
|
|
|
# color = r.get_snapshot()
|
|
# plt.imshow(color)
|
|
# plt.show()
|
|
|
|
fig, ax = plt.subplots(1, 2)
|
|
ax[0].plot(loss_history[1::], label='sgd')
|
|
ax[0].set(xlabel="Iterations", ylabel="Loss", title='Total Loss')
|
|
plt_idx = 1
|
|
for name, loss in loss_components.items():
|
|
x = math.floor(plt_idx / 3)
|
|
y = plt_idx % 2
|
|
ax[1].plot(loss[1::], label=name)
|
|
ax[1].set(xlabel="Iteration",
|
|
ylabel="Loss", title="Component Loss")
|
|
|
|
plt_idx = plt_idx + 1
|
|
|
|
plt.legend(loc="upper left")
|
|
# name = getfilename_from_conf(config=config, index=sample_index)
|
|
# fig.savefig("results/" + name + ".png")
|
|
# ax.legend()
|
|
plt.show()
|