Skip to content

Commit

Permalink
add tests of new naming behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanMarx committed Oct 2, 2024
1 parent 56cdded commit f06470c
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 17 deletions.
46 changes: 37 additions & 9 deletions hermes/quiver/exporters/torchscript.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import warnings
from collections import OrderedDict
from io import BytesIO

Expand Down Expand Up @@ -32,12 +33,26 @@ class TorchScript(Exporter, metaclass=TorchScriptMeta):
def __call__(
self, model_fn, version, input_shapes, output_names=None
) -> None:
if output_names is not None:
raise ValueError(
"Cannot specify output_names for TorchScript exporter"
# if a dictionary is passed
# (i.e. user specified names for tensors)
# warn the user about specific naming conventions
# for tensor input names from triton
if isinstance(input_shapes, dict):
warnings.warn(
"Triton expects specific naming conventions and "
"ordering for tensor input names. Be careful. See "
"https://docs.nvidia.com/deeplearning/triton-inference-server/"
"user-guide/docs/user_guide/model_configuration.html"
"#special-conventions-for-pytorch-backend"
)

input_shapes = {f"INPUT__{i}": j for i, j in enumerate(input_shapes)}
return super().__call__(
model_fn, version, input_shapes, output_names
)
# otherwise, user passed a sequence of shapes:
# use tritons recommended naming conventions
# by inferring the names from the model_fn
parameters = get_input_names_from_torch_object(model_fn)
input_shapes = {p: s for p, s in zip(parameters, input_shapes)}
super().__call__(model_fn, version, input_shapes, output_names)

def _get_tensor(self, shape):
Expand All @@ -56,7 +71,7 @@ def _get_tensor(self, shape):
# and pass them along if they were provided?
return torch.randn(*tensor_shape)

def _get_output_shapes(self, model_fn, _):
def _get_output_shapes(self, model_fn, output_names=None):
# now that we know we have inputs added to our
# model config, use that config to generate
# framework tensors that we'll feed through
Expand Down Expand Up @@ -120,9 +135,22 @@ def _get_output_shapes(self, model_fn, _):
if any([x.dims[0] == -1 for x in self.config.input]):
shapes = [(None,) + s[1:] for s in shapes]

# if we provided names for the outputs, return them
# as a dict for validation against the config
shapes = {f"OUTPUT__{i}": j for i, j in enumerate(shapes)}
# if we didn't provide names for the outputs,
# use the "OUTPUT__{i}" format
if output_names is None:
shapes = {f"OUTPUT__{i}": j for i, j in enumerate(shapes)}
else:
warnings.warn(
"Triton expects specific naming conventions "
"and ordering for tensor output names. Be careful. See "
"https://docs.nvidia.com/deeplearning/triton-inference-server/"
"user-guide/docs/user_guide/model_configuration.html"
"#special-conventions-for-pytorch-backend"
)
shapes = {
name: shape
for i, (name, shape) in enumerate(zip(output_names, shapes))
}
return shapes

def export(self, model_fn, export_path, verbose=0, **kwargs):
Expand Down
16 changes: 16 additions & 0 deletions tests/quiver/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,19 @@ def forward(self, x):
return torch.matmul(x, self.W)

return Model(dim)


@pytest.fixture
def torch_model_2(dim):
import torch

class Model(torch.nn.Module):
def __init__(self, size: int = 10):
super().__init__()
self.size = size
self.W = torch.eye(size)

def forward(self, x, y):
return torch.matmul(x, self.W), y

return Model(dim)
75 changes: 67 additions & 8 deletions tests/quiver/exporters/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,28 @@


@pytest.mark.torch
def test_torchscript_exporter(temp_local_repo, torch_model):
def test_torchscript_exporter(temp_local_repo, torch_model, torch_model_2):
model_fn = torch_model

model = Model("identity", temp_local_repo, Platform.TORCHSCRIPT)
exporter = TorchScript(model.config, model.fs)

input_shapes = {"x": (None, 10)}
input_shapes = {"input0": (None, 10)}
exporter._check_exposed_tensors("input", input_shapes)
assert len(model.config.input) == 1
assert model.config.input[0].name == "x"
assert model.config.input[0].name == "input0"
assert model.config.input[0].dims[0] == -1

bad_input_shapes = {"x": (None, 12)}
with pytest.raises(ValueError):
exporter._check_exposed_tensors("input", bad_input_shapes)

output_shapes = exporter._get_output_shapes(model_fn, "y")
assert output_shapes["OUTPUT__0"] == (None, 10)
assert output_shapes["y"] == (None, 10)

exporter._check_exposed_tensors("output", output_shapes)
assert len(model.config.output) == 1
assert model.config.output[0].name == "OUTPUT__0"
assert model.config.output[0].name == "y"
assert model.config.output[0].dims[0] == -1

version_path = temp_local_repo.fs.join("identity", "1")
Expand All @@ -39,8 +39,67 @@ def test_torchscript_exporter(temp_local_repo, torch_model):
exporter = TorchScript(model2.config, model2.fs)

model_path = temp_local_repo.fs.root
with pytest.raises(ValueError):
exporter(model_fn, model_path, [(None, 10)], ["BAD_NAME"])

# if a dictionary if input_shapes is not passed
# (i.e. no user specified name) it should default
# to the name of the input tensor in the forward call of the model
# in this case, "x"
exporter(model_fn, model_path, [(None, 10)])
assert "INPUT__0" in model2.inputs
assert "x" in model2.inputs
assert "OUTPUT__0" in model2.outputs

# if a dictionary of input_shapes is passed, it should
# use the user specified names
model3 = Model("identity3", temp_local_repo, Platform.TORCHSCRIPT)
exporter = TorchScript(model3.config, model3.fs)
model_path = temp_local_repo.fs.root

exporter(model_fn, model_path, input_shapes={"my_name": (None, 10)})
assert "my_name" in model3.inputs
assert "OUTPUT__0" in model3.outputs

# now check using non-default output names
model4 = Model("identity4", temp_local_repo, Platform.TORCHSCRIPT)
exporter = TorchScript(model4.config, model4.fs)
model_path = temp_local_repo.fs.root

exporter(
model_fn,
model_path,
input_shapes={"my_name": (None, 10)},
output_names=["my_output"],
)
assert "my_name" in model4.inputs
assert "my_output" in model4.outputs

# test a model with multiple inputs and outputs
model_fn = torch_model_2

model = Model("identity", temp_local_repo, Platform.TORCHSCRIPT)
exporter = TorchScript(model.config, model.fs)

input_shapes = {"input0": (None, 10), "input1": (None, 10)}
exporter._check_exposed_tensors("input", input_shapes)
assert len(model.config.input) == 2
assert model.config.input[0].name == "input0"
assert model.config.input[1].name == "input1"
assert model.config.input[0].dims[0] == -1
assert model.config.input[1].dims[0] == -1

bad_input_shapes = {"x": (None, 12)}
with pytest.raises(ValueError):
exporter._check_exposed_tensors("input", bad_input_shapes)

output_shapes = exporter._get_output_shapes(
model_fn, ["output1", "output2"]
)
assert output_shapes["output1"] == (None, 10)
assert output_shapes["output2"] == (None, 10)

model4 = Model("identity4", temp_local_repo, Platform.TORCHSCRIPT)
exporter = TorchScript(model4.config, model4.fs)
model_path = temp_local_repo.fs.root

exporter(model_fn, model_path, input_shapes=[(None, 10), (None, 10)])
assert "x" in model4.inputs and "y" in model4.inputs
assert "OUTPUT__0" in model4.outputs and "OUTPUT__1" in model4.outputs

0 comments on commit f06470c

Please sign in to comment.