Skip to content

Commit

Permalink
Add amax partitioner constraints (pytorch#1064)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1064

As title

Reviewed By: mcr229

Differential Revision: D50195550

fbshipit-source-id: 62c77ca5a7a2d52b308bf1cbf930edd8529c7b96
  • Loading branch information
digantdesai authored and facebook-github-bot committed Oct 23, 2023
1 parent c27a59a commit f4d9c4a
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions backends/xnnpack/partition/xnnpack_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,17 @@ def slice_copy(node: torch.fx.Node, ep: ExportedProgram) -> bool: # noqa

return True

@_constraint(exir_ops.edge.aten.amax.default)
def amax(node: torch.fx.Node, ep: ExportedProgram) -> bool: # noqa
"""
A: Only with keep_dim == True
B: Only support with dim == 2 or dim == 3
valid iff, A && B
"""
is_keep_dim = (len(node.args) == 3) and (cast(bool, node.args[3]) is True)
dim_arg_val = cast(int, node.args[1])
return is_keep_dim and (dim_arg_val == 2 or dim_arg_val == 3)


class XnnpackFloatingPointPartitioner(Partitioner):
"""
Expand Down

0 comments on commit f4d9c4a

Please sign in to comment.