diff --git a/config.yaml b/config.yaml index 66c30fb..18c2741 100644 --- a/config.yaml +++ b/config.yaml @@ -19,26 +19,31 @@ camera: patience: 10 optimizer: Adam pose: - # device: cuda - lr: 0.2 - optimizer: LBFGS # currently supported Adam, LBFGS - iterations: 30 + device: cuda + lr: 0.01 + optimizer: Adam # currently supported Adam, LBFGS + iterations: 100 useCameraIntrinsics: true bodyMeanLoss: enabled: false weight: 0.1 bodyPrior: enabled: true - weight: 1.0 + weight: 0.1 anglePrior: enabled: true - weight: 0.001 + weight: 0.5 angleLimitLoss: enabled: true - weight: 0.5 + weight: 0.01 angleSumLoss: enabled: true - weight: 0.1 + weight: 0.01 + intersectLoss: + enabled: true + weight: 2.0 + maxCollisions: 8 + sigma: 0.5 confWeights: enabled: false vposerPath: "./vposer_v1_0" diff --git a/modules/angle_clip.py b/modules/angle_clip.py index 225dc12..36f3a9c 100644 --- a/modules/angle_clip.py +++ b/modules/angle_clip.py @@ -34,7 +34,7 @@ class AngleClipper(nn.Module): torch.tensor(weight, dtype=dtype).to(device=device) ) - def forward(self, pose, joints, points, keypoints): + def forward(self, pose, joints, points, keypoints, raw_output): angles = pose[:, self.angle_idx] diff --git a/modules/angle_prior.py b/modules/angle_prior.py index 9308d2c..0670b3d 100644 --- a/modules/angle_prior.py +++ b/modules/angle_prior.py @@ -40,7 +40,7 @@ class AnglePriorsLoss(nn.Module): torch.tensor(global_weight, dtype=dtype).to(device=device) ) - def forward(self, pose, joints, points, keypoints): + def forward(self, pose, joints, points, keypoints, raw_output): # compute direction deviation from expected joint rotation directions, # e.g. don't rotate the knee joint forwards. Broken knees are not fun. diff --git a/modules/angle_sum.py b/modules/angle_sum.py index dfa0a47..8bf3d69 100644 --- a/modules/angle_sum.py +++ b/modules/angle_sum.py @@ -21,6 +21,6 @@ class AngleSumLoss(nn.Module): torch.tensor(weight).to(device=device, dtype=dtype) ) - def forward(self, pose, joints, points, keypoints): + def forward(self, pose, joints, points, keypoints, raw_output): # get relevant angles return pose.pow(2).sum() * self.weight diff --git a/modules/body_prior.py b/modules/body_prior.py index b084ec9..0dbd8fc 100644 --- a/modules/body_prior.py +++ b/modules/body_prior.py @@ -27,7 +27,7 @@ class BodyPrior(nn.Module): torch.tensor(weight, dtype=dtype).to(device=device) ) - def forward(self, pose, joints, points, keypoints): + def forward(self, pose, joints, points, keypoints, raw_output): # get relevant angles return self.latent_pose.pow( 2).sum() * self.weight diff --git a/modules/intersect.py b/modules/intersect.py new file mode 100644 index 0000000..9407910 --- /dev/null +++ b/modules/intersect.py @@ -0,0 +1,82 @@ +import smplx +from model import VPoserModel +import torch +import torch.nn as nn +import numpy as np + + +from mesh_intersection.bvh_search_tree import BVH +import mesh_intersection.loss as collisions_loss + + +class IntersectLoss(nn.Module): + def __init__( + self, + model: smplx.SMPL, + device=torch.device('cpu'), + dtype=torch.float32, + batch_size=1, + weight=1, + sigma=0.5, + max_collisions=8, + point2plane=True + ): + """Intersections loss layer. + + Args: + device ([type], optional): [description]. Defaults to torch.device('cpu'). + dtype ([type], optional): [description]. Defaults to torch.float32. + weight (int, optional): Weight factor of the loss. Defaults to 1. + sigma (float, optional): The height of the cone used to calculate the distance field loss. Defaults to 0.5. + max_collisions (int, optional): The maximum number of bounding box collisions. Defaults to 8. + """ + + super(IntersectLoss, self).__init__() + + self.has_parameters = False + + with torch.no_grad(): + output = model(get_skin=True) + verts = output.vertices + + face_tensor = torch.tensor( + model.faces.astype(np.int64), + dtype=torch.long, + device=device) \ + .unsqueeze_(0) \ + .repeat( + [batch_size, + 1, 1]) + + bs, nv = verts.shape[:2] + bs, nf = face_tensor.shape[:2] + + faces_idx = face_tensor + \ + (torch.arange(bs, dtype=torch.long).to(device) * nv)[:, None, None] + + self.register_buffer("faces_idx", faces_idx) + + # Create the search tree + self.search_tree = BVH(max_collisions=max_collisions) + + self.pen_distance = \ + collisions_loss.DistanceFieldPenetrationLoss(sigma=sigma, + point2plane=point2plane, + vectorized=True) + + # create buffer for weights + self.register_buffer( + "weight", + torch.tensor(weight, dtype=dtype).to(device=device) + ) + + def forward(self, pose, joints, points, keypoints, raw_output): + verts = raw_output.vertices + polygons = verts.view([-1, 3])[self.faces_idx] + + # find collision idx + with torch.no_grad(): + collision_idxs = self.search_tree(polygons) + + # compute penetration loss + return self.pen_distance(polygons, collision_idxs) * self.weight diff --git a/requirements.txt b/requirements.txt index ad89aa4..ca5b27a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,7 @@ kiwisolver==1.3.1 lazy-object-proxy @ file:///tmp/build/80754af9/lazy-object-proxy_1607707326711/work matplotlib==3.3.3 mccabe==0.6.1 +mesh-intersection @ git+git://github.com/gosticks/torch-mesh-isect.git@cd0dcfe1e8845de4e8ab2b241f2d32787debcf85 moviepy==1.0.3 networkx==2.5 numpy==1.19.5 @@ -39,7 +40,7 @@ six @ file:///tmp/build/80754af9/six_1605205335545/work smplx==0.1.26 tensorboardX==2.1 toml @ file:///tmp/build/80754af9/toml_1592853716807/work -torch==1.7.1 +torch==1.7.1+cu110 torch-utils==0.1.2 torchgeometry==0.1.2 torchvision==0.8.2+cu110 diff --git a/train_pose.py b/train_pose.py index 3d15229..404cd9b 100644 --- a/train_pose.py +++ b/train_pose.py @@ -1,3 +1,4 @@ +from modules.intersect import IntersectLoss from modules.body_prior import BodyPrior from modules.angle_sum import AngleSumLoss from camera_estimation import TorchCameraEstimate @@ -60,7 +61,8 @@ def train_pose( extra_loss_layers=[], - use_progress_bar=True + use_progress_bar=True, + loss_analysis=True ): if use_progress_bar: print("[pose] starting training") @@ -127,12 +129,15 @@ def train_pose( points = filter_layer(points) # compute loss between 2D joint projection and OpenPose keypoints - loss = loss_layer(points, keypoints) * 100 + loss = loss_layer(points, keypoints) # * 100 # apply extra losses for l in extra_loss_layers: - loss = loss + l(cur_pose, body_joints, points, - keypoints) + cur_loss = l(cur_pose, body_joints, points, + keypoints, pose_layer.cur_out) + if loss_analysis: + print(l.__class__.__name__, ":loss ->", cur_loss) + loss = loss + cur_loss return loss def optim_closure(): @@ -190,7 +195,7 @@ def train_pose( return best_output, loss_history, offscreen_step_output -def get_loss_layers(config, device, dtype): +def get_loss_layers(config, model: smplx.SMPL, device, dtype): """ Utility method to create loss layers based on a config file Args: @@ -227,6 +232,16 @@ def get_loss_layers(config, device, dtype): dtype=dtype, weight=config['pose']['angleLimitLoss']['weight'])) + if config['pose']['intersectLoss']['enabled']: + extra_loss_layers.append(IntersectLoss( + model=model, + device=device, + dtype=dtype, + weight=config['pose']['intersectLoss']['weight'], + sigma=config['pose']['intersectLoss']['sigma'], + max_collisions=config['pose']['intersectLoss']['maxCollisions'] + )) + return extra_loss_layers @@ -263,7 +278,7 @@ def train_pose_with_conf( if renderer is not None: renderer.set_group_pose("body", cam_trans.cpu().numpy()) - loss_layers = get_loss_layers(config, device, dtype) + loss_layers = get_loss_layers(config, model, device, dtype) if print_loss_layers: print(loss_layers)