add intersection loss

This commit is contained in:
Wlad 2021-02-17 13:18:15 +01:00
parent 787770244c
commit 3fcfc6cd45
8 changed files with 122 additions and 19 deletions

View File

@ -19,26 +19,31 @@ camera:
patience: 10
optimizer: Adam
pose:
# device: cuda
lr: 0.2
optimizer: LBFGS # currently supported Adam, LBFGS
iterations: 30
device: cuda
lr: 0.01
optimizer: Adam # currently supported Adam, LBFGS
iterations: 100
useCameraIntrinsics: true
bodyMeanLoss:
enabled: false
weight: 0.1
bodyPrior:
enabled: true
weight: 1.0
weight: 0.1
anglePrior:
enabled: true
weight: 0.001
weight: 0.5
angleLimitLoss:
enabled: true
weight: 0.5
weight: 0.01
angleSumLoss:
enabled: true
weight: 0.1
weight: 0.01
intersectLoss:
enabled: true
weight: 2.0
maxCollisions: 8
sigma: 0.5
confWeights:
enabled: false
vposerPath: "./vposer_v1_0"

View File

@ -34,7 +34,7 @@ class AngleClipper(nn.Module):
torch.tensor(weight, dtype=dtype).to(device=device)
)
def forward(self, pose, joints, points, keypoints):
def forward(self, pose, joints, points, keypoints, raw_output):
angles = pose[:, self.angle_idx]

View File

@ -40,7 +40,7 @@ class AnglePriorsLoss(nn.Module):
torch.tensor(global_weight, dtype=dtype).to(device=device)
)
def forward(self, pose, joints, points, keypoints):
def forward(self, pose, joints, points, keypoints, raw_output):
# compute direction deviation from expected joint rotation directions,
# e.g. don't rotate the knee joint forwards. Broken knees are not fun.

View File

@ -21,6 +21,6 @@ class AngleSumLoss(nn.Module):
torch.tensor(weight).to(device=device, dtype=dtype)
)
def forward(self, pose, joints, points, keypoints):
def forward(self, pose, joints, points, keypoints, raw_output):
# get relevant angles
return pose.pow(2).sum() * self.weight

View File

@ -27,7 +27,7 @@ class BodyPrior(nn.Module):
torch.tensor(weight, dtype=dtype).to(device=device)
)
def forward(self, pose, joints, points, keypoints):
def forward(self, pose, joints, points, keypoints, raw_output):
# get relevant angles
return self.latent_pose.pow(
2).sum() * self.weight

82
modules/intersect.py Normal file
View File

@ -0,0 +1,82 @@
import smplx
from model import VPoserModel
import torch
import torch.nn as nn
import numpy as np
from mesh_intersection.bvh_search_tree import BVH
import mesh_intersection.loss as collisions_loss
class IntersectLoss(nn.Module):
def __init__(
self,
model: smplx.SMPL,
device=torch.device('cpu'),
dtype=torch.float32,
batch_size=1,
weight=1,
sigma=0.5,
max_collisions=8,
point2plane=True
):
"""Intersections loss layer.
Args:
device ([type], optional): [description]. Defaults to torch.device('cpu').
dtype ([type], optional): [description]. Defaults to torch.float32.
weight (int, optional): Weight factor of the loss. Defaults to 1.
sigma (float, optional): The height of the cone used to calculate the distance field loss. Defaults to 0.5.
max_collisions (int, optional): The maximum number of bounding box collisions. Defaults to 8.
"""
super(IntersectLoss, self).__init__()
self.has_parameters = False
with torch.no_grad():
output = model(get_skin=True)
verts = output.vertices
face_tensor = torch.tensor(
model.faces.astype(np.int64),
dtype=torch.long,
device=device) \
.unsqueeze_(0) \
.repeat(
[batch_size,
1, 1])
bs, nv = verts.shape[:2]
bs, nf = face_tensor.shape[:2]
faces_idx = face_tensor + \
(torch.arange(bs, dtype=torch.long).to(device) * nv)[:, None, None]
self.register_buffer("faces_idx", faces_idx)
# Create the search tree
self.search_tree = BVH(max_collisions=max_collisions)
self.pen_distance = \
collisions_loss.DistanceFieldPenetrationLoss(sigma=sigma,
point2plane=point2plane,
vectorized=True)
# create buffer for weights
self.register_buffer(
"weight",
torch.tensor(weight, dtype=dtype).to(device=device)
)
def forward(self, pose, joints, points, keypoints, raw_output):
verts = raw_output.vertices
polygons = verts.view([-1, 3])[self.faces_idx]
# find collision idx
with torch.no_grad():
collision_idxs = self.search_tree(polygons)
# compute penetration loss
return self.pen_distance(polygons, collision_idxs) * self.weight

View File

@ -18,6 +18,7 @@ kiwisolver==1.3.1
lazy-object-proxy @ file:///tmp/build/80754af9/lazy-object-proxy_1607707326711/work
matplotlib==3.3.3
mccabe==0.6.1
mesh-intersection @ git+git://github.com/gosticks/torch-mesh-isect.git@cd0dcfe1e8845de4e8ab2b241f2d32787debcf85
moviepy==1.0.3
networkx==2.5
numpy==1.19.5
@ -39,7 +40,7 @@ six @ file:///tmp/build/80754af9/six_1605205335545/work
smplx==0.1.26
tensorboardX==2.1
toml @ file:///tmp/build/80754af9/toml_1592853716807/work
torch==1.7.1
torch==1.7.1+cu110
torch-utils==0.1.2
torchgeometry==0.1.2
torchvision==0.8.2+cu110

View File

@ -1,3 +1,4 @@
from modules.intersect import IntersectLoss
from modules.body_prior import BodyPrior
from modules.angle_sum import AngleSumLoss
from camera_estimation import TorchCameraEstimate
@ -60,7 +61,8 @@ def train_pose(
extra_loss_layers=[],
use_progress_bar=True
use_progress_bar=True,
loss_analysis=True
):
if use_progress_bar:
print("[pose] starting training")
@ -127,12 +129,15 @@ def train_pose(
points = filter_layer(points)
# compute loss between 2D joint projection and OpenPose keypoints
loss = loss_layer(points, keypoints) * 100
loss = loss_layer(points, keypoints) # * 100
# apply extra losses
for l in extra_loss_layers:
loss = loss + l(cur_pose, body_joints, points,
keypoints)
cur_loss = l(cur_pose, body_joints, points,
keypoints, pose_layer.cur_out)
if loss_analysis:
print(l.__class__.__name__, ":loss ->", cur_loss)
loss = loss + cur_loss
return loss
def optim_closure():
@ -190,7 +195,7 @@ def train_pose(
return best_output, loss_history, offscreen_step_output
def get_loss_layers(config, device, dtype):
def get_loss_layers(config, model: smplx.SMPL, device, dtype):
""" Utility method to create loss layers based on a config file
Args:
@ -227,6 +232,16 @@ def get_loss_layers(config, device, dtype):
dtype=dtype,
weight=config['pose']['angleLimitLoss']['weight']))
if config['pose']['intersectLoss']['enabled']:
extra_loss_layers.append(IntersectLoss(
model=model,
device=device,
dtype=dtype,
weight=config['pose']['intersectLoss']['weight'],
sigma=config['pose']['intersectLoss']['sigma'],
max_collisions=config['pose']['intersectLoss']['maxCollisions']
))
return extra_loss_layers
@ -263,7 +278,7 @@ def train_pose_with_conf(
if renderer is not None:
renderer.set_group_pose("body", cam_trans.cpu().numpy())
loss_layers = get_loss_layers(config, device, dtype)
loss_layers = get_loss_layers(config, model, device, dtype)
if print_loss_layers:
print(loss_layers)