torch-mesh-isect/examples/batch_smpl_untangle.py
2019-06-07 10:18:19 +02:00

363 lines
14 KiB
Python

# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems and the Max Planck Institute for Biological
# Cybernetics. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import sys
import os
import time
import pickle
import argparse
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
from mesh_intersection.filter_faces import FilterFaces
from mesh_intersection.bvh_search_tree import BVH
import mesh_intersection.loss as collisions_loss
from smplx import create
def main():
description = 'Example script for untangling SMPL self intersections'
parser = argparse.ArgumentParser(description=description,
prog='Batch SMPL-Untangle')
parser.add_argument('--param_fn', type=str,
nargs='*',
required=True,
help='The pickle file with the model parameters')
parser.add_argument('--interactive', default=True,
type=lambda arg: arg.lower() in ['true', '1'],
help='Display the mesh during the optimization' +
' process')
parser.add_argument('--delay', type=int, default=50,
help='The delay for the animation callback in ms')
parser.add_argument('--model_folder', type=str,
default='models',
help='The path to the LBS model')
parser.add_argument('--model_type', type=str,
default='smpl', choices=['smpl', 'smplx', 'smplh'],
help='The type of model to create')
parser.add_argument('--point2plane', default=False,
type=lambda arg: arg.lower() in ['true', '1'],
help='Use point to distance')
parser.add_argument('--optimize_pose', default=True,
type=lambda arg: arg.lower() in ['true', '1'],
help='Enable optimization over the joint pose')
parser.add_argument('--optimize_shape', default=False,
type=lambda arg: arg.lower() in ['true', '1'],
help='Enable optimization over the shape of the model')
parser.add_argument('--sigma', default=0.5, type=float,
help='The height of the cone used to calculate the' +
' distance field loss')
parser.add_argument('--lr', default=1, type=float,
help='The learning rate for SGD')
parser.add_argument('--coll_loss_weight', default=1e-4, type=float,
help='The weight for the collision loss')
parser.add_argument('--pose_reg_weight', default=0, type=float,
help='The weight for the pose regularizer')
parser.add_argument('--shape_reg_weight', default=0, type=float,
help='The weight for the shape regularizer')
parser.add_argument('--max_collisions', default=8, type=int,
help='The maximum number of bounding box collisions')
parser.add_argument('--part_segm_fn', default='', type=str,
help='The file with the part segmentation for the' +
' faces of the model')
parser.add_argument('--print_timings', default=False,
type=lambda arg: arg.lower() in ['true', '1'],
help='Print timings for all the operations')
args = parser.parse_args()
model_folder = args.model_folder
model_type = args.model_type
param_fn = args.param_fn
interactive = args.interactive
delay = args.delay
point2plane = args.point2plane
# optimize_shape = args.optimize_shape
# optimize_pose = args.optimize_pose
lr = args.lr
coll_loss_weight = args.coll_loss_weight
pose_reg_weight = args.pose_reg_weight
shape_reg_weight = args.shape_reg_weight
max_collisions = args.max_collisions
sigma = args.sigma
part_segm_fn = args.part_segm_fn
print_timings = args.print_timings
if interactive:
import trimesh
import pyrender
device = torch.device('cuda')
batch_size = len(param_fn)
params_dict = defaultdict(lambda: [])
for idx, fn in enumerate(param_fn):
with open(fn, 'rb') as param_file:
data = pickle.load(param_file, encoding='latin1')
assert 'betas' in data, \
'No key for shape parameter in provided pickle file'
assert 'global_pose' in data, \
'No key for the global pose in the given pickle file'
assert 'pose' in data, \
'No key for the pose of the joints in the given pickle file'
for key, val in data.items():
params_dict[key].append(val)
params = {}
for key in params_dict:
params[key] = np.stack(params_dict[key], axis=0).astype(np.float32)
if len(params[key].shape) < 2:
params[key] = params[key][np.newaxis]
if 'global_pose' in params:
params['global_orient'] = params['global_pose']
if 'pose' in params:
params['body_pose'] = params['pose']
if part_segm_fn:
# Read the part segmentation
with open(part_segm_fn, 'rb') as faces_parents_file:
data = pickle.load(faces_parents_file, encoding='latin1')
faces_segm = data['segm']
faces_parents = data['parents']
# Create the module used to filter invalid collision pairs
filter_faces = FilterFaces(faces_segm, faces_parents).to(device=device)
# Create the body model
body = create(model_folder, batch_size=batch_size,
model_type=model_type).to(device=device)
body.reset_params(**params)
# Clone the given pose to use it as a target for regularization
init_pose = body.body_pose.clone().detach()
# Create the search tree
search_tree = BVH(max_collisions=max_collisions)
pen_distance = \
collisions_loss.DistanceFieldPenetrationLoss(sigma=sigma,
point2plane=point2plane,
vectorized=True)
mse_loss = nn.MSELoss(reduction='sum').to(device=device)
face_tensor = torch.tensor(body.faces.astype(np.int64), dtype=torch.long,
device=device).unsqueeze_(0).repeat([batch_size,
1, 1])
with torch.no_grad():
output = body(get_skin=True)
verts = output.vertices
bs, nv = verts.shape[:2]
bs, nf = face_tensor.shape[:2]
faces_idx = face_tensor + \
(torch.arange(bs, dtype=torch.long).to(device) * nv)[:, None, None]
optimizer = torch.optim.SGD([body.body_pose], lr=lr)
if interactive:
# Plot the initial mesh
with torch.no_grad():
output = body(get_skin=True)
verts = output.vertices
np_verts = verts.detach().cpu().numpy()
def create_mesh(vertices, faces, color=(0.3, 0.3, 0.3, 1.0),
wireframe=False):
tri_mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
rot = trimesh.transformations.rotation_matrix(np.radians(180),
[1, 0, 0])
tri_mesh.apply_transform(rot)
material = pyrender.MetallicRoughnessMaterial(
metallicFactor=0.0,
alphaMode='BLEND',
baseColorFactor=color)
return pyrender.Mesh.from_trimesh(
tri_mesh,
material=material)
scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 1.0],
ambient_light=(1.0, 1.0, 1.0))
for bidx in range(np_verts.shape[0]):
curr_verts = np_verts[bidx].copy()
body_mesh = create_mesh(curr_verts, body.faces,
color=(0.3, 0.3, 0.3, 0.99),
wireframe=True)
pose = np.eye(4)
pose[0, 3] = bidx * 2
scene.add(body_mesh,
name='body_mesh_{:03d}'.format(bidx),
pose=pose)
viewer = pyrender.Viewer(scene, use_raymond_lighting=True,
viewport_size=(1200, 800),
cull_faces=False,
run_in_thread=True)
query_names = ['recv_mesh', 'intr_mesh', 'body_mesh']
step = 0
while True:
optimizer.zero_grad()
if print_timings:
start = time.time()
if print_timings:
torch.cuda.synchronize()
output = body(get_skin=True)
verts = output.vertices
if print_timings:
torch.cuda.synchronize()
print('Body model forward: {:5f}'.format(time.time() - start))
if print_timings:
torch.cuda.synchronize()
start = time.time()
triangles = verts.view([-1, 3])[faces_idx]
if print_timings:
torch.cuda.synchronize()
print('Triangle indexing: {:5f}'.format(time.time() - start))
with torch.no_grad():
if print_timings:
start = time.time()
collision_idxs = search_tree(triangles)
if print_timings:
torch.cuda.synchronize()
print('Collision Detection: {:5f}'.format(time.time() -
start))
if part_segm_fn:
if print_timings:
start = time.time()
collision_idxs = filter_faces(collision_idxs)
if print_timings:
torch.cuda.synchronize()
print('Collision filtering: {:5f}'.format(time.time() -
start))
if print_timings:
start = time.time()
pen_loss = coll_loss_weight * \
pen_distance(triangles, collision_idxs)
if print_timings:
torch.cuda.synchronize()
print('Penetration loss: {:5f}'.format(time.time() - start))
shape_reg_loss = torch.tensor(0, device=device,
dtype=torch.float32)
if shape_reg_weight > 0:
shape_reg_loss = shape_reg_weight * torch.sum(body.betas ** 2)
pose_reg_loss = torch.tensor(0, device=device,
dtype=torch.float32)
if pose_reg_weight > 0:
pose_reg_loss = pose_reg_weight * \
mse_loss(body.pose, init_pose)
loss = pen_loss + pose_reg_loss + shape_reg_loss
np_loss = loss.detach().cpu().squeeze().tolist()
if type(np_loss) != list:
np_loss = [np_loss]
msg = '{:.5f} ' * len(np_loss)
print('Loss per model:', msg.format(*np_loss))
if print_timings:
start = time.time()
loss.backward(torch.ones_like(loss))
if print_timings:
torch.cuda.synchronize()
print('Backward pass: {:5f}'.format(time.time() - start))
if interactive:
with torch.no_grad():
output = body(get_skin=True)
verts = output.vertices
np_verts = verts.detach().cpu().numpy()
np_collision_idxs = collision_idxs.detach().cpu().numpy()
np_receivers = np_collision_idxs[:, :, 0]
np_intruders = np_collision_idxs[:, :, 1]
viewer.render_lock.acquire()
for node in scene.get_nodes():
if node.name is None:
continue
if any([query in node.name for query in query_names]):
scene.remove_node(node)
for bidx in range(batch_size):
recv_faces_idxs = np_receivers[bidx][np_receivers[bidx] >= 0]
intr_faces_idxs = np_intruders[bidx][np_intruders[bidx] >= 0]
recv_faces = body.faces[recv_faces_idxs]
intr_faces = body.faces[intr_faces_idxs]
curr_verts = np_verts[bidx].copy()
body_mesh = create_mesh(curr_verts, body.faces,
color=(0.3, 0.3, 0.3, 0.99),
wireframe=True)
pose = np.eye(4)
pose[0, 3] = bidx * 2
scene.add(body_mesh,
name='body_mesh_{:03d}'.format(bidx),
pose=pose)
if len(intr_faces) > 0:
intr_mesh = create_mesh(curr_verts, intr_faces,
color=(0.9, 0.0, 0.0, 1.0))
scene.add(intr_mesh,
name='intr_mesh_{:03d}'.format(bidx),
pose=pose)
if len(recv_faces) > 0:
recv_mesh = create_mesh(curr_verts, recv_faces,
color=(0.0, 0.9, 0.0, 1.0))
scene.add(recv_mesh, name='recv_mesh_{:03d}'.format(bidx),
pose=pose)
viewer.render_lock.release()
if not viewer.is_active:
break
time.sleep(delay / 1000)
optimizer.step()
step += 1
if __name__ == '__main__':
main()