-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SUPR model support and load_STAR example
- Loading branch information
1 parent
ebd1807
commit d872817
Showing
4 changed files
with
327 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
""" | ||
Copyright (C) 2023 ETH Zurich, Manuel Kaufmann, Velko Vechev, Dario Mylonopoulos | ||
This program is free software: you can redistribute it and/or modify | ||
it under the terms of the GNU General Public License as published by | ||
the Free Software Foundation, either version 3 of the License, or | ||
(at your option) any later version. | ||
This program is distributed in the hope that it will be useful, | ||
but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
GNU General Public License for more details. | ||
You should have received a copy of the GNU General Public License | ||
along with this program. If not, see <https://www.gnu.org/licenses/>. | ||
""" | ||
import collections | ||
import os | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
try: | ||
from supr.config import cfg | ||
from supr.pytorch.supr import SUPR | ||
except Exception as e: | ||
raise ImportError(f"Cannot import SUPR. Please run `pip install git+https://github.com/ahmedosman/SUPR.git`\n{e}") | ||
|
||
from aitviewer.configuration import CONFIG as C | ||
from aitviewer.utils.so3 import aa2rot_torch as aa2rot | ||
from aitviewer.utils.so3 import rot2aa_torch as rot2aa | ||
|
||
|
||
class SUPRLayer(SUPR): | ||
"""Wraps the publicly available SUPR model to match SMPLX model interface""" | ||
|
||
def __init__(self, gender="male", num_betas=10, constrained=False, device=None, dtype=None): | ||
""" | ||
Initializer. | ||
:param gender: Which gender to load. | ||
:param num_betas: Number of shape components. | ||
:param device: CPU or GPU. | ||
:param dtype: The pytorch floating point data type. | ||
""" | ||
# Configure SUPR model before initializing | ||
cfg.data_type = dtype if dtype is not None else C.f_precision | ||
path_model = os.path.join(C.supr_models, f'supr_{gender}{"_constrained" if constrained else ""}.npy') | ||
super(SUPRLayer, self).__init__(path_model, num_betas=num_betas) | ||
|
||
self.device = device if device is not None else C.device | ||
self.model_type = "supr" | ||
self._parents = None | ||
self._children = None | ||
|
||
@property | ||
def parents(self): | ||
"""Return how the joints are connected in the kinematic chain where parents[i, 0] is the parent of | ||
joint parents[i, 1].""" | ||
if self._parents is None: | ||
self._parents = self.kintree_table.transpose(0, 1).cpu().numpy() | ||
return self._parents | ||
|
||
@property | ||
def joint_children(self): | ||
"""Return the children of each joint in the kinematic chain.""" | ||
if self._children is None: | ||
self._children = collections.defaultdict(list) | ||
for bone in self.parents: | ||
if bone[0] != -1: | ||
self._children[bone[0]].append(bone[1]) | ||
return self._children | ||
|
||
def skeletons(self): | ||
"""Return how the joints are connected in the kinematic chain where skeleton[0, i] is the parent of | ||
joint skeleton[1, i].""" | ||
kintree_table = self.kintree_table | ||
kintree_table[:, 0] = -1 | ||
return { | ||
"all": kintree_table, | ||
"body": kintree_table[:, : self.n_joints_body + 1], | ||
} | ||
|
||
@property | ||
def n_joints_body(self): | ||
return self.parent.shape[0] | ||
|
||
@property | ||
def n_joints_total(self): | ||
return self.n_joints_body + 1 | ||
|
||
def forward(self, poses_body, betas=None, poses_root=None, trans=None, normalize_root=False): | ||
""" | ||
forwards the model | ||
:param poses_body: Pose parameters. | ||
:param poses_root: Pose parameters for the root joint. | ||
:param beta: Beta parameters. | ||
:param trans: Root translation. | ||
:param normalize_root: Makes poses relative to the root joint (useful for globally rotated captures). | ||
:return: Deformed surface vertices, transformed joints | ||
""" | ||
poses, betas, trans = self.preprocess(poses_body, betas, poses_root, trans, normalize_root) | ||
|
||
# SUPR repo currently hardcodes floats. | ||
v = super().forward(pose=poses.float(), betas=betas.float(), trans=trans.float()) | ||
J = v.J_transformed | ||
return v, J | ||
|
||
def preprocess(self, poses_body, betas=None, poses_root=None, trans=None, normalize_root=False): | ||
batch_size = poses_body.shape[0] | ||
|
||
if poses_root is None: | ||
poses_root = torch.zeros([batch_size, 3]).to(dtype=poses_body.dtype, device=self.device) | ||
if trans is None: | ||
# If we don't supply the root translation explicitly, it falls back to using self.bm.trans | ||
# which might not be zero since it is a trainable param that can get updated. | ||
trans = torch.zeros([batch_size, 3]).to(dtype=poses_body.dtype, device=self.device) | ||
|
||
if normalize_root: | ||
# Make everything relative to the first root orientation. | ||
root_ori = aa2rot(poses_root) | ||
first_root_ori = torch.inverse(root_ori[0:1]) | ||
root_ori = torch.matmul(first_root_ori, root_ori) | ||
poses_root = rot2aa(root_ori) | ||
trans = torch.matmul(first_root_ori.unsqueeze(0), trans.unsqueeze(-1)).squeeze() | ||
trans = trans - trans[0:1] | ||
|
||
poses = torch.cat((poses_root, poses_body), dim=1) | ||
|
||
if betas is None: | ||
betas = torch.zeros([batch_size, self.num_betas]).to(dtype=poses_body.dtype, device=self.device) | ||
|
||
# Batch shapes if they don't match batch dimension. | ||
if betas.shape[0] != batch_size: | ||
betas = betas.repeat(batch_size, 1) | ||
|
||
# Lower bound betas | ||
if betas.shape[1] < self.num_betas: | ||
betas = torch.nn.functional.pad(betas, [0, self.num_betas - betas.shape[1]]) | ||
|
||
# Upper bound betas | ||
betas = betas[:, : self.num_betas] | ||
|
||
return poses, betas, trans |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
""" | ||
Copyright (C) 2022 ETH Zurich, Manuel Kaufmann, Velko Vechev, Dario Mylonopoulos | ||
This program is free software: you can redistribute it and/or modify | ||
it under the terms of the GNU General Public License as published by | ||
the Free Software Foundation, either version 3 of the License, or | ||
(at your option) any later version. | ||
This program is distributed in the hope that it will be useful, | ||
but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
GNU General Public License for more details. | ||
You should have received a copy of the GNU General Public License | ||
along with this program. If not, see <https://www.gnu.org/licenses/>. | ||
""" | ||
import numpy as np | ||
|
||
from aitviewer.configuration import CONFIG as C | ||
from aitviewer.models.supr import SUPRLayer | ||
from aitviewer.renderables.smpl import SMPLSequence | ||
from aitviewer.utils import to_numpy as c2c | ||
|
||
|
||
class SUPRSequence(SMPLSequence): | ||
""" | ||
Represents a temporal sequence of SMPL poses using the SUPR model. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
poses_body, | ||
smpl_layer, | ||
poses_root, | ||
betas=None, | ||
trans=None, | ||
device=None, | ||
include_root=True, | ||
normalize_root=False, | ||
is_rigged=True, | ||
show_joint_angles=False, | ||
z_up=False, | ||
post_fk_func=None, | ||
**kwargs, | ||
): | ||
super(SUPRSequence, self).__init__( | ||
poses_body, | ||
smpl_layer, | ||
poses_root, | ||
betas, | ||
trans, | ||
device=device, | ||
include_root=include_root, | ||
normalize_root=normalize_root, | ||
is_rigged=is_rigged, | ||
show_joint_angles=show_joint_angles, | ||
z_up=z_up, | ||
post_fk_func=post_fk_func, | ||
**kwargs, | ||
) | ||
|
||
def fk(self, current_frame_only=False): | ||
"""Get joints and/or vertices from the poses.""" | ||
if current_frame_only: | ||
# Use current frame data. | ||
if self._edit_mode: | ||
poses_root = self._edit_pose[:3][None, :] | ||
poses_body = self._edit_pose[3:][None, :] | ||
else: | ||
poses_body = self.poses_body[self.current_frame_id][None, :] | ||
poses_root = self.poses_root[self.current_frame_id][None, :] | ||
|
||
trans = self.trans[self.current_frame_id][None, :] | ||
|
||
if self.betas.shape[0] == self.n_frames: | ||
betas = self.betas[self.current_frame_id][None, :] | ||
else: | ||
betas = self.betas | ||
else: | ||
# Use the whole sequence. | ||
if self._edit_mode: | ||
poses_root = self.poses_root.clone() | ||
poses_body = self.poses_body.clone() | ||
poses_root[self.current_frame_id] = self._edit_pose[:3] | ||
poses_body[self.current_frame_id] = self._edit_pose[3:] | ||
else: | ||
poses_body = self.poses_body | ||
poses_root = self.poses_root | ||
trans = self.trans | ||
betas = self.betas | ||
|
||
verts, joints = self.smpl_layer( | ||
poses_root=poses_root, | ||
poses_body=poses_body, | ||
betas=betas, | ||
trans=trans, | ||
normalize_root=self._normalize_root, | ||
) | ||
|
||
skeleton = self.smpl_layer.skeletons()["body"].T | ||
faces = self.smpl_layer.faces | ||
joints = joints[:, : skeleton.shape[0]] | ||
|
||
if current_frame_only: | ||
return c2c(verts)[0], c2c(joints)[0], c2c(faces), c2c(skeleton) | ||
else: | ||
return c2c(verts), c2c(joints), c2c(faces), c2c(skeleton) | ||
|
||
@classmethod | ||
def from_amass( | ||
cls, | ||
npz_data_path, | ||
start_frame=None, | ||
end_frame=None, | ||
sub_frames=None, | ||
log=True, | ||
fps_out=None, | ||
load_betas=False, | ||
z_up=True, | ||
**kwargs, | ||
): | ||
raise ValueError("SUPR does not support loading from 3DPW.") | ||
|
||
@classmethod | ||
def from_3dpw(cls, **kwargs): | ||
raise ValueError("SUPR does not support loading from 3DPW.") | ||
|
||
@classmethod | ||
def t_pose(cls, model=None, betas=None, frames=1, **kwargs): | ||
"""Creates a SMPL sequence whose single frame is a SMPL mesh in T-Pose.""" | ||
|
||
if model is None: | ||
model = SUPRLayer(device=C.device) | ||
|
||
poses_body = np.zeros([frames, model.n_joints_body * 3]) | ||
poses_root = np.zeros([frames, 3]) | ||
return cls( | ||
poses_body=poses_body, | ||
smpl_layer=model, | ||
poses_root=poses_root, | ||
betas=betas, | ||
**kwargs, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
""" | ||
Copyright (C) 2023 ETH Zurich, Manuel Kaufmann, Velko Vechev, Dario Mylonopoulos | ||
This program is free software: you can redistribute it and/or modify | ||
it under the terms of the GNU General Public License as published by | ||
the Free Software Foundation, either version 3 of the License, or | ||
(at your option) any later version. | ||
This program is distributed in the hope that it will be useful, | ||
but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
GNU General Public License for more details. | ||
You should have received a copy of the GNU General Public License | ||
along with this program. If not, see <https://www.gnu.org/licenses/>. | ||
""" | ||
from aitviewer.renderables.supr import SUPRLayer, SUPRSequence | ||
from aitviewer.viewer import Viewer | ||
|
||
# Instantiate a SUPR layer. This requires that the respective repo has been installed via | ||
# pip install git+https://github.com/ahmedosman/SUPR.git and that the model files are available on the path | ||
# specified in `C.supr_models`. | ||
# | ||
# The directory structure should be: | ||
# | ||
# - models | ||
# |- supr_female.npy | ||
# |- supr_female_constrained.npy | ||
# |- supr_male.npy | ||
# |- supr_male_constrained.npy | ||
# |- supr_neutral.npy | ||
# |- supr_neutral_constrained.npy | ||
model = SUPRLayer(constrained=False) | ||
|
||
# Create a male SUPR T Pose. | ||
template = SUPRSequence.t_pose(model, color=(0.62, 0.62, 0.62, 0.8)) | ||
|
||
v = Viewer() | ||
v.scene.add(template) | ||
v.run() |