forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Qualcomm AI Engine Direct - ESRGAN Enablement (pytorch#3751)
Summary: - OSS model enablement: esgan (https://github.com/ai-forever/Real-ESRGAN) - support upsample_nearest2d / pixel_unshuffle / leaky_relu / prelu - test cases for newly added operators Pull Request resolved: pytorch#3751 Reviewed By: kirklandsign Differential Revision: D57896517 Pulled By: cccclai fbshipit-source-id: 12e7911ef28dbb15604ada4fe5bc6f848b819eab
- Loading branch information
1 parent
d9194d1
commit 8f65b6d
Showing
17 changed files
with
943 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Copyright (c) Qualcomm Innovation Center, Inc. | ||
# All rights reserved | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from typing import Dict | ||
|
||
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper | ||
|
||
import torch | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
|
||
from .node_visitor import get_parameter, NodeVisitor, register_node_visitor | ||
from .qnn_constants import OpPRelu, QNN_OP_PACKAGE_NAME_QTI_AISW | ||
|
||
|
||
@register_node_visitor | ||
class PReLU(NodeVisitor): | ||
target = ["aten.leaky_relu.default", "aten.prelu.default"] | ||
|
||
def __init__(self, *args) -> None: | ||
super().__init__(*args) | ||
|
||
def define_node( | ||
self, | ||
node: torch.fx.Node, | ||
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], | ||
) -> PyQnnWrapper.PyQnnOpWrapper: | ||
input_node = node.args[0] | ||
input_tensor = self.get_tensor(input_node, node) | ||
prelu_inp_tensor_wrapper = self.define_tensor( | ||
input_node, | ||
input_tensor, | ||
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, | ||
nodes_to_wrappers, | ||
is_input_tensor=True, | ||
) | ||
|
||
if node.target.__name__ == "aten.leaky_relu.default": | ||
coeff = 1e-2 if len(node.args) < 2 else node.args[1] | ||
coeff_tensor = torch.full(input_tensor.shape, coeff).to(torch.float32) | ||
else: | ||
coeff_node = node.args[1] | ||
coeff_tensor = torch.zeros(input_node.meta["val"].shape) | ||
coeff = get_parameter(coeff_node, self.edge_program) | ||
# per-channel activation | ||
if coeff_node.meta["val"].shape[0] > 1: | ||
for i in range(input_node.meta["val"].shape[1]): | ||
coeff_tensor = coeff_tensor.index_fill( | ||
1, torch.tensor([i]), coeff[i] | ||
) | ||
if "axis_order" in input_node.meta: | ||
axis_order = input_node.meta["axis_order"] | ||
coeff_tensor = coeff_tensor.permute(dims=axis_order).contiguous() | ||
# simple min-max quantization | ||
coeff = torch.max(coeff).item() | ||
else: | ||
coeff = coeff.item() | ||
coeff_tensor = torch.full(input_tensor.shape, coeff).to(torch.float32) | ||
|
||
# 'graph', 'name', 'op', 'target', 'args', and 'kwargs' | ||
scalar_node = torch.fx.Node( | ||
node.graph, | ||
node.name + "_runtime_scalar", | ||
"call_function", | ||
exir_ops.edge.aten.full.default, | ||
(), # args | ||
{}, # kwargs | ||
) | ||
if pow_quant_attrs := node.meta.get("quant_attrs"): | ||
quant_attrs = pow_quant_attrs.copy() | ||
quant_range = quant_attrs["quant_max"] - quant_attrs["quant_min"] | ||
# coeff is guaranteed to be positive | ||
quant_attrs["zero_point"] = 0 | ||
quant_attrs["scale"] = coeff / quant_range | ||
scalar_node.meta["quant_attrs"] = quant_attrs | ||
|
||
scalar_tensor_wrapper = self.define_tensor( | ||
scalar_node, | ||
coeff_tensor, | ||
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, | ||
nodes_to_wrappers, | ||
is_input_tensor=True, | ||
) | ||
prelu_input_tensors = [prelu_inp_tensor_wrapper, scalar_tensor_wrapper] | ||
|
||
output_tensor = self.get_tensor(node, node) | ||
output_tensor_wrapper = self.define_tensor( | ||
node, | ||
output_tensor, | ||
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, | ||
nodes_to_wrappers, | ||
is_input_tensor=False, | ||
) | ||
prelu_output_tensors = [output_tensor_wrapper] | ||
|
||
prelu_op = PyQnnWrapper.PyQnnOpWrapper( | ||
node.name, | ||
QNN_OP_PACKAGE_NAME_QTI_AISW, | ||
OpPRelu.op_name, | ||
) | ||
prelu_op.AddInputTensors(prelu_input_tensors) | ||
prelu_op.AddOutputTensors(prelu_output_tensors) | ||
|
||
return prelu_op |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright (c) Qualcomm Innovation Center, Inc. | ||
# All rights reserved | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Dict | ||
|
||
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from .node_visitor import NodeVisitor, register_node_visitor | ||
from .qnn_constants import OpSpaceToDepth, QNN_OP_PACKAGE_NAME_QTI_AISW | ||
|
||
|
||
@register_node_visitor | ||
class SpaceToDepthVisitor(NodeVisitor): | ||
target = ["aten.pixel_unshuffle.default"] | ||
|
||
def __init__(self, *args) -> None: | ||
super().__init__(*args) | ||
|
||
def define_node( | ||
self, | ||
node: torch.fx.Node, | ||
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], | ||
) -> PyQnnWrapper.PyQnnOpWrapper: | ||
input_node = node.args[0] | ||
input_tensor = self.get_tensor(input_node, node) | ||
input_tensor_wrapper = self.define_tensor( | ||
input_node, | ||
input_tensor, | ||
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, | ||
nodes_to_wrappers, | ||
is_input_tensor=True, | ||
) | ||
|
||
output_tensor = self.get_tensor(node, node) | ||
output_tensor_wrapper = self.define_tensor( | ||
node, | ||
output_tensor, | ||
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, | ||
nodes_to_wrappers, | ||
is_input_tensor=False, | ||
) | ||
|
||
block_size = [] | ||
for index in range(1, 3): | ||
block_size.append(input_tensor.shape[index] / output_tensor.shape[index]) | ||
block_size = np.array(block_size, dtype=np.uint32) | ||
block_size_shape = [2] | ||
|
||
space_to_depth_op = PyQnnWrapper.PyQnnOpWrapper( | ||
node.name, | ||
QNN_OP_PACKAGE_NAME_QTI_AISW, | ||
OpSpaceToDepth.op_name, | ||
) | ||
space_to_depth_op.AddInputTensors([input_tensor_wrapper]) | ||
space_to_depth_op.AddOutputTensors([output_tensor_wrapper]) | ||
space_to_depth_op.AddTensorParam( | ||
OpSpaceToDepth.param_block_size, | ||
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, | ||
len(block_size.shape), | ||
block_size_shape, | ||
block_size, | ||
True, | ||
) | ||
space_to_depth_op.AddScalarParam( | ||
OpSpaceToDepth.param_mode, | ||
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, | ||
{"data": np.uint32(OpSpaceToDepth.Mode.CRD)}, | ||
) | ||
|
||
return space_to_depth_op |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright (c) Qualcomm Innovation Center, Inc. | ||
# All rights reserved | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from typing import Dict | ||
|
||
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper | ||
|
||
import torch | ||
|
||
from .node_visitor import NodeVisitor, register_node_visitor | ||
from .qnn_constants import OpResizeNearestNeighbor, QNN_OP_PACKAGE_NAME_QTI_AISW | ||
|
||
|
||
@register_node_visitor | ||
class ResizeBilinear(NodeVisitor): | ||
target = ["aten.upsample_nearest2d.default"] | ||
|
||
def __init__(self, *args) -> None: | ||
super().__init__(*args) | ||
|
||
def define_node( | ||
self, | ||
node: torch.fx.Node, | ||
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], | ||
) -> PyQnnWrapper.PyQnnOpWrapper: | ||
input_node = node.args[0] | ||
input_tensor = self.get_tensor(input_node, node) | ||
input_tensor_wrapper = self.define_tensor( | ||
input_node, | ||
input_tensor, | ||
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, | ||
nodes_to_wrappers, | ||
is_input_tensor=True, | ||
) | ||
|
||
output_tensor = self.get_tensor(node, node) | ||
output_tensor_wrapper = self.define_tensor( | ||
node, | ||
output_tensor, | ||
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, | ||
nodes_to_wrappers, | ||
is_input_tensor=False, | ||
) | ||
|
||
reisze_nearest_op = PyQnnWrapper.PyQnnOpWrapper( | ||
node.name, | ||
QNN_OP_PACKAGE_NAME_QTI_AISW, | ||
OpResizeNearestNeighbor.op_name, | ||
) | ||
reisze_nearest_op.AddInputTensors([input_tensor_wrapper]) | ||
reisze_nearest_op.AddOutputTensors([output_tensor_wrapper]) | ||
# align_corners is guaranteed to be false | ||
reisze_nearest_op.AddScalarParam( | ||
OpResizeNearestNeighbor.param_align_corners, | ||
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, | ||
{"data": False}, | ||
) | ||
reisze_nearest_op.AddScalarParam( | ||
OpResizeNearestNeighbor.param_half_pixel_centers, | ||
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, | ||
{"data": True}, | ||
) | ||
|
||
return reisze_nearest_op |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.