Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More intelligent and flexible input and output tensor naming for TorchScript Exporter #62

Merged
merged 5 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hermes/quiver/exporters/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ def handles(self):
"Platform metaclass has no `handles` property"
)

@abc.abstractproperty
@property
@abc.abstractmethod
def platform(self) -> "Platform":
try:
return type(self).platform
Expand Down
23 changes: 5 additions & 18 deletions hermes/quiver/exporters/torch_onnx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
import inspect
from collections import OrderedDict
from io import BytesIO

Expand All @@ -12,14 +11,7 @@

from hermes.quiver import Platform
from hermes.quiver.exporters import Exporter


def get_input_names_from_script_module(m):
graph = m.graph
input_names = [node.debugName().split(".")[0] for node in graph.inputs()]
if "self" in input_names:
input_names.remove("self")
return OrderedDict({name: name for name in input_names})
from hermes.quiver.exporters.utils import get_input_names_from_torch_object


class TorchOnnxMeta(abc.ABCMeta):
Expand All @@ -29,7 +21,7 @@ def handles(self):
raise ImportError(
"Must have torch installed to use TorchOnnx export platform"
)
return torch.nn.Module
return (torch.nn.Module, torch.jit.ScriptModule)

@property
def platform(self):
Expand Down Expand Up @@ -63,14 +55,9 @@ def _get_output_shapes(self, model_fn, output_names):
# generate an input array of random data
input_tensors[input.name] = self._get_tensor(input.dims)

# parse script module to figure out in which order
# to pass input tensors to the model_fn
if isinstance(model_fn, torch.jit.ScriptModule):
parameters = get_input_names_from_script_module(model_fn)
# otherwise use function signature from module.forward
else:
signature = inspect.signature(model_fn.forward)
parameters = OrderedDict(signature.parameters)
# parse either a `ScriptModule` or `torch.nn.Module`
# to figure out in which order to pass input tensors to the model_fn
parameters = get_input_names_from_torch_object(model_fn)

# make sure the number of inputs to
# the model_fn matches the number of
Expand Down
58 changes: 42 additions & 16 deletions hermes/quiver/exporters/torchscript.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
import inspect
import warnings
from collections import OrderedDict
from io import BytesIO

Expand All @@ -12,6 +12,7 @@

from hermes.quiver import Platform
from hermes.quiver.exporters import Exporter
from hermes.quiver.exporters.utils import get_input_names_from_torch_object


class TorchScriptMeta(abc.ABCMeta):
Expand All @@ -21,7 +22,7 @@ def handles(self):
raise ImportError(
"Must have torch installed to use TorchScript export platform"
)
return torch.nn.Module
return (torch.nn.Module, torch.jit.ScriptModule)

@property
def platform(self):
Expand All @@ -32,12 +33,26 @@ class TorchScript(Exporter, metaclass=TorchScriptMeta):
def __call__(
self, model_fn, version, input_shapes, output_names=None
) -> None:
if output_names is not None:
raise ValueError(
"Cannot specify output_names for TorchScript exporter"
# if a dictionary is passed
# (i.e. user specified names for tensors)
# warn the user about specific naming conventions
# for tensor input names from triton
if isinstance(input_shapes, dict):
warnings.warn(
"Triton expects specific naming conventions and "
"ordering for tensor input names. Be careful. See "
"https://docs.nvidia.com/deeplearning/triton-inference-server/"
"user-guide/docs/user_guide/model_configuration.html"
"#special-conventions-for-pytorch-backend"
)

input_shapes = {f"INPUT__{i}": j for i, j in enumerate(input_shapes)}
return super().__call__(
model_fn, version, input_shapes, output_names
)
# otherwise, user passed a sequence of shapes:
# use tritons recommended naming conventions
# by inferring the names from the model_fn
parameters = get_input_names_from_torch_object(model_fn)
input_shapes = {p: s for p, s in zip(parameters, input_shapes)}
super().__call__(model_fn, version, input_shapes, output_names)

def _get_tensor(self, shape):
Expand All @@ -56,7 +71,7 @@ def _get_tensor(self, shape):
# and pass them along if they were provided?
return torch.randn(*tensor_shape)

def _get_output_shapes(self, model_fn, _):
def _get_output_shapes(self, model_fn, output_names=None):
# now that we know we have inputs added to our
# model config, use that config to generate
# framework tensors that we'll feed through
Expand All @@ -66,11 +81,9 @@ def _get_output_shapes(self, model_fn, _):
# generate an input array of random data
input_tensors[input.name] = self._get_tensor(input.dims)

# use function signature from module.forward
# to figure out in which order to pass input
# tensors to the model_fn
signature = inspect.signature(model_fn.forward)
parameters = OrderedDict(signature.parameters)
# parse either a `ScriptModule` or `torch.nn.Module`
# to figure out in which order to pass input tensors to the model_fn
parameters = get_input_names_from_torch_object(model_fn)

# make sure the number of inputs to
# the model_fn matches the number of
Expand Down Expand Up @@ -122,9 +135,22 @@ def _get_output_shapes(self, model_fn, _):
if any([x.dims[0] == -1 for x in self.config.input]):
shapes = [(None,) + s[1:] for s in shapes]

# if we provided names for the outputs, return them
# as a dict for validation against the config
shapes = {f"OUTPUT__{i}": j for i, j in enumerate(shapes)}
# if we didn't provide names for the outputs,
# use the "OUTPUT__{i}" format
if output_names is None:
shapes = {f"OUTPUT__{i}": j for i, j in enumerate(shapes)}
else:
warnings.warn(
"Triton expects specific naming conventions "
"and ordering for tensor output names. Be careful. See "
"https://docs.nvidia.com/deeplearning/triton-inference-server/"
"user-guide/docs/user_guide/model_configuration.html"
"#special-conventions-for-pytorch-backend"
)
shapes = {
name: shape
for i, (name, shape) in enumerate(zip(output_names, shapes))
}
return shapes

def export(self, model_fn, export_path, verbose=0, **kwargs):
Expand Down
31 changes: 31 additions & 0 deletions hermes/quiver/exporters/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import inspect
from collections import OrderedDict
from typing import TYPE_CHECKING, Callable, Union

from .exporter import Exporter

try:
import torch

_has_torch = True
except ImportError:
_has_torch = False

if TYPE_CHECKING:
from hermes.quiver.model import Model

Expand All @@ -19,6 +28,9 @@ class to find the first for which `model_fn` is an instance of
model_fn:
The framework-specific function which performs the
neural network's input/output mapping
model:
The `Model` object which specifies the desired export platform
and configuration information for the model
Returns:
An exporter to export `model_fn` to the format specified
by this models' inference platform
Expand Down Expand Up @@ -53,3 +65,22 @@ def _get_all_subclasses(cls):
type(model_fn), model.platform
)
)


def get_input_names_from_torch_object(
model_fn: Union["torch.nn.Module", "torch.jit.ScriptModule"]
):
"""
Parse either a torch.nn.Module or torch.ScriptModule for input names
"""

if isinstance(model_fn, torch.jit.ScriptModule):
graph = model_fn.graph
input_names = [
node.debugName().split(".")[0] for node in graph.inputs()
]
if "self" in input_names:
input_names.remove("self")
return OrderedDict({name: name for name in input_names})
signature = inspect.signature(model_fn.forward)
return OrderedDict(signature.parameters)
7 changes: 2 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions tests/quiver/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,19 @@ def forward(self, x):
return torch.matmul(x, self.W)

return Model(dim)


@pytest.fixture
def torch_model_2(dim):
import torch

class Model(torch.nn.Module):
def __init__(self, size: int = 10):
super().__init__()
self.size = size
self.W = torch.eye(size)

def forward(self, x, y):
return torch.matmul(x, self.W), y

return Model(dim)
16 changes: 15 additions & 1 deletion tests/quiver/exporters/test_exporter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import pytest

from hermes.quiver.exporters import KerasSavedModel, TorchOnnx, utils
from hermes.quiver.exporters import (
KerasSavedModel,
TorchOnnx,
TorchScript,
utils,
)
from hermes.quiver.platform import Platform


Expand Down Expand Up @@ -33,3 +38,12 @@ def test_find_torch_exporter(torch_model):
good_model = make_model(Platform.ONNX)
bad_model = make_model(Platform.SAVEDMODEL)
_test_find_exporter(torch_model, good_model, bad_model, TorchOnnx)

good_model = make_model(Platform.TORCHSCRIPT)
_test_find_exporter(torch_model, good_model, bad_model, TorchScript)

import torch

script_module = torch.jit.trace(torch_model, torch.randn(1, 10))
good_model = make_model(Platform.TORCHSCRIPT)
_test_find_exporter(script_module, good_model, bad_model, TorchScript)
75 changes: 67 additions & 8 deletions tests/quiver/exporters/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,28 @@


@pytest.mark.torch
def test_torchscript_exporter(temp_local_repo, torch_model):
def test_torchscript_exporter(temp_local_repo, torch_model, torch_model_2):
model_fn = torch_model

model = Model("identity", temp_local_repo, Platform.TORCHSCRIPT)
exporter = TorchScript(model.config, model.fs)

input_shapes = {"x": (None, 10)}
input_shapes = {"input0": (None, 10)}
exporter._check_exposed_tensors("input", input_shapes)
assert len(model.config.input) == 1
assert model.config.input[0].name == "x"
assert model.config.input[0].name == "input0"
assert model.config.input[0].dims[0] == -1

bad_input_shapes = {"x": (None, 12)}
with pytest.raises(ValueError):
exporter._check_exposed_tensors("input", bad_input_shapes)

output_shapes = exporter._get_output_shapes(model_fn, "y")
assert output_shapes["OUTPUT__0"] == (None, 10)
assert output_shapes["y"] == (None, 10)

exporter._check_exposed_tensors("output", output_shapes)
assert len(model.config.output) == 1
assert model.config.output[0].name == "OUTPUT__0"
assert model.config.output[0].name == "y"
assert model.config.output[0].dims[0] == -1

version_path = temp_local_repo.fs.join("identity", "1")
Expand All @@ -39,8 +39,67 @@ def test_torchscript_exporter(temp_local_repo, torch_model):
exporter = TorchScript(model2.config, model2.fs)

model_path = temp_local_repo.fs.root
with pytest.raises(ValueError):
exporter(model_fn, model_path, [(None, 10)], ["BAD_NAME"])

# if a dictionary if input_shapes is not passed
# (i.e. no user specified name) it should default
# to the name of the input tensor in the forward call of the model
# in this case, "x"
exporter(model_fn, model_path, [(None, 10)])
assert "INPUT__0" in model2.inputs
assert "x" in model2.inputs
assert "OUTPUT__0" in model2.outputs

# if a dictionary of input_shapes is passed, it should
# use the user specified names
model3 = Model("identity3", temp_local_repo, Platform.TORCHSCRIPT)
exporter = TorchScript(model3.config, model3.fs)
model_path = temp_local_repo.fs.root

exporter(model_fn, model_path, input_shapes={"my_name": (None, 10)})
assert "my_name" in model3.inputs
assert "OUTPUT__0" in model3.outputs

# now check using non-default output names
model4 = Model("identity4", temp_local_repo, Platform.TORCHSCRIPT)
exporter = TorchScript(model4.config, model4.fs)
model_path = temp_local_repo.fs.root

exporter(
model_fn,
model_path,
input_shapes={"my_name": (None, 10)},
output_names=["my_output"],
)
assert "my_name" in model4.inputs
assert "my_output" in model4.outputs

# test a model with multiple inputs and outputs
model_fn = torch_model_2

model = Model("identity", temp_local_repo, Platform.TORCHSCRIPT)
exporter = TorchScript(model.config, model.fs)

input_shapes = {"input0": (None, 10), "input1": (None, 10)}
exporter._check_exposed_tensors("input", input_shapes)
assert len(model.config.input) == 2
assert model.config.input[0].name == "input0"
assert model.config.input[1].name == "input1"
assert model.config.input[0].dims[0] == -1
assert model.config.input[1].dims[0] == -1

bad_input_shapes = {"x": (None, 12)}
with pytest.raises(ValueError):
exporter._check_exposed_tensors("input", bad_input_shapes)

output_shapes = exporter._get_output_shapes(
model_fn, ["output1", "output2"]
)
assert output_shapes["output1"] == (None, 10)
assert output_shapes["output2"] == (None, 10)

model4 = Model("identity4", temp_local_repo, Platform.TORCHSCRIPT)
exporter = TorchScript(model4.config, model4.fs)
model_path = temp_local_repo.fs.root

exporter(model_fn, model_path, input_shapes=[(None, 10), (None, 10)])
assert "x" in model4.inputs and "y" in model4.inputs
assert "OUTPUT__0" in model4.outputs and "OUTPUT__1" in model4.outputs
Loading