add openpose conf toggle

This commit is contained in:
Wlad 2021-02-17 21:16:29 +01:00
parent 3e40ff67ad
commit 6bf0cb4b7f
4 changed files with 58 additions and 37 deletions

View File

@ -1,11 +1,9 @@
# library imports
import math
import os
from utils.graphs import render_loss_graph
from train import optimize_sample
import matplotlib.pyplot as plt
# local imports
from utils.general import load_config
from utils.general import get_output_path_from_conf, load_config
from dataset import SMPLyDataset
# load and select sample
@ -25,25 +23,9 @@ pose, camera_transformation, loss_history, step_imgs, loss_components = optimize
)
# color = r.get_snapshot()
# plt.imshow(color)
# plt.show()
fig, ax = plt.subplots(1, 2)
ax[0].plot(loss_history[1::], label='sgd')
ax[0].set(xlabel="Iterations", ylabel="Loss", title='Total Loss')
plt_idx = 1
for name, loss in loss_components.items():
x = math.floor(plt_idx / 3)
y = plt_idx % 2
ax[1].plot(loss[1::], label=name)
ax[1].set(xlabel="Iteration",
ylabel="Loss", title="Component Loss")
plt_idx = plt_idx + 1
plt.legend(loc="upper left")
# name = getfilename_from_conf(config=config, index=sample_index)
# fig.savefig("results/" + name + ".png")
# ax.legend()
plt.show()
filename = get_output_path_from_conf(config) + ".png"
render_loss_graph(
loss_history=loss_history,
loss_components=loss_components,
save=True,
filename=filename)

View File

@ -42,6 +42,7 @@ def train_pose(
extra_loss_layers=[],
use_progress_bar=True,
use_openpose_conf_loss=True,
loss_analysis=True
):
if use_progress_bar:
@ -50,12 +51,19 @@ def train_pose(
offscreen_step_output = []
loss_layer = WeightedMSELoss(
weights=keypoint_conf,
device=device,
dtype=dtype
) # torch.nn.MSELoss(reduction="sum").to(
# device=device, dtype=dtype) # MSELoss()
# is enabled will use openpose keypoint confidence
# as weights on the loss components
if use_openpose_conf_loss:
loss_layer = WeightedMSELoss(
weights=keypoint_conf,
device=device,
dtype=dtype
)
else:
loss_layer = torch.nn.MSELoss(reduction="sum").to(
device=device,
dtype=dtype
)
# make sure camera module is on the correct device
camera = camera.to(device=device, dtype=dtype)
@ -69,9 +77,6 @@ def train_pose(
# filter keypoints
keypoints = keypoint_filter(keypoints)
# get a list of openpose conf values
# keypoints_conf = torch.tensor(keypoint_conf).to(device=device, dtype=dtype)
# create filter layer to ignore unused joints, keypoints during optimization
filter_layer = JointFilter(
model_type=model_type, filter_dims=3).to(device=device, dtype=dtype)
@ -240,6 +245,7 @@ def train_pose_with_conf(
iterations=config['pose']['iterations'],
learning_rate=config['pose']['lr'],
render_steps=render_steps,
use_openpose_conf_loss=config['pose']['useOpenPoseConf'],
use_progress_bar=use_progress_bar,
extra_loss_layers=loss_layers
)

View File

@ -34,11 +34,15 @@ def getfilename_from_conf(config, index=None):
if config['pose']['temporal']['enabled']:
name = name + "-temporal"
print(name)
return name
def get_output_path_from_conf(config, index=None):
name = getfilename_from_conf(config, index=index)
return os.path.join(config['output']['rootDir'], name)
def load_config(name=None):
if name is None:
config_name_env = os.getenv('CONFIG_PATH')

29
utils/graphs.py Normal file
View File

@ -0,0 +1,29 @@
import matplotlib.pyplot as plt
import math
def render_loss_graph(
loss_history,
loss_components,
save=False,
show=True,
filename="untitled.png"):
fig, ax = plt.subplots(1, 2)
ax[0].plot(loss_history[1::], label='sgd')
ax[0].set(xlabel="Iterations", ylabel="Loss", title='Total Loss')
plt_idx = 1
for name, loss in loss_components.items():
x = math.floor(plt_idx / 3)
y = plt_idx % 2
ax[1].plot(loss[1::], label=name)
ax[1].set(xlabel="Iteration",
ylabel="Loss", title="Component Loss")
plt_idx = plt_idx + 1
plt.legend(loc="upper left")
if save:
fig.savefig(filename)
if show:
plt.show()