Skip to content

Commit

Permalink
Merge with main
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Jan 16, 2025
1 parent e63ec29 commit 5e4c1c2
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions onnxscript/rewriter/onnxruntime/xformers/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,16 @@ def rewrite(
)


_rule1 = MultiHeadAttention.rule("MHA_2dmm", use_2d_matmul=True)
_rule1 = MultiHeadAttention.rule("MHA_2dmm", use_2d_matmul=False)

mha_rules = pattern.RewriteRuleSet([_rule1])

debug: bool = True

def fuse_mha(model: ir.Model) -> int:
count = mha_rules.apply_to_model(model)
print(f"MHA count: {count}")
if count == 0 and debug:
mha_rules.apply_to_model(model, debug=True)
else:
print(f"MHA count: {count}")
return count

0 comments on commit 5e4c1c2

Please sign in to comment.