diff --git a/example_fit.py b/example_fit.py index 6c9c28a..af84a96 100644 --- a/example_fit.py +++ b/example_fit.py @@ -1,11 +1,9 @@ -# library imports -import math import os +from utils.graphs import render_loss_graph from train import optimize_sample -import matplotlib.pyplot as plt # local imports -from utils.general import load_config +from utils.general import get_output_path_from_conf, load_config from dataset import SMPLyDataset # load and select sample @@ -25,25 +23,9 @@ pose, camera_transformation, loss_history, step_imgs, loss_components = optimize ) -# color = r.get_snapshot() -# plt.imshow(color) -# plt.show() - -fig, ax = plt.subplots(1, 2) -ax[0].plot(loss_history[1::], label='sgd') -ax[0].set(xlabel="Iterations", ylabel="Loss", title='Total Loss') -plt_idx = 1 -for name, loss in loss_components.items(): - x = math.floor(plt_idx / 3) - y = plt_idx % 2 - ax[1].plot(loss[1::], label=name) - ax[1].set(xlabel="Iteration", - ylabel="Loss", title="Component Loss") - - plt_idx = plt_idx + 1 - -plt.legend(loc="upper left") -# name = getfilename_from_conf(config=config, index=sample_index) -# fig.savefig("results/" + name + ".png") -# ax.legend() -plt.show() +filename = get_output_path_from_conf(config) + ".png" +render_loss_graph( + loss_history=loss_history, + loss_components=loss_components, + save=True, + filename=filename) diff --git a/train_pose.py b/train_pose.py index 743c291..596a245 100644 --- a/train_pose.py +++ b/train_pose.py @@ -42,6 +42,7 @@ def train_pose( extra_loss_layers=[], use_progress_bar=True, + use_openpose_conf_loss=True, loss_analysis=True ): if use_progress_bar: @@ -50,12 +51,19 @@ def train_pose( offscreen_step_output = [] - loss_layer = WeightedMSELoss( - weights=keypoint_conf, - device=device, - dtype=dtype - ) # torch.nn.MSELoss(reduction="sum").to( - # device=device, dtype=dtype) # MSELoss() + # is enabled will use openpose keypoint confidence + # as weights on the loss components + if use_openpose_conf_loss: + loss_layer = WeightedMSELoss( + weights=keypoint_conf, + device=device, + dtype=dtype + ) + else: + loss_layer = torch.nn.MSELoss(reduction="sum").to( + device=device, + dtype=dtype + ) # make sure camera module is on the correct device camera = camera.to(device=device, dtype=dtype) @@ -69,9 +77,6 @@ def train_pose( # filter keypoints keypoints = keypoint_filter(keypoints) - # get a list of openpose conf values - # keypoints_conf = torch.tensor(keypoint_conf).to(device=device, dtype=dtype) - # create filter layer to ignore unused joints, keypoints during optimization filter_layer = JointFilter( model_type=model_type, filter_dims=3).to(device=device, dtype=dtype) @@ -240,6 +245,7 @@ def train_pose_with_conf( iterations=config['pose']['iterations'], learning_rate=config['pose']['lr'], render_steps=render_steps, + use_openpose_conf_loss=config['pose']['useOpenPoseConf'], use_progress_bar=use_progress_bar, extra_loss_layers=loss_layers ) diff --git a/utils/general.py b/utils/general.py index 119b2af..144e5a6 100644 --- a/utils/general.py +++ b/utils/general.py @@ -34,11 +34,15 @@ def getfilename_from_conf(config, index=None): if config['pose']['temporal']['enabled']: name = name + "-temporal" - print(name) - return name +def get_output_path_from_conf(config, index=None): + name = getfilename_from_conf(config, index=index) + + return os.path.join(config['output']['rootDir'], name) + + def load_config(name=None): if name is None: config_name_env = os.getenv('CONFIG_PATH') diff --git a/utils/graphs.py b/utils/graphs.py new file mode 100644 index 0000000..ab02e6d --- /dev/null +++ b/utils/graphs.py @@ -0,0 +1,29 @@ +import matplotlib.pyplot as plt +import math + + +def render_loss_graph( + loss_history, + loss_components, + save=False, + show=True, + filename="untitled.png"): + fig, ax = plt.subplots(1, 2) + ax[0].plot(loss_history[1::], label='sgd') + ax[0].set(xlabel="Iterations", ylabel="Loss", title='Total Loss') + plt_idx = 1 + for name, loss in loss_components.items(): + x = math.floor(plt_idx / 3) + y = plt_idx % 2 + ax[1].plot(loss[1::], label=name) + ax[1].set(xlabel="Iteration", + ylabel="Loss", title="Component Loss") + + plt_idx = plt_idx + 1 + + plt.legend(loc="upper left") + + if save: + fig.savefig(filename) + if show: + plt.show()