From 2a9f7f167656c239381a21fb30625d913eb3eaf2 Mon Sep 17 00:00:00 2001 From: Wlad <9556979+gosticks@users.noreply.github.com> Date: Fri, 5 Feb 2021 22:15:57 +0100 Subject: [PATCH] add body mean loss --- config.yaml | 9 ++++++--- train_pose.py | 18 ++++++++++++++++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/config.yaml b/config.yaml index 13d8b0c..4b89971 100644 --- a/config.yaml +++ b/config.yaml @@ -19,15 +19,18 @@ camera: patience: 10 optimizer: Adam pose: - lr: 0.01 + lr: 0.05 optimizer: Adam # currently supported Adam, LBFGS iterations: 300 useCameraIntrinsics: false + bodyMeanLoss: + enabled: false + weight: 0.001 bodyPrior: enabled: true - weight: 4.0 + weight: 0.5 anglePrior: - enabled: false + enabled: true weight: 0.01 # idxWeights: # - 0.5 diff --git a/train_pose.py b/train_pose.py index d20ec85..b6b2c7f 100644 --- a/train_pose.py +++ b/train_pose.py @@ -30,9 +30,14 @@ def train_pose( useConfWeights=True, patience=10, body_prior_weight=2, - angle_prior_weight=0.5 + angle_prior_weight=0.5, + body_mean_loss=False, + body_mean_weight=0.01 ): + print("[pose] starting training") + print("[pose] ") + loss_layer = torch.nn.MSELoss() # setup keypoint data @@ -95,9 +100,16 @@ def train_pose( else: loss = loss_layer(points, keypoints) + if body_mean_loss: + # apply pose prior loss. + # poZ.pow(2).sum() * body_prior_weight + loss = loss + (pose_layer.body_pose - + pose_extra).pow(2).sum() * body_mean_weight + if useBodyPrior: # apply pose prior loss. loss = loss + poZ.pow(2).sum() * body_prior_weight + if useAnglePrior: loss = loss + \ angle_prior_layer(pose_layer.body_pose) * angle_prior_weight @@ -187,5 +199,7 @@ def train_pose_with_conf( optimizer_type=config['pose']['optimizer'], iterations=config['pose']['iterations'], body_prior_weight=config['pose']['bodyPrior']['weight'], - angle_prior_weight=config['pose']['anglePrior']['weight'] + angle_prior_weight=config['pose']['anglePrior']['weight'], + body_mean_loss=config['pose']['bodyMeanLoss']['enabled'], + body_mean_weight=config['pose']['bodyMeanLoss']['weight'] )