Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge dev into main #63

Merged
merged 14 commits into from
Oct 3, 2024
2 changes: 1 addition & 1 deletion .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
python-version: ['3.9', '3.10', '3.11', '3.12']
steps:
- uses: actions/checkout@v2

Expand Down
7 changes: 6 additions & 1 deletion hermes/aeriel/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,13 @@ def infer(
except KeyError:
raise ValueError(f"Missing state {name}")

# sometimes we can have a batched state, in which
# case don't append a batch dimension
if shape[0] == 1 and value.ndim < len(shape):
value = value[None]

# add the update to our running list of updates
state_values.append(value[None])
state_values.append(value)

# if we have more than one state, combine them
# into a single tensor along the channel axis
Expand Down
3 changes: 2 additions & 1 deletion hermes/quiver/exporters/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ def handles(self):
"Platform metaclass has no `handles` property"
)

@abc.abstractproperty
@property
@abc.abstractmethod
def platform(self) -> "Platform":
try:
return type(self).platform
Expand Down
8 changes: 5 additions & 3 deletions hermes/quiver/exporters/keras_savedmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
import os
import tempfile

from hermes.quiver import Platform
from hermes.quiver.exporters import Exporter

os.environ["TF_USE_LEGACY_KERAS"] = "1"

try:
import tensorflow as tf

_has_tf = True
except ImportError:
_has_tf = False

from hermes.quiver import Platform
from hermes.quiver.exporters import Exporter


class KerasSavedModelMeta(abc.ABCMeta):
@property
Expand Down
12 changes: 5 additions & 7 deletions hermes/quiver/exporters/torch_onnx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
import inspect
from collections import OrderedDict
from io import BytesIO

Expand All @@ -12,6 +11,7 @@

from hermes.quiver import Platform
from hermes.quiver.exporters import Exporter
from hermes.quiver.exporters.utils import get_input_names_from_torch_object


class TorchOnnxMeta(abc.ABCMeta):
Expand All @@ -21,7 +21,7 @@ def handles(self):
raise ImportError(
"Must have torch installed to use TorchOnnx export platform"
)
return torch.nn.Module
return (torch.nn.Module, torch.jit.ScriptModule)

@property
def platform(self):
Expand Down Expand Up @@ -55,11 +55,9 @@ def _get_output_shapes(self, model_fn, output_names):
# generate an input array of random data
input_tensors[input.name] = self._get_tensor(input.dims)

# use function signature from module.forward
# to figure out in which order to pass input
# tensors to the model_fn
signature = inspect.signature(model_fn.forward)
parameters = OrderedDict(signature.parameters)
# parse either a `ScriptModule` or `torch.nn.Module`
# to figure out in which order to pass input tensors to the model_fn
parameters = get_input_names_from_torch_object(model_fn)

# make sure the number of inputs to
# the model_fn matches the number of
Expand Down
58 changes: 42 additions & 16 deletions hermes/quiver/exporters/torchscript.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
import inspect
import warnings
from collections import OrderedDict
from io import BytesIO

Expand All @@ -12,6 +12,7 @@

from hermes.quiver import Platform
from hermes.quiver.exporters import Exporter
from hermes.quiver.exporters.utils import get_input_names_from_torch_object


class TorchScriptMeta(abc.ABCMeta):
Expand All @@ -21,7 +22,7 @@ def handles(self):
raise ImportError(
"Must have torch installed to use TorchScript export platform"
)
return torch.nn.Module
return (torch.nn.Module, torch.jit.ScriptModule)

@property
def platform(self):
Expand All @@ -32,12 +33,26 @@ class TorchScript(Exporter, metaclass=TorchScriptMeta):
def __call__(
self, model_fn, version, input_shapes, output_names=None
) -> None:
if output_names is not None:
raise ValueError(
"Cannot specify output_names for TorchScript exporter"
# if a dictionary is passed
# (i.e. user specified names for tensors)
# warn the user about specific naming conventions
# for tensor input names from triton
if isinstance(input_shapes, dict):
warnings.warn(
"Triton expects specific naming conventions and "
"ordering for tensor input names. Be careful. See "
"https://docs.nvidia.com/deeplearning/triton-inference-server/"
"user-guide/docs/user_guide/model_configuration.html"
"#special-conventions-for-pytorch-backend"
)

input_shapes = {f"INPUT__{i}": j for i, j in enumerate(input_shapes)}
return super().__call__(
model_fn, version, input_shapes, output_names
)
# otherwise, user passed a sequence of shapes:
# use tritons recommended naming conventions
# by inferring the names from the model_fn
parameters = get_input_names_from_torch_object(model_fn)
input_shapes = {p: s for p, s in zip(parameters, input_shapes)}
super().__call__(model_fn, version, input_shapes, output_names)

def _get_tensor(self, shape):
Expand All @@ -56,7 +71,7 @@ def _get_tensor(self, shape):
# and pass them along if they were provided?
return torch.randn(*tensor_shape)

def _get_output_shapes(self, model_fn, _):
def _get_output_shapes(self, model_fn, output_names=None):
# now that we know we have inputs added to our
# model config, use that config to generate
# framework tensors that we'll feed through
Expand All @@ -66,11 +81,9 @@ def _get_output_shapes(self, model_fn, _):
# generate an input array of random data
input_tensors[input.name] = self._get_tensor(input.dims)

# use function signature from module.forward
# to figure out in which order to pass input
# tensors to the model_fn
signature = inspect.signature(model_fn.forward)
parameters = OrderedDict(signature.parameters)
# parse either a `ScriptModule` or `torch.nn.Module`
# to figure out in which order to pass input tensors to the model_fn
parameters = get_input_names_from_torch_object(model_fn)

# make sure the number of inputs to
# the model_fn matches the number of
Expand Down Expand Up @@ -122,9 +135,22 @@ def _get_output_shapes(self, model_fn, _):
if any([x.dims[0] == -1 for x in self.config.input]):
shapes = [(None,) + s[1:] for s in shapes]

# if we provided names for the outputs, return them
# as a dict for validation against the config
shapes = {f"OUTPUT__{i}": j for i, j in enumerate(shapes)}
# if we didn't provide names for the outputs,
# use the "OUTPUT__{i}" format
if output_names is None:
shapes = {f"OUTPUT__{i}": j for i, j in enumerate(shapes)}
else:
warnings.warn(
"Triton expects specific naming conventions "
"and ordering for tensor output names. Be careful. See "
"https://docs.nvidia.com/deeplearning/triton-inference-server/"
"user-guide/docs/user_guide/model_configuration.html"
"#special-conventions-for-pytorch-backend"
)
shapes = {
name: shape
for i, (name, shape) in enumerate(zip(output_names, shapes))
}
return shapes

def export(self, model_fn, export_path, verbose=0, **kwargs):
Expand Down
31 changes: 31 additions & 0 deletions hermes/quiver/exporters/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import inspect
from collections import OrderedDict
from typing import TYPE_CHECKING, Callable, Union

from .exporter import Exporter

try:
import torch

_has_torch = True
except ImportError:
_has_torch = False

if TYPE_CHECKING:
from hermes.quiver.model import Model

Expand All @@ -19,6 +28,9 @@ class to find the first for which `model_fn` is an instance of
model_fn:
The framework-specific function which performs the
neural network's input/output mapping
model:
The `Model` object which specifies the desired export platform
and configuration information for the model
Returns:
An exporter to export `model_fn` to the format specified
by this models' inference platform
Expand Down Expand Up @@ -53,3 +65,22 @@ def _get_all_subclasses(cls):
type(model_fn), model.platform
)
)


def get_input_names_from_torch_object(
model_fn: Union["torch.nn.Module", "torch.jit.ScriptModule"]
):
"""
Parse either a torch.nn.Module or torch.ScriptModule for input names
"""

if isinstance(model_fn, torch.jit.ScriptModule):
graph = model_fn.graph
input_names = [
node.debugName().split(".")[0] for node in graph.inputs()
]
if "self" in input_names:
input_names.remove("self")
return OrderedDict({name: name for name in input_names})
signature = inspect.signature(model_fn.forward)
return OrderedDict(signature.parameters)
Loading
Loading