mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
add intersection loss
This commit is contained in:
parent
787770244c
commit
3fcfc6cd45
21
config.yaml
21
config.yaml
@ -19,26 +19,31 @@ camera:
|
||||
patience: 10
|
||||
optimizer: Adam
|
||||
pose:
|
||||
# device: cuda
|
||||
lr: 0.2
|
||||
optimizer: LBFGS # currently supported Adam, LBFGS
|
||||
iterations: 30
|
||||
device: cuda
|
||||
lr: 0.01
|
||||
optimizer: Adam # currently supported Adam, LBFGS
|
||||
iterations: 100
|
||||
useCameraIntrinsics: true
|
||||
bodyMeanLoss:
|
||||
enabled: false
|
||||
weight: 0.1
|
||||
bodyPrior:
|
||||
enabled: true
|
||||
weight: 1.0
|
||||
weight: 0.1
|
||||
anglePrior:
|
||||
enabled: true
|
||||
weight: 0.001
|
||||
weight: 0.5
|
||||
angleLimitLoss:
|
||||
enabled: true
|
||||
weight: 0.5
|
||||
weight: 0.01
|
||||
angleSumLoss:
|
||||
enabled: true
|
||||
weight: 0.1
|
||||
weight: 0.01
|
||||
intersectLoss:
|
||||
enabled: true
|
||||
weight: 2.0
|
||||
maxCollisions: 8
|
||||
sigma: 0.5
|
||||
confWeights:
|
||||
enabled: false
|
||||
vposerPath: "./vposer_v1_0"
|
||||
|
||||
@ -34,7 +34,7 @@ class AngleClipper(nn.Module):
|
||||
torch.tensor(weight, dtype=dtype).to(device=device)
|
||||
)
|
||||
|
||||
def forward(self, pose, joints, points, keypoints):
|
||||
def forward(self, pose, joints, points, keypoints, raw_output):
|
||||
|
||||
angles = pose[:, self.angle_idx]
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ class AnglePriorsLoss(nn.Module):
|
||||
torch.tensor(global_weight, dtype=dtype).to(device=device)
|
||||
)
|
||||
|
||||
def forward(self, pose, joints, points, keypoints):
|
||||
def forward(self, pose, joints, points, keypoints, raw_output):
|
||||
# compute direction deviation from expected joint rotation directions,
|
||||
# e.g. don't rotate the knee joint forwards. Broken knees are not fun.
|
||||
|
||||
|
||||
@ -21,6 +21,6 @@ class AngleSumLoss(nn.Module):
|
||||
torch.tensor(weight).to(device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
def forward(self, pose, joints, points, keypoints):
|
||||
def forward(self, pose, joints, points, keypoints, raw_output):
|
||||
# get relevant angles
|
||||
return pose.pow(2).sum() * self.weight
|
||||
|
||||
@ -27,7 +27,7 @@ class BodyPrior(nn.Module):
|
||||
torch.tensor(weight, dtype=dtype).to(device=device)
|
||||
)
|
||||
|
||||
def forward(self, pose, joints, points, keypoints):
|
||||
def forward(self, pose, joints, points, keypoints, raw_output):
|
||||
# get relevant angles
|
||||
return self.latent_pose.pow(
|
||||
2).sum() * self.weight
|
||||
|
||||
82
modules/intersect.py
Normal file
82
modules/intersect.py
Normal file
@ -0,0 +1,82 @@
|
||||
import smplx
|
||||
from model import VPoserModel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
from mesh_intersection.bvh_search_tree import BVH
|
||||
import mesh_intersection.loss as collisions_loss
|
||||
|
||||
|
||||
class IntersectLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: smplx.SMPL,
|
||||
device=torch.device('cpu'),
|
||||
dtype=torch.float32,
|
||||
batch_size=1,
|
||||
weight=1,
|
||||
sigma=0.5,
|
||||
max_collisions=8,
|
||||
point2plane=True
|
||||
):
|
||||
"""Intersections loss layer.
|
||||
|
||||
Args:
|
||||
device ([type], optional): [description]. Defaults to torch.device('cpu').
|
||||
dtype ([type], optional): [description]. Defaults to torch.float32.
|
||||
weight (int, optional): Weight factor of the loss. Defaults to 1.
|
||||
sigma (float, optional): The height of the cone used to calculate the distance field loss. Defaults to 0.5.
|
||||
max_collisions (int, optional): The maximum number of bounding box collisions. Defaults to 8.
|
||||
"""
|
||||
|
||||
super(IntersectLoss, self).__init__()
|
||||
|
||||
self.has_parameters = False
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(get_skin=True)
|
||||
verts = output.vertices
|
||||
|
||||
face_tensor = torch.tensor(
|
||||
model.faces.astype(np.int64),
|
||||
dtype=torch.long,
|
||||
device=device) \
|
||||
.unsqueeze_(0) \
|
||||
.repeat(
|
||||
[batch_size,
|
||||
1, 1])
|
||||
|
||||
bs, nv = verts.shape[:2]
|
||||
bs, nf = face_tensor.shape[:2]
|
||||
|
||||
faces_idx = face_tensor + \
|
||||
(torch.arange(bs, dtype=torch.long).to(device) * nv)[:, None, None]
|
||||
|
||||
self.register_buffer("faces_idx", faces_idx)
|
||||
|
||||
# Create the search tree
|
||||
self.search_tree = BVH(max_collisions=max_collisions)
|
||||
|
||||
self.pen_distance = \
|
||||
collisions_loss.DistanceFieldPenetrationLoss(sigma=sigma,
|
||||
point2plane=point2plane,
|
||||
vectorized=True)
|
||||
|
||||
# create buffer for weights
|
||||
self.register_buffer(
|
||||
"weight",
|
||||
torch.tensor(weight, dtype=dtype).to(device=device)
|
||||
)
|
||||
|
||||
def forward(self, pose, joints, points, keypoints, raw_output):
|
||||
verts = raw_output.vertices
|
||||
polygons = verts.view([-1, 3])[self.faces_idx]
|
||||
|
||||
# find collision idx
|
||||
with torch.no_grad():
|
||||
collision_idxs = self.search_tree(polygons)
|
||||
|
||||
# compute penetration loss
|
||||
return self.pen_distance(polygons, collision_idxs) * self.weight
|
||||
@ -18,6 +18,7 @@ kiwisolver==1.3.1
|
||||
lazy-object-proxy @ file:///tmp/build/80754af9/lazy-object-proxy_1607707326711/work
|
||||
matplotlib==3.3.3
|
||||
mccabe==0.6.1
|
||||
mesh-intersection @ git+git://github.com/gosticks/torch-mesh-isect.git@cd0dcfe1e8845de4e8ab2b241f2d32787debcf85
|
||||
moviepy==1.0.3
|
||||
networkx==2.5
|
||||
numpy==1.19.5
|
||||
@ -39,7 +40,7 @@ six @ file:///tmp/build/80754af9/six_1605205335545/work
|
||||
smplx==0.1.26
|
||||
tensorboardX==2.1
|
||||
toml @ file:///tmp/build/80754af9/toml_1592853716807/work
|
||||
torch==1.7.1
|
||||
torch==1.7.1+cu110
|
||||
torch-utils==0.1.2
|
||||
torchgeometry==0.1.2
|
||||
torchvision==0.8.2+cu110
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from modules.intersect import IntersectLoss
|
||||
from modules.body_prior import BodyPrior
|
||||
from modules.angle_sum import AngleSumLoss
|
||||
from camera_estimation import TorchCameraEstimate
|
||||
@ -60,7 +61,8 @@ def train_pose(
|
||||
|
||||
extra_loss_layers=[],
|
||||
|
||||
use_progress_bar=True
|
||||
use_progress_bar=True,
|
||||
loss_analysis=True
|
||||
):
|
||||
if use_progress_bar:
|
||||
print("[pose] starting training")
|
||||
@ -127,12 +129,15 @@ def train_pose(
|
||||
points = filter_layer(points)
|
||||
|
||||
# compute loss between 2D joint projection and OpenPose keypoints
|
||||
loss = loss_layer(points, keypoints) * 100
|
||||
loss = loss_layer(points, keypoints) # * 100
|
||||
|
||||
# apply extra losses
|
||||
for l in extra_loss_layers:
|
||||
loss = loss + l(cur_pose, body_joints, points,
|
||||
keypoints)
|
||||
cur_loss = l(cur_pose, body_joints, points,
|
||||
keypoints, pose_layer.cur_out)
|
||||
if loss_analysis:
|
||||
print(l.__class__.__name__, ":loss ->", cur_loss)
|
||||
loss = loss + cur_loss
|
||||
return loss
|
||||
|
||||
def optim_closure():
|
||||
@ -190,7 +195,7 @@ def train_pose(
|
||||
return best_output, loss_history, offscreen_step_output
|
||||
|
||||
|
||||
def get_loss_layers(config, device, dtype):
|
||||
def get_loss_layers(config, model: smplx.SMPL, device, dtype):
|
||||
""" Utility method to create loss layers based on a config file
|
||||
|
||||
Args:
|
||||
@ -227,6 +232,16 @@ def get_loss_layers(config, device, dtype):
|
||||
dtype=dtype,
|
||||
weight=config['pose']['angleLimitLoss']['weight']))
|
||||
|
||||
if config['pose']['intersectLoss']['enabled']:
|
||||
extra_loss_layers.append(IntersectLoss(
|
||||
model=model,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
weight=config['pose']['intersectLoss']['weight'],
|
||||
sigma=config['pose']['intersectLoss']['sigma'],
|
||||
max_collisions=config['pose']['intersectLoss']['maxCollisions']
|
||||
))
|
||||
|
||||
return extra_loss_layers
|
||||
|
||||
|
||||
@ -263,7 +278,7 @@ def train_pose_with_conf(
|
||||
if renderer is not None:
|
||||
renderer.set_group_pose("body", cam_trans.cpu().numpy())
|
||||
|
||||
loss_layers = get_loss_layers(config, device, dtype)
|
||||
loss_layers = get_loss_layers(config, model, device, dtype)
|
||||
|
||||
if print_loss_layers:
|
||||
print(loss_layers)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user