Skip to content

Commit

Permalink
Move arm.passes to arm._passes (pytorch#5918)
Browse files Browse the repository at this point in the history
Summary:
Changing arm.passes to arm._passes to indicate that these passes are not covered under the API stability guarantee.

Pull Request resolved: pytorch#5918

Reviewed By: malfet, helunwencser

Differential Revision: D63926055

fbshipit-source-id: 141a5be9f3a81e75784825357bacbab91904620c
(cherry picked from commit 83c95df)
  • Loading branch information
tarun292 committed Oct 10, 2024
1 parent 40358fa commit 14b594f
Show file tree
Hide file tree
Showing 17 changed files with 226 additions and 9 deletions.
4 changes: 2 additions & 2 deletions backends/arm/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ python_library(
typing = True,
deps = [
":arm_backend",
"//executorch/backends/arm/passes:passes",
"//executorch/backends/arm/_passes:passes",
"//executorch/exir:lib",
],
)
Expand All @@ -27,7 +27,7 @@ python_library(
":arm_vela",
"//executorch/backends/arm/operators:lib",
"//executorch/backends/arm/operators:node_visitor",
"//executorch/backends/arm/passes:passes",
"//executorch/backends/arm/_passes:passes",
],
)

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
# pyre-unsafe

import torch
from executorch.backends.arm.passes.annotate_channels_last_dim_order_pass import (
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
AnnotateChannelsLastDimOrder,
)
from executorch.backends.arm.passes.convert_expand_copy_to_repeat import (
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
ConvertExpandCopyToRepeatPass,
)
from executorch.backends.arm.passes.convert_split_to_slice import (
from executorch.backends.arm._passes.convert_split_to_slice import (
ConvertSplitToSlicePass,
)
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
Expand Down
66 changes: 66 additions & 0 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch

from executorch.exir.dialects._ops import ops as exir_ops
from torch._ops import OpOverload


def create_node(
graph: torch.fx.Graph,
op_target: OpOverload,
args: tuple = (),
kwargs: Optional[dict] = None,
quantize: bool = False,
q_params: Optional[tuple] = None,
):
"""
Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node.
If quantize is true and q_params is not None, a q dq pair is inserted after the newly created node.
"""

node = graph.create_node(
"call_function",
op_target,
args=args,
kwargs=kwargs or {},
)
if quantize and q_params:
return insert_q_dq_pair(graph, node, q_params)
return node


def insert_q_dq_pair(
graph: torch.fx.Graph,
anchor: torch.fx.Node,
q_params: tuple,
):
"""
Inserts a q dq node pair after the node 'anchor'.
"""

with graph.inserting_after(anchor):
q = create_node(
graph=graph,
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(), # We add the argument last
)
q.meta = anchor.meta
with graph.inserting_after(q):
dq = create_node(
graph=graph,
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(q,) + q_params,
)
dq.meta = q.meta
anchor.replace_all_uses_with(dq)
# We add this last so the replace all uses above does not replace the quantized
# node's first use
q.args = (anchor,) + q_params
return dq
35 changes: 35 additions & 0 deletions backends/arm/_passes/cast_int64_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.pass_base import ExportPass, PassResult


class CastInt64ToInt32Pass(ExportPass):
def __init__(self, exported_program: torch.export.ExportedProgram):
super(CastInt64ToInt32Pass, self).__init__()
self.exported_program = exported_program

def _to_int32(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
fake_tensor = node.meta["val"]
if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
if node.meta["val"].dtype == torch.int64:
node.meta["val"] = node.meta["val"].to(torch.int32)
buffer_name = (
self.exported_program.graph_signature.inputs_to_buffers[
node.name
]
)
new_tensor = self.exported_program.state_dict[buffer_name].to(
torch.int32
)
self.exported_program.state_dict[buffer_name] = new_tensor

def call(self, graph_module: torch.fx.GraphModule):
self._to_int32(graph_module)
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)
45 changes: 45 additions & 0 deletions backends/arm/_passes/decompose_div_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


def get_div_decomposition(op) -> tuple:
"""
Returns the the (reciprocal_op, mul_op), where the ops depends on if
the div op is in exir_ops torch.ops.aten.
"""
if op == exir_ops.edge.aten.div.Tensor:
return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor)
if op == torch.ops.aten.div.Tensor:
return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor)
raise RuntimeError(f"Can't get div decomposition for op {op}")


class DecomposeDivPass(ExportPass):
"""
This pass decomposes div into a mul and a reciprocal node.
Example:
y = div(a,b)
Becomes:
x = reciprocal(b)
y = mul(a,x)
"""

def call_operator(self, op, args, kwargs, meta):
if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor):
return super().call_operator(op, args, kwargs, meta)

reciprocal_op, mul_op = get_div_decomposition(op)

numerator = args[0]
denominator = args[1]
reciprocal = super().call_operator(reciprocal_op, (denominator,), {}, meta)

return super().call_operator(mul_op, (numerator, reciprocal), {}, meta)
File renamed without changes.
69 changes: 69 additions & 0 deletions backends/arm/_passes/scalars_to_attribute_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast, Union

import torch
from executorch.backends.arm.tosa_mapping import extract_tensor_meta

from executorch.exir.pass_base import ExportPass, PassResult
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
from torch.fx import GraphModule, Node


class ScalarsToAttributePass(ExportPass):
"""
For ops in 'targeted_ops', convert inputs that are scalar values
to attribute Nodes that output the same value.
"""

targeted_ops = [
torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Tensor,
torch.ops.aten.sub_.Tensor,
torch.ops.aten.mul.Tensor,
torch.ops.aten.div.Tensor,
]

def call(self, graph_module: GraphModule) -> PassResult:
for n in graph_module.graph.nodes:
n = cast(Node, n)
if n.op != "call_function" or n.target not in self.targeted_ops:
continue

biggest_rank = 1
for arg in n.args:
if isinstance(arg, Node):
_, shape, _ = extract_tensor_meta(arg.meta)
biggest_rank = max(biggest_rank, len(shape))

new_args = []
for arg in n.args:
if isinstance(arg, Node):
new_args.append(arg)
continue

prefix = "_tensor_constant_"
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
tensor_constant_name = get_new_attr_name(graph_module)
float_tensor = torch.tensor(
float(cast(Union[int, float], arg))
).reshape((1,) * biggest_rank)
graph_module.register_buffer(tensor_constant_name, float_tensor)
fake_mode = n.meta["val"].fake_mode

with graph_module.graph.inserting_before(n):
get_attr_node = graph_module.graph.create_node(
"get_attr", tensor_constant_name, (), {}
)
get_attr_node.meta["val"] = fake_mode.from_tensor(
float_tensor, static_shapes=True
)
new_args.append(get_attr_node)
n.args = tuple(new_args)

graph_module.recompile()
return PassResult(graph_module, True)
File renamed without changes.
4 changes: 3 additions & 1 deletion backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from executorch.backends.arm.operators.node_visitor import get_node_visitors
from executorch.backends.arm.operators.op_output import process_output
from executorch.backends.arm.operators.op_placeholder import process_placeholder
from executorch.backends.arm.passes.arm_pass_manager import ArmPassManager
from executorch.backends.arm._passes.arm_pass_manager import (
ArmPassManager,
) # usort: skip
from executorch.backends.arm.tosa_utils import (
dbg_fail,
dbg_tosa_dump,
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from typing import final, List

import torch
from executorch.backends.arm.arm_backend import ArmBackend
from executorch.backends.arm.passes.tag_io_quant_pass import TagIOQuantPass
from executorch.backends.arm.arm_backend import ArmBackend # usort: skip
from executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/passes/test_meandim_to_averagepool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import unittest

import torch
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
ConvertMeanDimToAveragePool,
)

Expand Down

0 comments on commit 14b594f

Please sign in to comment.