Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Register aten.linear and use matmul to simplify graph #2021

Merged
merged 4 commits into from
Jan 18, 2025

Conversation

justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Jan 18, 2025

Use matmul when the input is not rank 2 to avoid decomp to addmm.

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Jan 18, 2025
@justinchuby justinchuby enabled auto-merge (squash) January 18, 2025 01:02
@justinchuby
Copy link
Collaborator Author

justinchuby commented Jan 18, 2025

import torch


class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.linear(x)


model = TestModule()
ep = torch.onnx.export(model, (torch.randn(1, 10),), dynamo=True, verify=True)
print(ep)
ep = torch.onnx.export(model, (torch.randn(1, 12, 15, 10),), dynamo=True, verify=True)
print(ep)
[torch.onnx] Obtain model graph for `TestModule([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `TestModule([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
[torch.onnx] Check the ONNX model...
[torch.onnx] Check the ONNX model... ✅
[torch.onnx] Execute the model with ONNX Runtime...
[torch.onnx] Execute the model with ONNX Runtime... ✅
[torch.onnx] Verify output accuracy...
[torch.onnx] Verify output accuracy... ✅
ONNXProgram(
    model=
        <
            ir_version=10,
            opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18},
            producer_name='pytorch',
            producer_version='2.7.0.dev20250115+cu124',
            domain=None,
            model_version=None,
        >
        graph(
            name=main_graph,
            inputs=(
                %"x"<FLOAT,[1,10]>
            ),
            outputs=(
                %"linear"<FLOAT,[1,10]>
            ),
            initializers=(
                %"linear.weight"<FLOAT,[10,10]>,
                %"linear.bias"<FLOAT,[10]>
            ),
        ) {
            0 |  # node_Gemm_0
                 %"linear"<FLOAT,[1,10]> ⬅️ ::Gemm(%"x", %"linear.weight", %"linear.bias") {beta=1.0, transB=True, alpha=1.0, transA=0}
            return %"linear"<FLOAT,[1,10]>
        }

        <
            opset_imports={'': 18},
        >
        def pkg.onnxscript.torch_lib.common::Rank(
            inputs=(
                %"input"<?,?>
            ),
            outputs=(
                %"return_val"<?,?>
            ),
        ) {
            0 |  # n0
                 %"tmp"<?,?> ⬅️ ::Shape(%"input")
            1 |  # n1
                 %"return_val"<?,?> ⬅️ ::Size(%"tmp")
            return %"return_val"<?,?>
        }

        <
            opset_imports={'': 18},
        >
        def pkg.onnxscript.torch_lib.common::IsScalar(
            inputs=(
                %"input"<?,?>
            ),
            outputs=(
                %"return_val"<?,?>
            ),
        ) {
            0 |  # n0
                 %"tmp"<?,?> ⬅️ ::Shape(%"input")
            1 |  # n1
                 %"tmp_0"<?,?> ⬅️ ::Size(%"tmp")
            2 |  # n2
                 %"tmp_1"<?,?> ⬅️ ::Constant() {value_int=0}
            3 |  # n3
                 %"return_val"<?,?> ⬅️ ::Equal(%"tmp_0", %"tmp_1")
            return %"return_val"<?,?>
        }
    ,
    exported_program=
        ExportedProgram:
            class GraphModule(torch.nn.Module):
                def forward(self, p_linear_weight: "f32[10, 10]", p_linear_bias: "f32[10]", x: "f32[1, 10]"):
                     # File: /home/justinchu/anaconda3/envs/onnx/lib/python3.13/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
                    linear: "f32[1, 10]" = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias);  x = p_linear_weight = p_linear_bias = None
                    return (linear,)
            
        Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_weight'), target='linear.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_bias'), target='linear.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='linear'), target=None)])
        Range constraints: {}

)

[torch.onnx] Obtain model graph for `TestModule([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `TestModule([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
[torch.onnx] Check the ONNX model...
[torch.onnx] Check the ONNX model... ✅
[torch.onnx] Execute the model with ONNX Runtime...
[torch.onnx] Execute the model with ONNX Runtime... ✅
[torch.onnx] Verify output accuracy...
[torch.onnx] Verify output accuracy... ✅
ONNXProgram(
    model=
        <
            ir_version=10,
            opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18},
            producer_name='pytorch',
            producer_version='2.7.0.dev20250115+cu124',
            domain=None,
            model_version=None,
        >
        graph(
            name=main_graph,
            inputs=(
                %"x"<FLOAT,[1,12,15,10]>
            ),
            outputs=(
                %"linear"<FLOAT,[1,12,15,10]>
            ),
            initializers=(
                %"linear.weight"<FLOAT,[10,10]>,
                %"linear.bias"<FLOAT,[10]>
            ),
        ) {
            0 |  # node_Transpose_0
                 %"val_0"<?,?> ⬅️ ::Transpose(%"linear.weight") {perm=[1, 0]}
            1 |  # node_MatMul_1
                 %"val_1"<?,?> ⬅️ ::MatMul(%"x", %"val_0")
            2 |  # node_Add_2
                 %"linear"<FLOAT,[1,12,15,10]> ⬅️ ::Add(%"val_1", %"linear.bias")
            return %"linear"<FLOAT,[1,12,15,10]>
        }

        <
            opset_imports={'': 18},
        >
        def pkg.onnxscript.torch_lib.common::Rank(
            inputs=(
                %"input"<?,?>
            ),
            outputs=(
                %"return_val"<?,?>
            ),
        ) {
            0 |  # n0
                 %"tmp"<?,?> ⬅️ ::Shape(%"input")
            1 |  # n1
                 %"return_val"<?,?> ⬅️ ::Size(%"tmp")
            return %"return_val"<?,?>
        }

        <
            opset_imports={'': 18},
        >
        def pkg.onnxscript.torch_lib.common::IsScalar(
            inputs=(
                %"input"<?,?>
            ),
            outputs=(
                %"return_val"<?,?>
            ),
        ) {
            0 |  # n0
                 %"tmp"<?,?> ⬅️ ::Shape(%"input")
            1 |  # n1
                 %"tmp_0"<?,?> ⬅️ ::Size(%"tmp")
            2 |  # n2
                 %"tmp_1"<?,?> ⬅️ ::Constant() {value_int=0}
            3 |  # n3
                 %"return_val"<?,?> ⬅️ ::Equal(%"tmp_0", %"tmp_1")
            return %"return_val"<?,?>
        }
    ,
    exported_program=
        ExportedProgram:
            class GraphModule(torch.nn.Module):
                def forward(self, p_linear_weight: "f32[10, 10]", p_linear_bias: "f32[10]", x: "f32[1, 12, 15, 10]"):
                     # File: /home/justinchu/anaconda3/envs/onnx/lib/python3.13/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
                    linear: "f32[1, 12, 15, 10]" = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias);  x = p_linear_weight = p_linear_bias = None
                    return (linear,)
            
        Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_weight'), target='linear.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_bias'), target='linear.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='linear'), target=None)])
        Range constraints: {}

)

Copy link

codecov bot commented Jan 18, 2025

❌ 51 Tests Failed:

Tests completed Failed Passed Skipped
11814 51 11763 2454
View the top 1 failed tests by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0212_test_cast_STRING_to_FLOAT
Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_cast_STRING_to_FLOAT'

The above exception was the direct cause of the following exception:
.nox\test\lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_cast_STRING_to_FLOAT' (e=No module named 'tests.onnx_backend_test_code.test_cast_STRING_to_FLOAT') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_cast_STRING_to_FLOAT.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_cast_STRING_to_FLOAT.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, STRING
E   from onnxscript.onnx_opset import opset21
E   
E   @script()
E   def bck_test_cast_STRING_to_FLOAT(input: STRING[3,4]) -> (FLOAT[3,4]):
E       output = opset21.Cast(input, to=1)
E       return output
View the full list of 2 ❄️ flaky tests
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0634_test_max_int32

Flake rate in main: 5.00% (Passed 38 times, Failed 2 times)

Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_max_int32'

The above exception was the direct cause of the following exception:
.nox\test\lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_max_int32' (e=No module named 'tests.onnx_backend_test_code.test_max_int32') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_max_int32.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_max_int32.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import INT32
E   from onnxscript.onnx_opset import opset13
E   
E   @script()
E   def bck_test_max_int32(data_0: INT32[3], data_1: INT32[3]) -> (INT32[3]):
E       result = opset13.Max(data_0, data_1)
E       return result
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_1104_test_shape_start_1_end_2

Flake rate in main: 12.50% (Passed 14 times, Failed 2 times)

Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_shape_start_1_end_2'

The above exception was the direct cause of the following exception:
.nox\test\lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_shape_start_1_end_2' (e=No module named 'tests.onnx_backend_test_code.test_shape_start_1_end_2') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_shape_start_1_end_2.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_shape_start_1_end_2.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, INT64
E   from onnxscript.onnx_opset import opset21
E   
E   @script()
E   def bck_test_shape_start_1_end_2(x: FLOAT[3,4,5]) -> (INT64[1]):
E       y = opset21.Shape(x, end=2, start=1)
E       return y

To view more test analytics, go to the Test Analytics Dashboard
📢 Thoughts on this report? Let us know!

@titaiwangms
Copy link
Contributor

https://github.com/microsoft/onnxscript/pull/1821/files also disabled linear_bias though, do we want that back as well?

@justinchuby
Copy link
Collaborator Author

Dynamic axes works too

model = TestModule()
ep = torch.onnx.export(model, (torch.randn(2, 10),), dynamic_shapes=(
    {0: torch.export.Dim.DYNAMIC},
), dynamo=True, verify=True)
print(ep)
ep = torch.onnx.export(model, (torch.randn(2, 12, 15, 10),), dynamo=True, dynamic_shapes=(
    {0: torch.export.Dim.DYNAMIC},
), verify=True)
print(ep)

@justinchuby
Copy link
Collaborator Author

#1821 (files) also disabled linear_bias though, do we want that back as well?

I combined the implementation.

@justinchuby justinchuby merged commit e7d199e into main Jan 18, 2025
22 of 27 checks passed
@justinchuby justinchuby deleted the justinchu/better-linear branch January 18, 2025 01:19
kunal-vaishnavi added a commit to microsoft/onnxruntime that referenced this pull request Jan 31, 2025
### Description
This PR adds fusions for [Google's SigLIP
model](https://huggingface.co/google/siglip-base-patch16-224/) and
Microsoft's internal conformer-encoder model.

Here is an example of how to run the ORT transformer optimizer for the
SigLIP model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1152 --use_external_data_format --opt_level 0 --disable_shape_inference
```

Here is an example of how to run the ORT transformer optimizer for the
conformer-encoder model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type conformer --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0 --disable_shape_inference --convert_attribute
```

### Motivation and Context
This PR helps optimize multi-modal models that use SigLIP for the vision
encoder and conformer-encoder for the speech encoder.

This PR uses changes from the following PRs:
- pytorch/pytorch#144801
- microsoft/onnxscript#2018
- microsoft/onnxscript#2019
- microsoft/onnxscript#2020
- microsoft/onnxscript#2021
- microsoft/onnxscript#2022
- microsoft/onnxscript#2024
- microsoft/onnxscript#2025
- microsoft/onnxscript#2029
- microsoft/onnxscript#2033

### Introduction of ONNX Script

This PR introduces [ONNX
Script](https://github.com/microsoft/onnxscript) into the ORT
transformer optimizer as an optional step via the
`fold_transpose_initializers()` method of the `DynamoOnnxHelper` class.
sfatimar pushed a commit to intel/onnxruntime that referenced this pull request Feb 5, 2025
### Description
This PR adds fusions for [Google's SigLIP
model](https://huggingface.co/google/siglip-base-patch16-224/) and
Microsoft's internal conformer-encoder model.

Here is an example of how to run the ORT transformer optimizer for the
SigLIP model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1152 --use_external_data_format --opt_level 0 --disable_shape_inference
```

Here is an example of how to run the ORT transformer optimizer for the
conformer-encoder model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type conformer --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0 --disable_shape_inference --convert_attribute
```

### Motivation and Context
This PR helps optimize multi-modal models that use SigLIP for the vision
encoder and conformer-encoder for the speech encoder.

This PR uses changes from the following PRs:
- pytorch/pytorch#144801
- microsoft/onnxscript#2018
- microsoft/onnxscript#2019
- microsoft/onnxscript#2020
- microsoft/onnxscript#2021
- microsoft/onnxscript#2022
- microsoft/onnxscript#2024
- microsoft/onnxscript#2025
- microsoft/onnxscript#2029
- microsoft/onnxscript#2033

### Introduction of ONNX Script

This PR introduces [ONNX
Script](https://github.com/microsoft/onnxscript) into the ORT
transformer optimizer as an optional step via the
`fold_transpose_initializers()` method of the `DynamoOnnxHelper` class.
sfatimar pushed a commit to intel/onnxruntime that referenced this pull request Feb 5, 2025
### Description
This PR adds fusions for [Google's SigLIP
model](https://huggingface.co/google/siglip-base-patch16-224/) and
Microsoft's internal conformer-encoder model.

Here is an example of how to run the ORT transformer optimizer for the
SigLIP model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1152 --use_external_data_format --opt_level 0 --disable_shape_inference
```

Here is an example of how to run the ORT transformer optimizer for the
conformer-encoder model.
```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type conformer --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0 --disable_shape_inference --convert_attribute
```

### Motivation and Context
This PR helps optimize multi-modal models that use SigLIP for the vision
encoder and conformer-encoder for the speech encoder.

This PR uses changes from the following PRs:
- pytorch/pytorch#144801
- microsoft/onnxscript#2018
- microsoft/onnxscript#2019
- microsoft/onnxscript#2020
- microsoft/onnxscript#2021
- microsoft/onnxscript#2022
- microsoft/onnxscript#2024
- microsoft/onnxscript#2025
- microsoft/onnxscript#2029
- microsoft/onnxscript#2033

### Introduction of ONNX Script

This PR introduces [ONNX
Script](https://github.com/microsoft/onnxscript) into the ORT
transformer optimizer as an optional step via the
`fold_transpose_initializers()` method of the `DynamoOnnxHelper` class.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development
Projects
Development

Successfully merging this pull request may close these issues.

2 participants