mirror of
https://github.com/gosticks/body-pose-animation.git
synced 2025-10-16 11:45:42 +00:00
34 lines
887 B
Python
34 lines
887 B
Python
from model import VPoserModel
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
|
|
class BodyPrior(nn.Module):
|
|
def __init__(
|
|
self,
|
|
vmodel: VPoserModel,
|
|
device=torch.device('cpu'),
|
|
dtype=torch.float32,
|
|
# directions=[-1, 1, 1, 1],
|
|
weight=1
|
|
):
|
|
super(BodyPrior, self).__init__()
|
|
|
|
self.has_parameters = True
|
|
|
|
self.model = vmodel.model.to(device=device, dtype=dtype)
|
|
latent_pose = vmodel.get_vposer_latent()
|
|
self.register_parameter("latent_pose", latent_pose)
|
|
|
|
# create buffer for weights
|
|
self.register_buffer(
|
|
"weight",
|
|
torch.tensor(weight, dtype=dtype).to(device=device)
|
|
)
|
|
|
|
def forward(self, pose, joints, points, keypoints, raw_output):
|
|
# get relevant angles
|
|
return self.latent_pose.pow(
|
|
2).sum() * self.weight
|