Skip to content

Commit

Permalink
Add Vulkan Quantizer to Llama export lib (pytorch#6169)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#6169

TSIA.

Note that only 8 bit weight only quantization is supported for now since `VulkanQuantizer` does not support 4 bit weight only quantization at the moment.
ghstack-source-id: 247613963
exported-using-ghexport

Reviewed By: jorgep31415

Differential Revision: D64249615

fbshipit-source-id: 33ef8d06e56838da5f7866832e50fe74d8878811
  • Loading branch information
SS-JIA authored and facebook-github-bot committed Oct 11, 2024
1 parent 236e60d commit 4b3ffc4
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 0 deletions.
8 changes: 8 additions & 0 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
get_pt2e_quantization_params,
get_pt2e_quantizers,
get_qnn_quantizer,
get_vulkan_quantizer,
)
from executorch.util.activation_memory_profiler import generate_memory_trace

Expand Down Expand Up @@ -147,6 +148,7 @@ def build_args_parser() -> argparse.ArgumentParser:
"coreml_8a_c4w",
"coreml_baseline_8a_c8w",
"coreml_baseline_8a_c4w",
"vulkan_8w",
],
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
)
Expand Down Expand Up @@ -548,6 +550,12 @@ def get_quantizer_and_quant_params(args):
assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml"
coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize)
quantizers.append(coreml_quantizer)
if args.vulkan and args.pt2e_quantize:
assert (
len(quantizers) == 0
), "Should not enable both vulkan and other quantizers"
vulkan_quantizer = get_vulkan_quantizer(args.pt2e_quantize)
quantizers.append(vulkan_quantizer)
logging.info(f"Applying quantizers: {quantizers}")
return pt2e_quant_params, quantizers, quant_dtype

Expand Down
1 change: 1 addition & 0 deletions extension/llm/export/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ runtime.python_library(
"//executorch/backends/qualcomm/quantizer:quantizer",
"//executorch/backends/transforms:duplicate_dynamic_quant_chain",
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
"//executorch/backends/vulkan/quantizer:vulkan_quantizer",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
"//executorch/exir:lib",
"//executorch/exir/backend:backend_details",
Expand Down
19 changes: 19 additions & 0 deletions extension/llm/export/quantizer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,22 @@ def get_coreml_quantizer(pt2e_quantize: str):
raise ValueError(f"Unsupported Core ML quantizer specification {pt2e_quantize}")

return quantizer


def get_vulkan_quantizer(pt2e_quantize: str):
from executorch.backends.vulkan.quantizer.vulkan_quantizer import (
get_weight_quantization_config,
VulkanQuantizer,
)

if pt2e_quantize == "vulkan_8w":
config = get_weight_quantization_config(
is_per_channel=True,
weight_qmin=-128,
weight_qmax=127,
)
else:
raise ValueError(f"Unsupported Vulkan quantizer specification {pt2e_quantize}")

quantizer = VulkanQuantizer().set_global(config)
return quantizer

0 comments on commit 4b3ffc4

Please sign in to comment.