mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
add openpose conf toggle
This commit is contained in:
parent
3e40ff67ad
commit
6bf0cb4b7f
@ -1,11 +1,9 @@
|
|||||||
# library imports
|
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
|
from utils.graphs import render_loss_graph
|
||||||
from train import optimize_sample
|
from train import optimize_sample
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
# local imports
|
# local imports
|
||||||
from utils.general import load_config
|
from utils.general import get_output_path_from_conf, load_config
|
||||||
from dataset import SMPLyDataset
|
from dataset import SMPLyDataset
|
||||||
|
|
||||||
# load and select sample
|
# load and select sample
|
||||||
@ -25,25 +23,9 @@ pose, camera_transformation, loss_history, step_imgs, loss_components = optimize
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# color = r.get_snapshot()
|
filename = get_output_path_from_conf(config) + ".png"
|
||||||
# plt.imshow(color)
|
render_loss_graph(
|
||||||
# plt.show()
|
loss_history=loss_history,
|
||||||
|
loss_components=loss_components,
|
||||||
fig, ax = plt.subplots(1, 2)
|
save=True,
|
||||||
ax[0].plot(loss_history[1::], label='sgd')
|
filename=filename)
|
||||||
ax[0].set(xlabel="Iterations", ylabel="Loss", title='Total Loss')
|
|
||||||
plt_idx = 1
|
|
||||||
for name, loss in loss_components.items():
|
|
||||||
x = math.floor(plt_idx / 3)
|
|
||||||
y = plt_idx % 2
|
|
||||||
ax[1].plot(loss[1::], label=name)
|
|
||||||
ax[1].set(xlabel="Iteration",
|
|
||||||
ylabel="Loss", title="Component Loss")
|
|
||||||
|
|
||||||
plt_idx = plt_idx + 1
|
|
||||||
|
|
||||||
plt.legend(loc="upper left")
|
|
||||||
# name = getfilename_from_conf(config=config, index=sample_index)
|
|
||||||
# fig.savefig("results/" + name + ".png")
|
|
||||||
# ax.legend()
|
|
||||||
plt.show()
|
|
||||||
|
|||||||
@ -42,6 +42,7 @@ def train_pose(
|
|||||||
extra_loss_layers=[],
|
extra_loss_layers=[],
|
||||||
|
|
||||||
use_progress_bar=True,
|
use_progress_bar=True,
|
||||||
|
use_openpose_conf_loss=True,
|
||||||
loss_analysis=True
|
loss_analysis=True
|
||||||
):
|
):
|
||||||
if use_progress_bar:
|
if use_progress_bar:
|
||||||
@ -50,12 +51,19 @@ def train_pose(
|
|||||||
|
|
||||||
offscreen_step_output = []
|
offscreen_step_output = []
|
||||||
|
|
||||||
|
# is enabled will use openpose keypoint confidence
|
||||||
|
# as weights on the loss components
|
||||||
|
if use_openpose_conf_loss:
|
||||||
loss_layer = WeightedMSELoss(
|
loss_layer = WeightedMSELoss(
|
||||||
weights=keypoint_conf,
|
weights=keypoint_conf,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype
|
dtype=dtype
|
||||||
) # torch.nn.MSELoss(reduction="sum").to(
|
)
|
||||||
# device=device, dtype=dtype) # MSELoss()
|
else:
|
||||||
|
loss_layer = torch.nn.MSELoss(reduction="sum").to(
|
||||||
|
device=device,
|
||||||
|
dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
# make sure camera module is on the correct device
|
# make sure camera module is on the correct device
|
||||||
camera = camera.to(device=device, dtype=dtype)
|
camera = camera.to(device=device, dtype=dtype)
|
||||||
@ -69,9 +77,6 @@ def train_pose(
|
|||||||
# filter keypoints
|
# filter keypoints
|
||||||
keypoints = keypoint_filter(keypoints)
|
keypoints = keypoint_filter(keypoints)
|
||||||
|
|
||||||
# get a list of openpose conf values
|
|
||||||
# keypoints_conf = torch.tensor(keypoint_conf).to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
# create filter layer to ignore unused joints, keypoints during optimization
|
# create filter layer to ignore unused joints, keypoints during optimization
|
||||||
filter_layer = JointFilter(
|
filter_layer = JointFilter(
|
||||||
model_type=model_type, filter_dims=3).to(device=device, dtype=dtype)
|
model_type=model_type, filter_dims=3).to(device=device, dtype=dtype)
|
||||||
@ -240,6 +245,7 @@ def train_pose_with_conf(
|
|||||||
iterations=config['pose']['iterations'],
|
iterations=config['pose']['iterations'],
|
||||||
learning_rate=config['pose']['lr'],
|
learning_rate=config['pose']['lr'],
|
||||||
render_steps=render_steps,
|
render_steps=render_steps,
|
||||||
|
use_openpose_conf_loss=config['pose']['useOpenPoseConf'],
|
||||||
use_progress_bar=use_progress_bar,
|
use_progress_bar=use_progress_bar,
|
||||||
extra_loss_layers=loss_layers
|
extra_loss_layers=loss_layers
|
||||||
)
|
)
|
||||||
|
|||||||
@ -34,11 +34,15 @@ def getfilename_from_conf(config, index=None):
|
|||||||
if config['pose']['temporal']['enabled']:
|
if config['pose']['temporal']['enabled']:
|
||||||
name = name + "-temporal"
|
name = name + "-temporal"
|
||||||
|
|
||||||
print(name)
|
|
||||||
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def get_output_path_from_conf(config, index=None):
|
||||||
|
name = getfilename_from_conf(config, index=index)
|
||||||
|
|
||||||
|
return os.path.join(config['output']['rootDir'], name)
|
||||||
|
|
||||||
|
|
||||||
def load_config(name=None):
|
def load_config(name=None):
|
||||||
if name is None:
|
if name is None:
|
||||||
config_name_env = os.getenv('CONFIG_PATH')
|
config_name_env = os.getenv('CONFIG_PATH')
|
||||||
|
|||||||
29
utils/graphs.py
Normal file
29
utils/graphs.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def render_loss_graph(
|
||||||
|
loss_history,
|
||||||
|
loss_components,
|
||||||
|
save=False,
|
||||||
|
show=True,
|
||||||
|
filename="untitled.png"):
|
||||||
|
fig, ax = plt.subplots(1, 2)
|
||||||
|
ax[0].plot(loss_history[1::], label='sgd')
|
||||||
|
ax[0].set(xlabel="Iterations", ylabel="Loss", title='Total Loss')
|
||||||
|
plt_idx = 1
|
||||||
|
for name, loss in loss_components.items():
|
||||||
|
x = math.floor(plt_idx / 3)
|
||||||
|
y = plt_idx % 2
|
||||||
|
ax[1].plot(loss[1::], label=name)
|
||||||
|
ax[1].set(xlabel="Iteration",
|
||||||
|
ylabel="Loss", title="Component Loss")
|
||||||
|
|
||||||
|
plt_idx = plt_idx + 1
|
||||||
|
|
||||||
|
plt.legend(loc="upper left")
|
||||||
|
|
||||||
|
if save:
|
||||||
|
fig.savefig(filename)
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
Loading…
Reference in New Issue
Block a user