diff --git a/model.py b/model.py index e2247ff..afb9df0 100644 --- a/model.py +++ b/model.py @@ -1,7 +1,9 @@ import matplotlib.pyplot as plt +import torch import numpy as np import smplx from human_body_prior.body_model.body_model_vposer import BodyModelWithPoser +from human_body_prior.tools.model_loader import load_vposer class VPoserModel(): @@ -17,7 +19,8 @@ class VPoserModel(): sample_shape=False, sample_expression=False, num_expression_coeffs=10, - use_face_contour=False + use_face_contour=False, + use_vposer=True ): self.vposer_model_path = vposer_model_path self.model_type = model_type @@ -29,10 +32,12 @@ class VPoserModel(): self.sample_expression = sample_expression self.num_expression_coeffs = num_expression_coeffs self.create_body_pose = create_body_pose + self.use_vposer = use_vposer self.create_model() def create_model(self): + self.model = BodyModelWithPoser( bm_path="./models/smplx/SMPLX_MALE.npz", batch_size=1, @@ -62,7 +67,9 @@ class SMPLyModel(): sample_expression=True, num_expression_coeffs=10, plotting_module='pyrender', - use_face_contour=False + use_face_contour=False, + use_vposer_init=False, + device=torch.device('cpu') ): self.model_folder = model_folder self.model_type = model_type @@ -76,12 +83,28 @@ class SMPLyModel(): self.plotting_module = plotting_module self.use_face_contour = use_face_contour self.create_body_pose = create_body_pose + self.use_vposer_init = use_vposer_init + self.device = device def create_model(self): + initial_pose = None + if self.use_vposer_init: + # sample a valid human shape via vposer + vp, ps = load_vposer("./vposer_v1_0") + vp = vp.to(device=self.device) + self.vp = vp + self.vposer_sample = torch.from_numpy( + np.random.randn(1, 32).astype(np.float32)).to(device=self.device) + + # sample SMPL body pose from vposer + initial_pose = self.vp.decode( + self.vposer_sample, output_type='aa').view(-1, 63) + self.model = smplx.create( self.model_folder, model_type=self.model_type, gender=self.gender, + body_pose=initial_pose, use_face_contour=self.use_face_contour, create_body_pose=self.create_body_pose, num_betas=self.num_betas,