Skip to content

Commit

Permalink
gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Jan 18, 2025
1 parent de19f32 commit 61a474e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ def aten_linear(input: TFloat, weight: TFloat, bias: TFloat | None = None) -> TF

if len(input.shape) == 2:
# Use Gemm for the rank 2 input
return op.Gemm(input, weight, transB=True)
return op.Gemm(input, weight, bias, transB=True)
weight_transposed = op.Transpose(weight, perm=[1, 0])
mul = op.MatMul(input, weight_transposed)
if bias is None:
Expand Down
1 change: 1 addition & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,7 @@ def _where_input_wrangler(
tolerance={torch.float16: (8e-2, 1e-4)},
),
TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu),
TorchLibOpInfo("nn.functional.linear", nn_ops.aten_linear),
TorchLibOpInfo(
"nn.functional.unfold",
nn_ops.aten_im2col,
Expand Down

0 comments on commit 61a474e

Please sign in to comment.