Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More intelligent and flexible input and output tensor naming for TorchScript Exporter #62

Merged
merged 5 commits into from
Oct 3, 2024

Conversation

EthanMarx
Copy link
Contributor

Previously, the TorchScript exporter directly enforced input and output tensor names follow the INPUT__{i} OUTPUT__{i} convention. The main reason for this is that triton expects certain naming conventions to deal with the fact that TorchScript models don't come with some necessary metadata triton needs to match input tensors to the appropriate ordering expected by the model. If your model only has one input and one output, this detail is irrelevant, and any naming will work.

This convention is not the only way to deal with this (see here): one can also name the tensors by the parameter names in the models forward, or name the tensors with <name>__<index>.

In any case, the INPUT__{i} OUTPUT__{i} tensor naming enforcement was leading to some difficulties when ensembling multiple torchscript models due to conflicting tensor names.

This PR makes naming tensors in TorchScript models more flexible by doing the following:

  • Input tensors

    1. If the user passes a dictionary of {tensor_name: shape} for the input_shapes, the user passed names will be used. A warning will be given telling users to be careful when using their own name.
    2. If the user passes a List[shape] to input_shapes, the names will be inferred directly from the forward method of the TorchScript model - this behavior is recommended by triton.
  • Output tensors

    1. If the user passes a List[name] to output_names, the user passed names will be used
    2. If output_names is left as None, the OUTPUT__{i} convention will be used

@EthanMarx EthanMarx merged commit b9ea264 into ML4GW:dev Oct 3, 2024
5 checks passed
@EthanMarx EthanMarx deleted the torchscript-export branch October 3, 2024 11:02
EthanMarx added a commit that referenced this pull request Oct 3, 2024
…chScript` Exporter (#62)

* support direct torchscript export from ScriptModule

* add tests of new naming behavior

* guard torch import

* add jit module to handles for onnx and torchscript exporters

* add script module test to exporter utils
EthanMarx added a commit that referenced this pull request Oct 3, 2024
* Allow parsing input tensors from torch `ScriptModules` (#57)

* add ability to parse inputs from script module

* update torch dep to ^2.0 for consistency with ml4gw

* add onnx dep

* remove both tests

* remove both tests

* Batched state updates (#59)

* allow for batched state updates

* update poetry lock

* restrict tf version

* revert whitespaces

* Allow python 3.12 and explicitly list PyPI as a source for future compatibility

* Add python 3.12 to unit tests

* Add py312 to tox file

* Missed one

* Attempt to update tensorflow

* Change tf specification

* Force usage of keras 2

* Fix formatting

* Remove py38 support and install tf-keras

* More intelligent and flexible input and output tensor naming for `TorchScript` Exporter (#62)

* support direct torchscript export from ScriptModule

* add tests of new naming behavior

* guard torch import

* add jit module to handles for onnx and torchscript exporters

* add script module test to exporter utils

* bump to version 0.2.0 (#64)

---------

Co-authored-by: William Benoit <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant