From 4b3ffc4ae5f7a86f34e26c005086889032b40d15 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 11 Oct 2024 16:24:25 -0700 Subject: [PATCH] Add Vulkan Quantizer to Llama export lib (#6169) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/6169 TSIA. Note that only 8 bit weight only quantization is supported for now since `VulkanQuantizer` does not support 4 bit weight only quantization at the moment. ghstack-source-id: 247613963 exported-using-ghexport Reviewed By: jorgep31415 Differential Revision: D64249615 fbshipit-source-id: 33ef8d06e56838da5f7866832e50fe74d8878811 --- examples/models/llama2/export_llama_lib.py | 8 ++++++++ extension/llm/export/TARGETS | 1 + extension/llm/export/quantizer_lib.py | 19 +++++++++++++++++++ 3 files changed, 28 insertions(+) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index dce864adfe..cdac837fd2 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -41,6 +41,7 @@ get_pt2e_quantization_params, get_pt2e_quantizers, get_qnn_quantizer, + get_vulkan_quantizer, ) from executorch.util.activation_memory_profiler import generate_memory_trace @@ -147,6 +148,7 @@ def build_args_parser() -> argparse.ArgumentParser: "coreml_8a_c4w", "coreml_baseline_8a_c8w", "coreml_baseline_8a_c4w", + "vulkan_8w", ], help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.", ) @@ -548,6 +550,12 @@ def get_quantizer_and_quant_params(args): assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml" coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize) quantizers.append(coreml_quantizer) + if args.vulkan and args.pt2e_quantize: + assert ( + len(quantizers) == 0 + ), "Should not enable both vulkan and other quantizers" + vulkan_quantizer = get_vulkan_quantizer(args.pt2e_quantize) + quantizers.append(vulkan_quantizer) logging.info(f"Applying quantizers: {quantizers}") return pt2e_quant_params, quantizers, quant_dtype diff --git a/extension/llm/export/TARGETS b/extension/llm/export/TARGETS index e4ade20228..866cfe56ea 100644 --- a/extension/llm/export/TARGETS +++ b/extension/llm/export/TARGETS @@ -31,6 +31,7 @@ runtime.python_library( "//executorch/backends/qualcomm/quantizer:quantizer", "//executorch/backends/transforms:duplicate_dynamic_quant_chain", "//executorch/backends/vulkan/partitioner:vulkan_partitioner", + "//executorch/backends/vulkan/quantizer:vulkan_quantizer", "//executorch/backends/xnnpack/partition:xnnpack_partitioner", "//executorch/exir:lib", "//executorch/exir/backend:backend_details", diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 30701e4fa5..fd368d73f1 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -260,3 +260,22 @@ def get_coreml_quantizer(pt2e_quantize: str): raise ValueError(f"Unsupported Core ML quantizer specification {pt2e_quantize}") return quantizer + + +def get_vulkan_quantizer(pt2e_quantize: str): + from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( + get_weight_quantization_config, + VulkanQuantizer, + ) + + if pt2e_quantize == "vulkan_8w": + config = get_weight_quantization_config( + is_per_channel=True, + weight_qmin=-128, + weight_qmax=127, + ) + else: + raise ValueError(f"Unsupported Vulkan quantizer specification {pt2e_quantize}") + + quantizer = VulkanQuantizer().set_global(config) + return quantizer