Skip to content

Commit

Permalink
SUPR model support and load_STAR example
Browse files Browse the repository at this point in the history
  • Loading branch information
ramenguy99 committed Aug 5, 2023
1 parent ebd1807 commit d872817
Show file tree
Hide file tree
Showing 4 changed files with 327 additions and 0 deletions.
1 change: 1 addition & 0 deletions aitviewer/aitvconfig.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Access SMPL models.
smplx_models: "../data/smplx_models"
star_models: "../data/star_models"
supr_models: "../data/supr_models"

# Access to datasets.
datasets:
Expand Down
143 changes: 143 additions & 0 deletions aitviewer/models/supr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""
Copyright (C) 2023 ETH Zurich, Manuel Kaufmann, Velko Vechev, Dario Mylonopoulos
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import collections
import os

import torch
import torch.nn.functional as F

try:
from supr.config import cfg
from supr.pytorch.supr import SUPR
except Exception as e:
raise ImportError(f"Cannot import SUPR. Please run `pip install git+https://github.com/ahmedosman/SUPR.git`\n{e}")

from aitviewer.configuration import CONFIG as C
from aitviewer.utils.so3 import aa2rot_torch as aa2rot
from aitviewer.utils.so3 import rot2aa_torch as rot2aa


class SUPRLayer(SUPR):
"""Wraps the publicly available SUPR model to match SMPLX model interface"""

def __init__(self, gender="male", num_betas=10, constrained=False, device=None, dtype=None):
"""
Initializer.
:param gender: Which gender to load.
:param num_betas: Number of shape components.
:param device: CPU or GPU.
:param dtype: The pytorch floating point data type.
"""
# Configure SUPR model before initializing
cfg.data_type = dtype if dtype is not None else C.f_precision
path_model = os.path.join(C.supr_models, f'supr_{gender}{"_constrained" if constrained else ""}.npy')
super(SUPRLayer, self).__init__(path_model, num_betas=num_betas)

self.device = device if device is not None else C.device
self.model_type = "supr"
self._parents = None
self._children = None

@property
def parents(self):
"""Return how the joints are connected in the kinematic chain where parents[i, 0] is the parent of
joint parents[i, 1]."""
if self._parents is None:
self._parents = self.kintree_table.transpose(0, 1).cpu().numpy()
return self._parents

@property
def joint_children(self):
"""Return the children of each joint in the kinematic chain."""
if self._children is None:
self._children = collections.defaultdict(list)
for bone in self.parents:
if bone[0] != -1:
self._children[bone[0]].append(bone[1])
return self._children

def skeletons(self):
"""Return how the joints are connected in the kinematic chain where skeleton[0, i] is the parent of
joint skeleton[1, i]."""
kintree_table = self.kintree_table
kintree_table[:, 0] = -1
return {
"all": kintree_table,
"body": kintree_table[:, : self.n_joints_body + 1],
}

@property
def n_joints_body(self):
return self.parent.shape[0]

@property
def n_joints_total(self):
return self.n_joints_body + 1

def forward(self, poses_body, betas=None, poses_root=None, trans=None, normalize_root=False):
"""
forwards the model
:param poses_body: Pose parameters.
:param poses_root: Pose parameters for the root joint.
:param beta: Beta parameters.
:param trans: Root translation.
:param normalize_root: Makes poses relative to the root joint (useful for globally rotated captures).
:return: Deformed surface vertices, transformed joints
"""
poses, betas, trans = self.preprocess(poses_body, betas, poses_root, trans, normalize_root)

# SUPR repo currently hardcodes floats.
v = super().forward(pose=poses.float(), betas=betas.float(), trans=trans.float())
J = v.J_transformed
return v, J

def preprocess(self, poses_body, betas=None, poses_root=None, trans=None, normalize_root=False):
batch_size = poses_body.shape[0]

if poses_root is None:
poses_root = torch.zeros([batch_size, 3]).to(dtype=poses_body.dtype, device=self.device)
if trans is None:
# If we don't supply the root translation explicitly, it falls back to using self.bm.trans
# which might not be zero since it is a trainable param that can get updated.
trans = torch.zeros([batch_size, 3]).to(dtype=poses_body.dtype, device=self.device)

if normalize_root:
# Make everything relative to the first root orientation.
root_ori = aa2rot(poses_root)
first_root_ori = torch.inverse(root_ori[0:1])
root_ori = torch.matmul(first_root_ori, root_ori)
poses_root = rot2aa(root_ori)
trans = torch.matmul(first_root_ori.unsqueeze(0), trans.unsqueeze(-1)).squeeze()
trans = trans - trans[0:1]

poses = torch.cat((poses_root, poses_body), dim=1)

if betas is None:
betas = torch.zeros([batch_size, self.num_betas]).to(dtype=poses_body.dtype, device=self.device)

# Batch shapes if they don't match batch dimension.
if betas.shape[0] != batch_size:
betas = betas.repeat(batch_size, 1)

# Lower bound betas
if betas.shape[1] < self.num_betas:
betas = torch.nn.functional.pad(betas, [0, self.num_betas - betas.shape[1]])

# Upper bound betas
betas = betas[:, : self.num_betas]

return poses, betas, trans
143 changes: 143 additions & 0 deletions aitviewer/renderables/supr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""
Copyright (C) 2022 ETH Zurich, Manuel Kaufmann, Velko Vechev, Dario Mylonopoulos
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import numpy as np

from aitviewer.configuration import CONFIG as C
from aitviewer.models.supr import SUPRLayer
from aitviewer.renderables.smpl import SMPLSequence
from aitviewer.utils import to_numpy as c2c


class SUPRSequence(SMPLSequence):
"""
Represents a temporal sequence of SMPL poses using the SUPR model.
"""

def __init__(
self,
poses_body,
smpl_layer,
poses_root,
betas=None,
trans=None,
device=None,
include_root=True,
normalize_root=False,
is_rigged=True,
show_joint_angles=False,
z_up=False,
post_fk_func=None,
**kwargs,
):
super(SUPRSequence, self).__init__(
poses_body,
smpl_layer,
poses_root,
betas,
trans,
device=device,
include_root=include_root,
normalize_root=normalize_root,
is_rigged=is_rigged,
show_joint_angles=show_joint_angles,
z_up=z_up,
post_fk_func=post_fk_func,
**kwargs,
)

def fk(self, current_frame_only=False):
"""Get joints and/or vertices from the poses."""
if current_frame_only:
# Use current frame data.
if self._edit_mode:
poses_root = self._edit_pose[:3][None, :]
poses_body = self._edit_pose[3:][None, :]
else:
poses_body = self.poses_body[self.current_frame_id][None, :]
poses_root = self.poses_root[self.current_frame_id][None, :]

trans = self.trans[self.current_frame_id][None, :]

if self.betas.shape[0] == self.n_frames:
betas = self.betas[self.current_frame_id][None, :]
else:
betas = self.betas
else:
# Use the whole sequence.
if self._edit_mode:
poses_root = self.poses_root.clone()
poses_body = self.poses_body.clone()
poses_root[self.current_frame_id] = self._edit_pose[:3]
poses_body[self.current_frame_id] = self._edit_pose[3:]
else:
poses_body = self.poses_body
poses_root = self.poses_root
trans = self.trans
betas = self.betas

verts, joints = self.smpl_layer(
poses_root=poses_root,
poses_body=poses_body,
betas=betas,
trans=trans,
normalize_root=self._normalize_root,
)

skeleton = self.smpl_layer.skeletons()["body"].T
faces = self.smpl_layer.faces
joints = joints[:, : skeleton.shape[0]]

if current_frame_only:
return c2c(verts)[0], c2c(joints)[0], c2c(faces), c2c(skeleton)
else:
return c2c(verts), c2c(joints), c2c(faces), c2c(skeleton)

@classmethod
def from_amass(
cls,
npz_data_path,
start_frame=None,
end_frame=None,
sub_frames=None,
log=True,
fps_out=None,
load_betas=False,
z_up=True,
**kwargs,
):
raise ValueError("SUPR does not support loading from 3DPW.")

@classmethod
def from_3dpw(cls, **kwargs):
raise ValueError("SUPR does not support loading from 3DPW.")

@classmethod
def t_pose(cls, model=None, betas=None, frames=1, **kwargs):
"""Creates a SMPL sequence whose single frame is a SMPL mesh in T-Pose."""

if model is None:
model = SUPRLayer(device=C.device)

poses_body = np.zeros([frames, model.n_joints_body * 3])
poses_root = np.zeros([frames, 3])
return cls(
poses_body=poses_body,
smpl_layer=model,
poses_root=poses_root,
betas=betas,
**kwargs,
)
40 changes: 40 additions & 0 deletions examples/load_SUPR.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Copyright (C) 2023 ETH Zurich, Manuel Kaufmann, Velko Vechev, Dario Mylonopoulos
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from aitviewer.renderables.supr import SUPRLayer, SUPRSequence
from aitviewer.viewer import Viewer

# Instantiate a SUPR layer. This requires that the respective repo has been installed via
# pip install git+https://github.com/ahmedosman/SUPR.git and that the model files are available on the path
# specified in `C.supr_models`.
#
# The directory structure should be:
#
# - models
# |- supr_female.npy
# |- supr_female_constrained.npy
# |- supr_male.npy
# |- supr_male_constrained.npy
# |- supr_neutral.npy
# |- supr_neutral_constrained.npy
model = SUPRLayer(constrained=False)

# Create a male SUPR T Pose.
template = SUPRSequence.t_pose(model, color=(0.62, 0.62, 0.62, 0.8))

v = Viewer()
v.scene.add(template)
v.run()

0 comments on commit d872817

Please sign in to comment.