diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 256dfe5a77..1b853c60c5 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -14,6 +14,7 @@ op_ceil, op_clamp, op_conv2d, + op_cos, op_depth_to_space, op_dequantize, op_div, @@ -43,6 +44,7 @@ op_rsqrt, op_select_copy, op_sigmoid, + op_sin, op_skip_ops, op_slice_copy, op_softmax, @@ -71,6 +73,7 @@ op_ceil, op_clamp, op_conv2d, + op_cos, op_depth_to_space, op_dequantize, op_div, @@ -100,6 +103,7 @@ op_rsqrt, op_select_copy, op_sigmoid, + op_sin, op_skip_ops, op_slice_copy, op_softmax, diff --git a/backends/qualcomm/builders/op_cos.py b/backends/qualcomm/builders/op_cos.py new file mode 100644 index 0000000000..98caed10d1 --- /dev/null +++ b/backends/qualcomm/builders/op_cos.py @@ -0,0 +1,56 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseCos, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Cos(NodeVisitor): + target = ["aten.cos.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + + cos_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseCos.op_name, + ) + cos_op.AddInputTensors([input_tensor_wrapper]) + cos_op.AddOutputTensors([output_tensor_wrapper]) + + return cos_op diff --git a/backends/qualcomm/builders/op_sin.py b/backends/qualcomm/builders/op_sin.py new file mode 100644 index 0000000000..40e466f59e --- /dev/null +++ b/backends/qualcomm/builders/op_sin.py @@ -0,0 +1,56 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseSin, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Sin(NodeVisitor): + target = ["aten.sin.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + + sin_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseSin.op_name, + ) + sin_op.AddInputTensors([input_tensor_wrapper]) + sin_op.AddOutputTensors([output_tensor_wrapper]) + + return sin_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index ebcae4a515..cb48cf38ba 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -85,6 +85,11 @@ class OpElementWiseCeil: op_name = "ElementWiseCeil" +@dataclass(init=False, frozen=True) +class OpElementWiseCos: + op_name: str = "ElementWiseCos" + + @dataclass(init=False, frozen=True) class OpElementWiseDivide: op_name: str = "ElementWiseDivide" @@ -113,6 +118,11 @@ class OpElementWiseRsqrt: op_name: str = "ElementWiseRsqrt" +@dataclass(init=False, frozen=True) +class OpElementWiseSin: + op_name: str = "ElementWiseSin" + + @dataclass(init=False, frozen=True) class OpElementWiseSubtract: op_name = "ElementWiseSubtract" diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 275da567e8..68d512a4e0 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -271,6 +271,16 @@ def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.cos.default]) +def annotate_cos(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + +@register_annotator([torch.ops.aten.sin.default]) +def annotate_sin(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.tanh.default]) def annotate_tanh(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 0ed66329c3..7bc8ae7417 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -427,6 +427,14 @@ def forward(self, x): return topk_values +class Cos(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.cos(x) + + class Div(torch.nn.Module): def __init__(self): super().__init__() @@ -889,6 +897,14 @@ def forward(self, x): return torch.sigmoid(x) +class Sin(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + class SimpleModel(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 10917cdd6b..99b16811b2 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -143,6 +143,11 @@ def test_qnn_backend_conv_transpose2d(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cos(self): + module = Cos() # noqa: F405 + sample_input = (torch.randn(2, 5, 1, 3),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_einsum_outer_product(self): module = EinsumOuterProduct() # noqa: F405 x = torch.randn(5) @@ -465,6 +470,11 @@ def test_qnn_backend_sigmoid(self): sample_input = (torch.randn([1, 3, 3, 3]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_sin(self): + module = Sin() # noqa: F405 + sample_input = (torch.randn(2, 5, 1, 3),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_select_copy(self): module = SelectCopy() # noqa: F405 sample_input = (torch.randn([1, 3, 3, 3]),) @@ -825,6 +835,12 @@ def test_qnn_backend_conv_transpose2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cos(self): + module = Cos() # noqa: F405 + sample_input = (torch.randn(2, 5, 1, 3),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_einsum_outer_product(self): module = EinsumOuterProduct() # noqa: F405 x = torch.randn(5) @@ -1201,6 +1217,12 @@ def test_qnn_backend_sigmoid(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_sin(self): + module = Sin() # noqa: F405 + sample_input = (torch.randn(2, 5, 1, 3),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_slice_copy(self): modules = [SliceCopy(), SliceCopyWithStep()] # noqa: F405 sample_input = ( diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index dc517764f8..38d7c9a64e 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -256,6 +256,27 @@ def build_executorch_binary( custom_pass_config=frozenset(), qat_training_data=None, ): + """ + A function to generate an ExecuTorch binary for Qualcomm platforms. + + Attributes: + model (torch.nn.Module): The model to be converted into an ExecuTorch binary. + inputs (torch.Tensor): Sample input tensors required for model export. + soc_model (QcomChipset): The target Qualcomm System on Chip (SoC) model. + file_name (str): Name for the output binary file (.pte). + dataset (List[torch.Tensor] | Callable): A dataset for quantization calibration. + skip_node_id_set (set, optional): Set of node IDs to be skipped during partition. + skip_node_op_set (set, optional): Set of operation node to be skipped during partition. + quant_dtype (QuantDtype, optional): Data type for quantization. + custom_quantizer (Callable, optional): Custom quantizer. + shared_buffer (bool, optional): Applies zero-copy mechanism to optimize runtime memory allocation. + metadata (dict, optional): An optional dictionary that maps each method name to a constant value in eager mode. + dump_intermediate_outputs (bool, optional): Enables dumping model intermediate outputs. + custom_pass_config (frozenset, optional): Set of custom passes for model processing. + + Returns: + None: The function writes the output to a specified .pte file. + """ if quant_dtype is not None: captured_model = torch.export.export(model, inputs).module() if qat_training_data: