Disclaimer: At present, we do not offer any backward compatibility guarantees for any APIs. We are currently in a development phase, and as such, we reserve the right to modify interfaces and implementations.
This backend is implemented on the top of Qualcomm AI Engine Direct SDK. Please follow tutorial to setup environment, build, and run executorch models by this backend (Qualcomm AI Engine Direct is also referred to as QNN in the source and documentation).
A website version of the tutorial is here.
Please check generate_qnn_executorch_compiler_spec()
in
utils.py for supported SoC and inference type.
- Snapdragon 8 Gen 1
- Snapdragon 8 Gen 1+
- Snapdragon 8 Gen 2
- Snapdragon 8 Gen 3
Currently, users cannot add additional chipset models because the chipset ID is not accessible to community users. If you have specific chipset models you wish to add, please contact one of the authors in the Code Reviews
section at the bottom of this page.
- Quantized
- FP16
backends/qualcomm
├── aot # Codes for generating QNN context binary (AoT Part).
| ├── wrappers # Wrapper of QNN data structures for ease of use.
| └── python # Python interface for using QNN libraries.
├── builders # Codes for lowering each operators (AoT Part).
├── partition # QNN Partitioner (AoT Part).
├── _passes # Various private passes helping lower models to QNN backend (AoT Part).
├── python # Places to put pybind artifacts for accessing QNN APIs, structures, etc (AoT Part).
├── quantizer # QNN Quantizer
├── runtime # Here is QNN runtime responsbile for compiling a model on x64.
| | # Meanwhile, this is also the runtime responsbile for executing compiled
| | # models on a device.
| └── backends # Backends supported by QNN.
| └── htpbackend
| ├── aarch64 # Configuration required to run on device. (Device Part).
| └── x86_64 # Configuration required to compile graph on host. (AoT Part).
├── scripts # Misc supporting scripts, not related to core functionality.
├── serialization # Contains files related to serializing QNN compiler options and SoC information
├── tests # Unit tests and model tests go here.
└── utils # Miscellaneous utilities.
examples/qualcomm
├── executor_runner # A general runner that is capable of running most of the basic models.
├── oss_scripts # Scripts for OSS(Open Source Software) models and customized runner for some specific models.
├── qaihub_scripts # Scripts for Qaihub models and corresponding customized runner for these models.
└── scripts # Scripts for models provided by executorch.
Please see this README.md.
Further, an example build script is provided as build.sh.
If you want to address the problem encountered, it would be great to have reproduction information for indicating maintainers. Please also follow the policy to emit issues.
PRs are always welcome to help improve the codebase in a comprehensive manner. Before submitting changes, please apply:
-
Check the Coding Style:
Make sure your code follows the style guides and passes the lint checks. -
Add Unit Tests:
Following is an example of adding test case after creating new operator builder, please navigate tobackends/qualcomm/tests
folder and put minimum example module inmodel.py
. e.g.:class IndexPut(torch.nn.Module): ... # please insert implementation in alphabetical order class LayerNorm(torch.nn.Module): def __init__(self): super().__init__() self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6) def forward(self, x): return self.layer_norm(x) class LeakyReLUDefault(torch.nn.Module): ...
Also extend sections
TestQNNFloatingPointOperator
,TestQNNQuantizedOperator
intest_qnn_delegate.py
. e.g.:class TestQNNQuantizedOperator(TestQNN): def test_qnn_backend_interpolate_nearest_2d(self): ... # please insert it implementation alphabetical order def test_qnn_backend_layer_norm(self): module = LayerNorm() # noqa: F405 sample_input = (torch.randn(196, 768),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_leaky_relu(self): ...
-
Verify Unit Test Results:
cd $PATH_TO_EXECUTORCH # example usage of performing unit test python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperator.test_qnn_backend_layer_norm -s $DEVICE_SERIAL -m SM8650 -b build-android/ -a $PATH_TO_TEST_ARTIFACTS
The test graph is expected to have 1 delegated node with only placeholders / output nodes being left. Check the execution report for more information.
-
Code Reviews:
Please ping authors in Qualcomm AI Engine Direct related PRs for reviewing, possible candidates are listed below:
Thanks again for your contribution!