mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
add intersection loss
This commit is contained in:
parent
787770244c
commit
3fcfc6cd45
21
config.yaml
21
config.yaml
@ -19,26 +19,31 @@ camera:
|
|||||||
patience: 10
|
patience: 10
|
||||||
optimizer: Adam
|
optimizer: Adam
|
||||||
pose:
|
pose:
|
||||||
# device: cuda
|
device: cuda
|
||||||
lr: 0.2
|
lr: 0.01
|
||||||
optimizer: LBFGS # currently supported Adam, LBFGS
|
optimizer: Adam # currently supported Adam, LBFGS
|
||||||
iterations: 30
|
iterations: 100
|
||||||
useCameraIntrinsics: true
|
useCameraIntrinsics: true
|
||||||
bodyMeanLoss:
|
bodyMeanLoss:
|
||||||
enabled: false
|
enabled: false
|
||||||
weight: 0.1
|
weight: 0.1
|
||||||
bodyPrior:
|
bodyPrior:
|
||||||
enabled: true
|
enabled: true
|
||||||
weight: 1.0
|
weight: 0.1
|
||||||
anglePrior:
|
anglePrior:
|
||||||
enabled: true
|
enabled: true
|
||||||
weight: 0.001
|
weight: 0.5
|
||||||
angleLimitLoss:
|
angleLimitLoss:
|
||||||
enabled: true
|
enabled: true
|
||||||
weight: 0.5
|
weight: 0.01
|
||||||
angleSumLoss:
|
angleSumLoss:
|
||||||
enabled: true
|
enabled: true
|
||||||
weight: 0.1
|
weight: 0.01
|
||||||
|
intersectLoss:
|
||||||
|
enabled: true
|
||||||
|
weight: 2.0
|
||||||
|
maxCollisions: 8
|
||||||
|
sigma: 0.5
|
||||||
confWeights:
|
confWeights:
|
||||||
enabled: false
|
enabled: false
|
||||||
vposerPath: "./vposer_v1_0"
|
vposerPath: "./vposer_v1_0"
|
||||||
|
|||||||
@ -34,7 +34,7 @@ class AngleClipper(nn.Module):
|
|||||||
torch.tensor(weight, dtype=dtype).to(device=device)
|
torch.tensor(weight, dtype=dtype).to(device=device)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, pose, joints, points, keypoints):
|
def forward(self, pose, joints, points, keypoints, raw_output):
|
||||||
|
|
||||||
angles = pose[:, self.angle_idx]
|
angles = pose[:, self.angle_idx]
|
||||||
|
|
||||||
|
|||||||
@ -40,7 +40,7 @@ class AnglePriorsLoss(nn.Module):
|
|||||||
torch.tensor(global_weight, dtype=dtype).to(device=device)
|
torch.tensor(global_weight, dtype=dtype).to(device=device)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, pose, joints, points, keypoints):
|
def forward(self, pose, joints, points, keypoints, raw_output):
|
||||||
# compute direction deviation from expected joint rotation directions,
|
# compute direction deviation from expected joint rotation directions,
|
||||||
# e.g. don't rotate the knee joint forwards. Broken knees are not fun.
|
# e.g. don't rotate the knee joint forwards. Broken knees are not fun.
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,6 @@ class AngleSumLoss(nn.Module):
|
|||||||
torch.tensor(weight).to(device=device, dtype=dtype)
|
torch.tensor(weight).to(device=device, dtype=dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, pose, joints, points, keypoints):
|
def forward(self, pose, joints, points, keypoints, raw_output):
|
||||||
# get relevant angles
|
# get relevant angles
|
||||||
return pose.pow(2).sum() * self.weight
|
return pose.pow(2).sum() * self.weight
|
||||||
|
|||||||
@ -27,7 +27,7 @@ class BodyPrior(nn.Module):
|
|||||||
torch.tensor(weight, dtype=dtype).to(device=device)
|
torch.tensor(weight, dtype=dtype).to(device=device)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, pose, joints, points, keypoints):
|
def forward(self, pose, joints, points, keypoints, raw_output):
|
||||||
# get relevant angles
|
# get relevant angles
|
||||||
return self.latent_pose.pow(
|
return self.latent_pose.pow(
|
||||||
2).sum() * self.weight
|
2).sum() * self.weight
|
||||||
|
|||||||
82
modules/intersect.py
Normal file
82
modules/intersect.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import smplx
|
||||||
|
from model import VPoserModel
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
from mesh_intersection.bvh_search_tree import BVH
|
||||||
|
import mesh_intersection.loss as collisions_loss
|
||||||
|
|
||||||
|
|
||||||
|
class IntersectLoss(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: smplx.SMPL,
|
||||||
|
device=torch.device('cpu'),
|
||||||
|
dtype=torch.float32,
|
||||||
|
batch_size=1,
|
||||||
|
weight=1,
|
||||||
|
sigma=0.5,
|
||||||
|
max_collisions=8,
|
||||||
|
point2plane=True
|
||||||
|
):
|
||||||
|
"""Intersections loss layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device ([type], optional): [description]. Defaults to torch.device('cpu').
|
||||||
|
dtype ([type], optional): [description]. Defaults to torch.float32.
|
||||||
|
weight (int, optional): Weight factor of the loss. Defaults to 1.
|
||||||
|
sigma (float, optional): The height of the cone used to calculate the distance field loss. Defaults to 0.5.
|
||||||
|
max_collisions (int, optional): The maximum number of bounding box collisions. Defaults to 8.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super(IntersectLoss, self).__init__()
|
||||||
|
|
||||||
|
self.has_parameters = False
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(get_skin=True)
|
||||||
|
verts = output.vertices
|
||||||
|
|
||||||
|
face_tensor = torch.tensor(
|
||||||
|
model.faces.astype(np.int64),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device) \
|
||||||
|
.unsqueeze_(0) \
|
||||||
|
.repeat(
|
||||||
|
[batch_size,
|
||||||
|
1, 1])
|
||||||
|
|
||||||
|
bs, nv = verts.shape[:2]
|
||||||
|
bs, nf = face_tensor.shape[:2]
|
||||||
|
|
||||||
|
faces_idx = face_tensor + \
|
||||||
|
(torch.arange(bs, dtype=torch.long).to(device) * nv)[:, None, None]
|
||||||
|
|
||||||
|
self.register_buffer("faces_idx", faces_idx)
|
||||||
|
|
||||||
|
# Create the search tree
|
||||||
|
self.search_tree = BVH(max_collisions=max_collisions)
|
||||||
|
|
||||||
|
self.pen_distance = \
|
||||||
|
collisions_loss.DistanceFieldPenetrationLoss(sigma=sigma,
|
||||||
|
point2plane=point2plane,
|
||||||
|
vectorized=True)
|
||||||
|
|
||||||
|
# create buffer for weights
|
||||||
|
self.register_buffer(
|
||||||
|
"weight",
|
||||||
|
torch.tensor(weight, dtype=dtype).to(device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, pose, joints, points, keypoints, raw_output):
|
||||||
|
verts = raw_output.vertices
|
||||||
|
polygons = verts.view([-1, 3])[self.faces_idx]
|
||||||
|
|
||||||
|
# find collision idx
|
||||||
|
with torch.no_grad():
|
||||||
|
collision_idxs = self.search_tree(polygons)
|
||||||
|
|
||||||
|
# compute penetration loss
|
||||||
|
return self.pen_distance(polygons, collision_idxs) * self.weight
|
||||||
@ -18,6 +18,7 @@ kiwisolver==1.3.1
|
|||||||
lazy-object-proxy @ file:///tmp/build/80754af9/lazy-object-proxy_1607707326711/work
|
lazy-object-proxy @ file:///tmp/build/80754af9/lazy-object-proxy_1607707326711/work
|
||||||
matplotlib==3.3.3
|
matplotlib==3.3.3
|
||||||
mccabe==0.6.1
|
mccabe==0.6.1
|
||||||
|
mesh-intersection @ git+git://github.com/gosticks/torch-mesh-isect.git@cd0dcfe1e8845de4e8ab2b241f2d32787debcf85
|
||||||
moviepy==1.0.3
|
moviepy==1.0.3
|
||||||
networkx==2.5
|
networkx==2.5
|
||||||
numpy==1.19.5
|
numpy==1.19.5
|
||||||
@ -39,7 +40,7 @@ six @ file:///tmp/build/80754af9/six_1605205335545/work
|
|||||||
smplx==0.1.26
|
smplx==0.1.26
|
||||||
tensorboardX==2.1
|
tensorboardX==2.1
|
||||||
toml @ file:///tmp/build/80754af9/toml_1592853716807/work
|
toml @ file:///tmp/build/80754af9/toml_1592853716807/work
|
||||||
torch==1.7.1
|
torch==1.7.1+cu110
|
||||||
torch-utils==0.1.2
|
torch-utils==0.1.2
|
||||||
torchgeometry==0.1.2
|
torchgeometry==0.1.2
|
||||||
torchvision==0.8.2+cu110
|
torchvision==0.8.2+cu110
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from modules.intersect import IntersectLoss
|
||||||
from modules.body_prior import BodyPrior
|
from modules.body_prior import BodyPrior
|
||||||
from modules.angle_sum import AngleSumLoss
|
from modules.angle_sum import AngleSumLoss
|
||||||
from camera_estimation import TorchCameraEstimate
|
from camera_estimation import TorchCameraEstimate
|
||||||
@ -60,7 +61,8 @@ def train_pose(
|
|||||||
|
|
||||||
extra_loss_layers=[],
|
extra_loss_layers=[],
|
||||||
|
|
||||||
use_progress_bar=True
|
use_progress_bar=True,
|
||||||
|
loss_analysis=True
|
||||||
):
|
):
|
||||||
if use_progress_bar:
|
if use_progress_bar:
|
||||||
print("[pose] starting training")
|
print("[pose] starting training")
|
||||||
@ -127,12 +129,15 @@ def train_pose(
|
|||||||
points = filter_layer(points)
|
points = filter_layer(points)
|
||||||
|
|
||||||
# compute loss between 2D joint projection and OpenPose keypoints
|
# compute loss between 2D joint projection and OpenPose keypoints
|
||||||
loss = loss_layer(points, keypoints) * 100
|
loss = loss_layer(points, keypoints) # * 100
|
||||||
|
|
||||||
# apply extra losses
|
# apply extra losses
|
||||||
for l in extra_loss_layers:
|
for l in extra_loss_layers:
|
||||||
loss = loss + l(cur_pose, body_joints, points,
|
cur_loss = l(cur_pose, body_joints, points,
|
||||||
keypoints)
|
keypoints, pose_layer.cur_out)
|
||||||
|
if loss_analysis:
|
||||||
|
print(l.__class__.__name__, ":loss ->", cur_loss)
|
||||||
|
loss = loss + cur_loss
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def optim_closure():
|
def optim_closure():
|
||||||
@ -190,7 +195,7 @@ def train_pose(
|
|||||||
return best_output, loss_history, offscreen_step_output
|
return best_output, loss_history, offscreen_step_output
|
||||||
|
|
||||||
|
|
||||||
def get_loss_layers(config, device, dtype):
|
def get_loss_layers(config, model: smplx.SMPL, device, dtype):
|
||||||
""" Utility method to create loss layers based on a config file
|
""" Utility method to create loss layers based on a config file
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -227,6 +232,16 @@ def get_loss_layers(config, device, dtype):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
weight=config['pose']['angleLimitLoss']['weight']))
|
weight=config['pose']['angleLimitLoss']['weight']))
|
||||||
|
|
||||||
|
if config['pose']['intersectLoss']['enabled']:
|
||||||
|
extra_loss_layers.append(IntersectLoss(
|
||||||
|
model=model,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
weight=config['pose']['intersectLoss']['weight'],
|
||||||
|
sigma=config['pose']['intersectLoss']['sigma'],
|
||||||
|
max_collisions=config['pose']['intersectLoss']['maxCollisions']
|
||||||
|
))
|
||||||
|
|
||||||
return extra_loss_layers
|
return extra_loss_layers
|
||||||
|
|
||||||
|
|
||||||
@ -263,7 +278,7 @@ def train_pose_with_conf(
|
|||||||
if renderer is not None:
|
if renderer is not None:
|
||||||
renderer.set_group_pose("body", cam_trans.cpu().numpy())
|
renderer.set_group_pose("body", cam_trans.cpu().numpy())
|
||||||
|
|
||||||
loss_layers = get_loss_layers(config, device, dtype)
|
loss_layers = get_loss_layers(config, model, device, dtype)
|
||||||
|
|
||||||
if print_loss_layers:
|
if print_loss_layers:
|
||||||
print(loss_layers)
|
print(loss_layers)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user