From d72467351751e8ee4eaf5d9f2c39a3936fe60b3d Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 30 May 2024 07:06:31 -0700 Subject: [PATCH] Fix GPTQ import error after torchao refactor (#3760) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/3760 Fix broken import after https://github.com/pytorch/ao/pull/275 Reviewed By: jerryzh168 Differential Revision: D57888168 fbshipit-source-id: 51a63131ae14e362991ef962df325ec24f958e2d --- .../models/llama2/source_transformation/quantize.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/models/llama2/source_transformation/quantize.py b/examples/models/llama2/source_transformation/quantize.py index c5472668ca..830c47ab86 100644 --- a/examples/models/llama2/source_transformation/quantize.py +++ b/examples/models/llama2/source_transformation/quantize.py @@ -96,7 +96,13 @@ def quantize( if calibration_tasks is None: calibration_tasks = ["wikitext"] - from torchao.quantization.GPTQ import InputRecorder + try: + # torchao 0.3+ + # pyre-ignore + from torchao._eval import InputRecorder + except ImportError: + from torchao.quantization.GPTQ import InputRecorder + from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer if tokenizer_path is None: @@ -107,7 +113,7 @@ def quantize( ) inputs = ( - InputRecorder( + InputRecorder( # pyre-ignore tokenizer, calibration_seq_length, None, # input_prep_func