-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torchlib] Register aten.linear and use matmul to simplify graph #2021
Conversation
import torch
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
model = TestModule()
ep = torch.onnx.export(model, (torch.randn(1, 10),), dynamo=True, verify=True)
print(ep)
ep = torch.onnx.export(model, (torch.randn(1, 12, 15, 10),), dynamo=True, verify=True)
print(ep) [torch.onnx] Obtain model graph for `TestModule([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `TestModule([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
[torch.onnx] Check the ONNX model...
[torch.onnx] Check the ONNX model... ✅
[torch.onnx] Execute the model with ONNX Runtime...
[torch.onnx] Execute the model with ONNX Runtime... ✅
[torch.onnx] Verify output accuracy...
[torch.onnx] Verify output accuracy... ✅
ONNXProgram(
model=
<
ir_version=10,
opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18},
producer_name='pytorch',
producer_version='2.7.0.dev20250115+cu124',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"x"<FLOAT,[1,10]>
),
outputs=(
%"linear"<FLOAT,[1,10]>
),
initializers=(
%"linear.weight"<FLOAT,[10,10]>,
%"linear.bias"<FLOAT,[10]>
),
) {
0 | # node_Gemm_0
%"linear"<FLOAT,[1,10]> ⬅️ ::Gemm(%"x", %"linear.weight", %"linear.bias") {beta=1.0, transB=True, alpha=1.0, transA=0}
return %"linear"<FLOAT,[1,10]>
}
<
opset_imports={'': 18},
>
def pkg.onnxscript.torch_lib.common::Rank(
inputs=(
%"input"<?,?>
),
outputs=(
%"return_val"<?,?>
),
) {
0 | # n0
%"tmp"<?,?> ⬅️ ::Shape(%"input")
1 | # n1
%"return_val"<?,?> ⬅️ ::Size(%"tmp")
return %"return_val"<?,?>
}
<
opset_imports={'': 18},
>
def pkg.onnxscript.torch_lib.common::IsScalar(
inputs=(
%"input"<?,?>
),
outputs=(
%"return_val"<?,?>
),
) {
0 | # n0
%"tmp"<?,?> ⬅️ ::Shape(%"input")
1 | # n1
%"tmp_0"<?,?> ⬅️ ::Size(%"tmp")
2 | # n2
%"tmp_1"<?,?> ⬅️ ::Constant() {value_int=0}
3 | # n3
%"return_val"<?,?> ⬅️ ::Equal(%"tmp_0", %"tmp_1")
return %"return_val"<?,?>
}
,
exported_program=
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_linear_weight: "f32[10, 10]", p_linear_bias: "f32[10]", x: "f32[1, 10]"):
# File: /home/justinchu/anaconda3/envs/onnx/lib/python3.13/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[1, 10]" = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias); x = p_linear_weight = p_linear_bias = None
return (linear,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_weight'), target='linear.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_bias'), target='linear.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='linear'), target=None)])
Range constraints: {}
)
[torch.onnx] Obtain model graph for `TestModule([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `TestModule([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
[torch.onnx] Check the ONNX model...
[torch.onnx] Check the ONNX model... ✅
[torch.onnx] Execute the model with ONNX Runtime...
[torch.onnx] Execute the model with ONNX Runtime... ✅
[torch.onnx] Verify output accuracy...
[torch.onnx] Verify output accuracy... ✅
ONNXProgram(
model=
<
ir_version=10,
opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18},
producer_name='pytorch',
producer_version='2.7.0.dev20250115+cu124',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"x"<FLOAT,[1,12,15,10]>
),
outputs=(
%"linear"<FLOAT,[1,12,15,10]>
),
initializers=(
%"linear.weight"<FLOAT,[10,10]>,
%"linear.bias"<FLOAT,[10]>
),
) {
0 | # node_Transpose_0
%"val_0"<?,?> ⬅️ ::Transpose(%"linear.weight") {perm=[1, 0]}
1 | # node_MatMul_1
%"val_1"<?,?> ⬅️ ::MatMul(%"x", %"val_0")
2 | # node_Add_2
%"linear"<FLOAT,[1,12,15,10]> ⬅️ ::Add(%"val_1", %"linear.bias")
return %"linear"<FLOAT,[1,12,15,10]>
}
<
opset_imports={'': 18},
>
def pkg.onnxscript.torch_lib.common::Rank(
inputs=(
%"input"<?,?>
),
outputs=(
%"return_val"<?,?>
),
) {
0 | # n0
%"tmp"<?,?> ⬅️ ::Shape(%"input")
1 | # n1
%"return_val"<?,?> ⬅️ ::Size(%"tmp")
return %"return_val"<?,?>
}
<
opset_imports={'': 18},
>
def pkg.onnxscript.torch_lib.common::IsScalar(
inputs=(
%"input"<?,?>
),
outputs=(
%"return_val"<?,?>
),
) {
0 | # n0
%"tmp"<?,?> ⬅️ ::Shape(%"input")
1 | # n1
%"tmp_0"<?,?> ⬅️ ::Size(%"tmp")
2 | # n2
%"tmp_1"<?,?> ⬅️ ::Constant() {value_int=0}
3 | # n3
%"return_val"<?,?> ⬅️ ::Equal(%"tmp_0", %"tmp_1")
return %"return_val"<?,?>
}
,
exported_program=
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_linear_weight: "f32[10, 10]", p_linear_bias: "f32[10]", x: "f32[1, 12, 15, 10]"):
# File: /home/justinchu/anaconda3/envs/onnx/lib/python3.13/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[1, 12, 15, 10]" = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias); x = p_linear_weight = p_linear_bias = None
return (linear,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_weight'), target='linear.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_bias'), target='linear.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='linear'), target=None)])
Range constraints: {}
) |
❌ 51 Tests Failed:
View the top 1 failed tests by shortest run time
View the full list of 2 ❄️ flaky tests
To view more test analytics, go to the Test Analytics Dashboard |
https://github.com/microsoft/onnxscript/pull/1821/files also disabled linear_bias though, do we want that back as well? |
Dynamic axes works too model = TestModule()
ep = torch.onnx.export(model, (torch.randn(2, 10),), dynamic_shapes=(
{0: torch.export.Dim.DYNAMIC},
), dynamo=True, verify=True)
print(ep)
ep = torch.onnx.export(model, (torch.randn(2, 12, 15, 10),), dynamo=True, dynamic_shapes=(
{0: torch.export.Dim.DYNAMIC},
), verify=True)
print(ep) |
I combined the implementation. |
### Description This PR adds fusions for [Google's SigLIP model](https://huggingface.co/google/siglip-base-patch16-224/) and Microsoft's internal conformer-encoder model. Here is an example of how to run the ORT transformer optimizer for the SigLIP model. ``` $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers $ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1152 --use_external_data_format --opt_level 0 --disable_shape_inference ``` Here is an example of how to run the ORT transformer optimizer for the conformer-encoder model. ``` $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers $ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type conformer --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0 --disable_shape_inference --convert_attribute ``` ### Motivation and Context This PR helps optimize multi-modal models that use SigLIP for the vision encoder and conformer-encoder for the speech encoder. This PR uses changes from the following PRs: - pytorch/pytorch#144801 - microsoft/onnxscript#2018 - microsoft/onnxscript#2019 - microsoft/onnxscript#2020 - microsoft/onnxscript#2021 - microsoft/onnxscript#2022 - microsoft/onnxscript#2024 - microsoft/onnxscript#2025 - microsoft/onnxscript#2029 - microsoft/onnxscript#2033 ### Introduction of ONNX Script This PR introduces [ONNX Script](https://github.com/microsoft/onnxscript) into the ORT transformer optimizer as an optional step via the `fold_transpose_initializers()` method of the `DynamoOnnxHelper` class.
### Description This PR adds fusions for [Google's SigLIP model](https://huggingface.co/google/siglip-base-patch16-224/) and Microsoft's internal conformer-encoder model. Here is an example of how to run the ORT transformer optimizer for the SigLIP model. ``` $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers $ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1152 --use_external_data_format --opt_level 0 --disable_shape_inference ``` Here is an example of how to run the ORT transformer optimizer for the conformer-encoder model. ``` $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers $ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type conformer --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0 --disable_shape_inference --convert_attribute ``` ### Motivation and Context This PR helps optimize multi-modal models that use SigLIP for the vision encoder and conformer-encoder for the speech encoder. This PR uses changes from the following PRs: - pytorch/pytorch#144801 - microsoft/onnxscript#2018 - microsoft/onnxscript#2019 - microsoft/onnxscript#2020 - microsoft/onnxscript#2021 - microsoft/onnxscript#2022 - microsoft/onnxscript#2024 - microsoft/onnxscript#2025 - microsoft/onnxscript#2029 - microsoft/onnxscript#2033 ### Introduction of ONNX Script This PR introduces [ONNX Script](https://github.com/microsoft/onnxscript) into the ORT transformer optimizer as an optional step via the `fold_transpose_initializers()` method of the `DynamoOnnxHelper` class.
### Description This PR adds fusions for [Google's SigLIP model](https://huggingface.co/google/siglip-base-patch16-224/) and Microsoft's internal conformer-encoder model. Here is an example of how to run the ORT transformer optimizer for the SigLIP model. ``` $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers $ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1152 --use_external_data_format --opt_level 0 --disable_shape_inference ``` Here is an example of how to run the ORT transformer optimizer for the conformer-encoder model. ``` $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers $ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type conformer --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0 --disable_shape_inference --convert_attribute ``` ### Motivation and Context This PR helps optimize multi-modal models that use SigLIP for the vision encoder and conformer-encoder for the speech encoder. This PR uses changes from the following PRs: - pytorch/pytorch#144801 - microsoft/onnxscript#2018 - microsoft/onnxscript#2019 - microsoft/onnxscript#2020 - microsoft/onnxscript#2021 - microsoft/onnxscript#2022 - microsoft/onnxscript#2024 - microsoft/onnxscript#2025 - microsoft/onnxscript#2029 - microsoft/onnxscript#2033 ### Introduction of ONNX Script This PR introduces [ONNX Script](https://github.com/microsoft/onnxscript) into the ORT transformer optimizer as an optional step via the `fold_transpose_initializers()` method of the `DynamoOnnxHelper` class.
Use matmul when the input is not rank 2 to avoid decomp to addmm.