diff --git a/examples/models/llama2/source_transformation/quantize.py b/examples/models/llama2/source_transformation/quantize.py index c5472668ca..830c47ab86 100644 --- a/examples/models/llama2/source_transformation/quantize.py +++ b/examples/models/llama2/source_transformation/quantize.py @@ -96,7 +96,13 @@ def quantize( if calibration_tasks is None: calibration_tasks = ["wikitext"] - from torchao.quantization.GPTQ import InputRecorder + try: + # torchao 0.3+ + # pyre-ignore + from torchao._eval import InputRecorder + except ImportError: + from torchao.quantization.GPTQ import InputRecorder + from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer if tokenizer_path is None: @@ -107,7 +113,7 @@ def quantize( ) inputs = ( - InputRecorder( + InputRecorder( # pyre-ignore tokenizer, calibration_seq_length, None, # input_prep_func