Skip to content

Commit

Permalink
add unit test for op_add (pytorch#7087)
Browse files Browse the repository at this point in the history
add op_add shapes to generate as binaries (pytorch#7087)

Summary:
generates the add model pte’s for cadence to execute on. will use graph builder in later diffs

Test Plan:
Imported from GitHub, without a `Test Plan:` line.
{F1968254537}

Reviewed By: hsharma35

Differential Revision: D66510372

Pulled By: zonglinpeng
  • Loading branch information
zonglinpeng authored Nov 27, 2024
1 parent dedf77b commit 5785fc3
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 7 deletions.
20 changes: 20 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,26 @@ python_library(
],
)

python_library(
name = "export_example",
srcs = [
"export_example.py",
],
deps = [
":passes",
":utils",
":ops_registrations",
":replace_ops",
"//caffe2:torch",
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
"//executorch/backends/cadence/runtime:runtime",
"//executorch/backends/cadence/aot/quantizer:quantizer",
"//executorch/backends/transforms:decompose_sdpa",
"//executorch/backends/transforms:remove_clone_ops",
"//executorch/exir:lib",
"//executorch/devtools:lib",
],
)

python_library(
name = "pass_utils",
Expand Down
14 changes: 8 additions & 6 deletions backends/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def export_model(
model: nn.Module,
example_inputs: Tuple[Any, ...],
file_name: str = "CadenceDemoModel",
run_and_compare: bool = True,
):
# create work directory for outputs and model binary
working_dir = tempfile.mkdtemp(dir="/tmp")
Expand Down Expand Up @@ -112,9 +113,10 @@ def export_model(
)

# TODO: move to test infra
runtime.run_and_compare(
executorch_prog=exec_prog,
inputs=example_inputs,
ref_outputs=ref_outputs,
working_dir=working_dir,
)
if run_and_compare:
runtime.run_and_compare(
executorch_prog=exec_prog,
inputs=example_inputs,
ref_outputs=ref_outputs,
working_dir=working_dir,
)
3 changes: 2 additions & 1 deletion backends/cadence/aot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def print_ops_info(

# Print the final ops and their counts in a tabular format
logging.info(
tabulate(
"\n"
+ tabulate(
sorted_ops_count,
headers=[
"Final Operators ", # one character longer than the longest op name
Expand Down
2 changes: 2 additions & 0 deletions backends/cadence/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ python_library(
srcs = [
"__init__.py",
"executor.py",
"runtime.py",
"utils.py"
] + glob([
"xtsc-cfg/**/*",
]),
Expand Down
26 changes: 26 additions & 0 deletions examples/cadence/operators/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")

oncall("odai_jarvis")


python_unittest(
name = "test_add_op",
srcs = [
"test_add_op.py",
],
typing = True,
supports_static_listing = False,
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
"//caffe2:torch",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:export_example",
"//executorch/backends/cadence/aot:compiler",
],
)
115 changes: 115 additions & 0 deletions examples/cadence/operators/test_add_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

import unittest
from typing import Tuple

from parameterized import parameterized

from executorch.backends.cadence.aot.ops_registrations import * # noqa

import torch
import torch.nn as nn
from executorch.backends.cadence.aot.export_example import export_model


class ATenOpTestCases(unittest.TestCase):
@parameterized.expand(
[
[(7, 5, 6), (7, 5, 6)],
[(7, 5, 6), (1)],
[(1), (7, 5, 6)],
[(1), (7, 5, 6), 2.23],
[(1), (7, 5, 6), -1.0],
[(1), (7, 5, 6), -2.23],
[(7, 5, 6), (7, 5, 6), 1.23],
[(6, 7), (6, 7)],
[(6, 7), (6, 7), 2],
# Broadcast tests (should be optimized on G3)
[(1, 32, 64), (1, 1, 64)],
[(1, 32, 64), (64)],
[(1, 1, 32), (32)],
[(16, 1, 16), (1, 1, 16)],
[(16, 1, 16), (16)],
[(1, 4, 8, 8), (1, 1, 8, 8)],
[(1, 4, 8, 8), (8, 8)],
# Broadcast tests (should go to portable ops)
[(1, 10, 1, 8), (4, 1, 4, 1)],
[(1, 1, 16), (1, 8, 1), 2.5],
# # aten.upsample_nearest2d tests
[(5, 6, 6, 8), (5, 6, 6, 8)],
[(1, 1, 12, 16), (1, 1, 12, 16)],
]
)
def test_aten_add_out(
self, Xshape: Tuple[int], Yshape: Tuple[int], alpha: float = 1
) -> None:
class AddTensor(nn.Module):
def __init__(self, alpha: float):
super().__init__()
self.alpha = alpha

def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.add(x, y, alpha=self.alpha)

model = AddTensor(alpha)

X = torch.randn(Xshape)
Y = torch.randn(Yshape)

model.eval()
export_model(
model, (X, Y), file_name=self._testMethodName, run_and_compare=False
)

@parameterized.expand(
[
[(7, 5, 6), (7, 5, 6)],
[(7, 5, 6), (1)],
[(1), (7, 5, 6)],
[(1), (7, 5, 6), 2.23],
[(1), (7, 5, 6), -1.0],
[(1), (7, 5, 6), -2.23],
[(7, 5, 6), (7, 5, 6), 1.23],
[(6, 7), (6, 7)],
[(6, 7), (6, 7), 2],
# Broadcast tests (should be optimized on G3)
[(1, 32, 64), (1, 1, 64)],
[(1, 32, 64), (64)],
[(1, 1, 32), (32)],
[(16, 1, 16), (1, 1, 16)],
[(16, 1, 16), (16)],
[(1, 4, 8, 8), (1, 1, 8, 8)],
[(1, 4, 8, 8), (8, 8)],
# Broadcast tests (should go to portable ops)
[(1, 10, 1, 8), (4, 1, 4, 1)],
[(1, 1, 16), (1, 8, 1), 2.5],
# # aten.upsample_nearest2d tests
[(5, 6, 6, 8), (5, 6, 6, 8)],
[(1, 1, 12, 16), (1, 1, 12, 16)],
]
)
def test_aten_add_scalar_out(
self, Xshape: Tuple[int], Yshape: Tuple[int], alpha: float = 1
) -> None:
# Tensor-Scalar addition
class AddScalar(nn.Module):
def __init__(self, alpha: float):
super().__init__()
self.alpha = alpha

def forward(self, x: torch.Tensor, y: float):
return torch.add(x, y, alpha=self.alpha)

model = AddScalar(alpha)

X = torch.randn(Xshape)
Y = 2.34

model.eval()
export_model(
model, (X, Y), file_name=self._testMethodName, run_and_compare=False
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5785fc3

Please sign in to comment.