mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
add body mean loss
This commit is contained in:
parent
a820443206
commit
2a9f7f1676
@ -19,15 +19,18 @@ camera:
|
||||
patience: 10
|
||||
optimizer: Adam
|
||||
pose:
|
||||
lr: 0.01
|
||||
lr: 0.05
|
||||
optimizer: Adam # currently supported Adam, LBFGS
|
||||
iterations: 300
|
||||
useCameraIntrinsics: false
|
||||
bodyMeanLoss:
|
||||
enabled: false
|
||||
weight: 0.001
|
||||
bodyPrior:
|
||||
enabled: true
|
||||
weight: 4.0
|
||||
weight: 0.5
|
||||
anglePrior:
|
||||
enabled: false
|
||||
enabled: true
|
||||
weight: 0.01
|
||||
# idxWeights:
|
||||
# - 0.5
|
||||
|
||||
@ -30,9 +30,14 @@ def train_pose(
|
||||
useConfWeights=True,
|
||||
patience=10,
|
||||
body_prior_weight=2,
|
||||
angle_prior_weight=0.5
|
||||
angle_prior_weight=0.5,
|
||||
body_mean_loss=False,
|
||||
body_mean_weight=0.01
|
||||
):
|
||||
|
||||
print("[pose] starting training")
|
||||
print("[pose] ")
|
||||
|
||||
loss_layer = torch.nn.MSELoss()
|
||||
|
||||
# setup keypoint data
|
||||
@ -95,9 +100,16 @@ def train_pose(
|
||||
else:
|
||||
loss = loss_layer(points, keypoints)
|
||||
|
||||
if body_mean_loss:
|
||||
# apply pose prior loss.
|
||||
# poZ.pow(2).sum() * body_prior_weight
|
||||
loss = loss + (pose_layer.body_pose -
|
||||
pose_extra).pow(2).sum() * body_mean_weight
|
||||
|
||||
if useBodyPrior:
|
||||
# apply pose prior loss.
|
||||
loss = loss + poZ.pow(2).sum() * body_prior_weight
|
||||
|
||||
if useAnglePrior:
|
||||
loss = loss + \
|
||||
angle_prior_layer(pose_layer.body_pose) * angle_prior_weight
|
||||
@ -187,5 +199,7 @@ def train_pose_with_conf(
|
||||
optimizer_type=config['pose']['optimizer'],
|
||||
iterations=config['pose']['iterations'],
|
||||
body_prior_weight=config['pose']['bodyPrior']['weight'],
|
||||
angle_prior_weight=config['pose']['anglePrior']['weight']
|
||||
angle_prior_weight=config['pose']['anglePrior']['weight'],
|
||||
body_mean_loss=config['pose']['bodyMeanLoss']['enabled'],
|
||||
body_mean_weight=config['pose']['bodyMeanLoss']['weight']
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user