body-pose-animation/example_fit.py
2021-02-17 14:52:36 +01:00

50 lines
1.2 KiB
Python

# library imports
import math
import os
from train import optimize_sample
import matplotlib.pyplot as plt
# local imports
from utils.general import load_config
from dataset import SMPLyDataset
# load and select sample
config = load_config()
dataset = SMPLyDataset.from_config(config=config)
sample_index = 55
if os.getenv('SAMPLE_INDEX') is not None:
sample_index = int(os.getenv('SAMPLE_INDEX'))
# train for pose
pose, camera_transformation, loss_history, step_imgs, loss_components = optimize_sample(
sample_index,
dataset,
config,
interactive=True
)
# color = r.get_snapshot()
# plt.imshow(color)
# plt.show()
fig, ax = plt.subplots(1, 2)
ax[0].plot(loss_history[1::], label='sgd')
ax[0].set(xlabel="Iterations", ylabel="Loss", title='Total Loss')
plt_idx = 1
for name, loss in loss_components.items():
x = math.floor(plt_idx / 3)
y = plt_idx % 2
ax[1].plot(loss[1::], label=name)
ax[1].set(xlabel="Iteration",
ylabel="Loss", title="Component Loss")
plt_idx = plt_idx + 1
plt.legend(loc="upper left")
# name = getfilename_from_conf(config=config, index=sample_index)
# fig.savefig("results/" + name + ".png")
# ax.legend()
plt.show()