Skip to content

Commit

Permalink
More intelligent and flexible input and output tensor naming for `Tor…
Browse files Browse the repository at this point in the history
…chScript` Exporter (#62)

* support direct torchscript export from ScriptModule

* add tests of new naming behavior

* guard torch import

* add jit module to handles for onnx and torchscript exporters

* add script module test to exporter utils
  • Loading branch information
EthanMarx committed Oct 3, 2024
1 parent 2bf99c7 commit a122342
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 49 deletions.
3 changes: 2 additions & 1 deletion hermes/quiver/exporters/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ def handles(self):
"Platform metaclass has no `handles` property"
)

@abc.abstractproperty
@property
@abc.abstractmethod
def platform(self) -> "Platform":
try:
return type(self).platform
Expand Down
23 changes: 5 additions & 18 deletions hermes/quiver/exporters/torch_onnx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
import inspect
from collections import OrderedDict
from io import BytesIO

Expand All @@ -12,14 +11,7 @@

from hermes.quiver import Platform
from hermes.quiver.exporters import Exporter


def get_input_names_from_script_module(m):
graph = m.graph
input_names = [node.debugName().split(".")[0] for node in graph.inputs()]
if "self" in input_names:
input_names.remove("self")
return OrderedDict({name: name for name in input_names})
from hermes.quiver.exporters.utils import get_input_names_from_torch_object


class TorchOnnxMeta(abc.ABCMeta):
Expand All @@ -29,7 +21,7 @@ def handles(self):
raise ImportError(
"Must have torch installed to use TorchOnnx export platform"
)
return torch.nn.Module
return (torch.nn.Module, torch.jit.ScriptModule)

@property
def platform(self):
Expand Down Expand Up @@ -63,14 +55,9 @@ def _get_output_shapes(self, model_fn, output_names):
# generate an input array of random data
input_tensors[input.name] = self._get_tensor(input.dims)

# parse script module to figure out in which order
# to pass input tensors to the model_fn
if isinstance(model_fn, torch.jit.ScriptModule):
parameters = get_input_names_from_script_module(model_fn)
# otherwise use function signature from module.forward
else:
signature = inspect.signature(model_fn.forward)
parameters = OrderedDict(signature.parameters)
# parse either a `ScriptModule` or `torch.nn.Module`
# to figure out in which order to pass input tensors to the model_fn
parameters = get_input_names_from_torch_object(model_fn)

# make sure the number of inputs to
# the model_fn matches the number of
Expand Down
58 changes: 42 additions & 16 deletions hermes/quiver/exporters/torchscript.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
import inspect
import warnings
from collections import OrderedDict
from io import BytesIO

Expand All @@ -12,6 +12,7 @@

from hermes.quiver import Platform
from hermes.quiver.exporters import Exporter
from hermes.quiver.exporters.utils import get_input_names_from_torch_object


class TorchScriptMeta(abc.ABCMeta):
Expand All @@ -21,7 +22,7 @@ def handles(self):
raise ImportError(
"Must have torch installed to use TorchScript export platform"
)
return torch.nn.Module
return (torch.nn.Module, torch.jit.ScriptModule)

@property
def platform(self):
Expand All @@ -32,12 +33,26 @@ class TorchScript(Exporter, metaclass=TorchScriptMeta):
def __call__(
self, model_fn, version, input_shapes, output_names=None
) -> None:
if output_names is not None:
raise ValueError(
"Cannot specify output_names for TorchScript exporter"
# if a dictionary is passed
# (i.e. user specified names for tensors)
# warn the user about specific naming conventions
# for tensor input names from triton
if isinstance(input_shapes, dict):
warnings.warn(
"Triton expects specific naming conventions and "
"ordering for tensor input names. Be careful. See "
"https://docs.nvidia.com/deeplearning/triton-inference-server/"
"user-guide/docs/user_guide/model_configuration.html"
"#special-conventions-for-pytorch-backend"
)

input_shapes = {f"INPUT__{i}": j for i, j in enumerate(input_shapes)}
return super().__call__(
model_fn, version, input_shapes, output_names
)
# otherwise, user passed a sequence of shapes:
# use tritons recommended naming conventions
# by inferring the names from the model_fn
parameters = get_input_names_from_torch_object(model_fn)
input_shapes = {p: s for p, s in zip(parameters, input_shapes)}
super().__call__(model_fn, version, input_shapes, output_names)

def _get_tensor(self, shape):
Expand All @@ -56,7 +71,7 @@ def _get_tensor(self, shape):
# and pass them along if they were provided?
return torch.randn(*tensor_shape)

def _get_output_shapes(self, model_fn, _):
def _get_output_shapes(self, model_fn, output_names=None):
# now that we know we have inputs added to our
# model config, use that config to generate
# framework tensors that we'll feed through
Expand All @@ -66,11 +81,9 @@ def _get_output_shapes(self, model_fn, _):
# generate an input array of random data
input_tensors[input.name] = self._get_tensor(input.dims)

# use function signature from module.forward
# to figure out in which order to pass input
# tensors to the model_fn
signature = inspect.signature(model_fn.forward)
parameters = OrderedDict(signature.parameters)
# parse either a `ScriptModule` or `torch.nn.Module`
# to figure out in which order to pass input tensors to the model_fn
parameters = get_input_names_from_torch_object(model_fn)

# make sure the number of inputs to
# the model_fn matches the number of
Expand Down Expand Up @@ -122,9 +135,22 @@ def _get_output_shapes(self, model_fn, _):
if any([x.dims[0] == -1 for x in self.config.input]):
shapes = [(None,) + s[1:] for s in shapes]

# if we provided names for the outputs, return them
# as a dict for validation against the config
shapes = {f"OUTPUT__{i}": j for i, j in enumerate(shapes)}
# if we didn't provide names for the outputs,
# use the "OUTPUT__{i}" format
if output_names is None:
shapes = {f"OUTPUT__{i}": j for i, j in enumerate(shapes)}
else:
warnings.warn(
"Triton expects specific naming conventions "
"and ordering for tensor output names. Be careful. See "
"https://docs.nvidia.com/deeplearning/triton-inference-server/"
"user-guide/docs/user_guide/model_configuration.html"
"#special-conventions-for-pytorch-backend"
)
shapes = {
name: shape
for i, (name, shape) in enumerate(zip(output_names, shapes))
}
return shapes

def export(self, model_fn, export_path, verbose=0, **kwargs):
Expand Down
31 changes: 31 additions & 0 deletions hermes/quiver/exporters/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import inspect
from collections import OrderedDict
from typing import TYPE_CHECKING, Callable, Union

from .exporter import Exporter

try:
import torch

_has_torch = True
except ImportError:
_has_torch = False

if TYPE_CHECKING:
from hermes.quiver.model import Model

Expand All @@ -19,6 +28,9 @@ class to find the first for which `model_fn` is an instance of
model_fn:
The framework-specific function which performs the
neural network's input/output mapping
model:
The `Model` object which specifies the desired export platform
and configuration information for the model
Returns:
An exporter to export `model_fn` to the format specified
by this models' inference platform
Expand Down Expand Up @@ -53,3 +65,22 @@ def _get_all_subclasses(cls):
type(model_fn), model.platform
)
)


def get_input_names_from_torch_object(
model_fn: Union["torch.nn.Module", "torch.jit.ScriptModule"]
):
"""
Parse either a torch.nn.Module or torch.ScriptModule for input names
"""

if isinstance(model_fn, torch.jit.ScriptModule):
graph = model_fn.graph
input_names = [
node.debugName().split(".")[0] for node in graph.inputs()
]
if "self" in input_names:
input_names.remove("self")
return OrderedDict({name: name for name in input_names})
signature = inspect.signature(model_fn.forward)
return OrderedDict(signature.parameters)
7 changes: 2 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions tests/quiver/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,19 @@ def forward(self, x):
return torch.matmul(x, self.W)

return Model(dim)


@pytest.fixture
def torch_model_2(dim):
import torch

class Model(torch.nn.Module):
def __init__(self, size: int = 10):
super().__init__()
self.size = size
self.W = torch.eye(size)

def forward(self, x, y):
return torch.matmul(x, self.W), y

return Model(dim)
16 changes: 15 additions & 1 deletion tests/quiver/exporters/test_exporter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import pytest

from hermes.quiver.exporters import KerasSavedModel, TorchOnnx, utils
from hermes.quiver.exporters import (
KerasSavedModel,
TorchOnnx,
TorchScript,
utils,
)
from hermes.quiver.platform import Platform


Expand Down Expand Up @@ -33,3 +38,12 @@ def test_find_torch_exporter(torch_model):
good_model = make_model(Platform.ONNX)
bad_model = make_model(Platform.SAVEDMODEL)
_test_find_exporter(torch_model, good_model, bad_model, TorchOnnx)

good_model = make_model(Platform.TORCHSCRIPT)
_test_find_exporter(torch_model, good_model, bad_model, TorchScript)

import torch

script_module = torch.jit.trace(torch_model, torch.randn(1, 10))
good_model = make_model(Platform.TORCHSCRIPT)
_test_find_exporter(script_module, good_model, bad_model, TorchScript)
75 changes: 67 additions & 8 deletions tests/quiver/exporters/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,28 @@


@pytest.mark.torch
def test_torchscript_exporter(temp_local_repo, torch_model):
def test_torchscript_exporter(temp_local_repo, torch_model, torch_model_2):
model_fn = torch_model

model = Model("identity", temp_local_repo, Platform.TORCHSCRIPT)
exporter = TorchScript(model.config, model.fs)

input_shapes = {"x": (None, 10)}
input_shapes = {"input0": (None, 10)}
exporter._check_exposed_tensors("input", input_shapes)
assert len(model.config.input) == 1
assert model.config.input[0].name == "x"
assert model.config.input[0].name == "input0"
assert model.config.input[0].dims[0] == -1

bad_input_shapes = {"x": (None, 12)}
with pytest.raises(ValueError):
exporter._check_exposed_tensors("input", bad_input_shapes)

output_shapes = exporter._get_output_shapes(model_fn, "y")
assert output_shapes["OUTPUT__0"] == (None, 10)
assert output_shapes["y"] == (None, 10)

exporter._check_exposed_tensors("output", output_shapes)
assert len(model.config.output) == 1
assert model.config.output[0].name == "OUTPUT__0"
assert model.config.output[0].name == "y"
assert model.config.output[0].dims[0] == -1

version_path = temp_local_repo.fs.join("identity", "1")
Expand All @@ -39,8 +39,67 @@ def test_torchscript_exporter(temp_local_repo, torch_model):
exporter = TorchScript(model2.config, model2.fs)

model_path = temp_local_repo.fs.root
with pytest.raises(ValueError):
exporter(model_fn, model_path, [(None, 10)], ["BAD_NAME"])

# if a dictionary if input_shapes is not passed
# (i.e. no user specified name) it should default
# to the name of the input tensor in the forward call of the model
# in this case, "x"
exporter(model_fn, model_path, [(None, 10)])
assert "INPUT__0" in model2.inputs
assert "x" in model2.inputs
assert "OUTPUT__0" in model2.outputs

# if a dictionary of input_shapes is passed, it should
# use the user specified names
model3 = Model("identity3", temp_local_repo, Platform.TORCHSCRIPT)
exporter = TorchScript(model3.config, model3.fs)
model_path = temp_local_repo.fs.root

exporter(model_fn, model_path, input_shapes={"my_name": (None, 10)})
assert "my_name" in model3.inputs
assert "OUTPUT__0" in model3.outputs

# now check using non-default output names
model4 = Model("identity4", temp_local_repo, Platform.TORCHSCRIPT)
exporter = TorchScript(model4.config, model4.fs)
model_path = temp_local_repo.fs.root

exporter(
model_fn,
model_path,
input_shapes={"my_name": (None, 10)},
output_names=["my_output"],
)
assert "my_name" in model4.inputs
assert "my_output" in model4.outputs

# test a model with multiple inputs and outputs
model_fn = torch_model_2

model = Model("identity", temp_local_repo, Platform.TORCHSCRIPT)
exporter = TorchScript(model.config, model.fs)

input_shapes = {"input0": (None, 10), "input1": (None, 10)}
exporter._check_exposed_tensors("input", input_shapes)
assert len(model.config.input) == 2
assert model.config.input[0].name == "input0"
assert model.config.input[1].name == "input1"
assert model.config.input[0].dims[0] == -1
assert model.config.input[1].dims[0] == -1

bad_input_shapes = {"x": (None, 12)}
with pytest.raises(ValueError):
exporter._check_exposed_tensors("input", bad_input_shapes)

output_shapes = exporter._get_output_shapes(
model_fn, ["output1", "output2"]
)
assert output_shapes["output1"] == (None, 10)
assert output_shapes["output2"] == (None, 10)

model4 = Model("identity4", temp_local_repo, Platform.TORCHSCRIPT)
exporter = TorchScript(model4.config, model4.fs)
model_path = temp_local_repo.fs.root

exporter(model_fn, model_path, input_shapes=[(None, 10), (None, 10)])
assert "x" in model4.inputs and "y" in model4.inputs
assert "OUTPUT__0" in model4.outputs and "OUTPUT__1" in model4.outputs

0 comments on commit a122342

Please sign in to comment.