From 4e5b396d8098de82b38eab9e628a8502e92d3bad Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 12 Apr 2024 16:45:10 -0700 Subject: [PATCH] Internal changes. PiperOrigin-RevId: 624328298 --- .../python/genai/converter/llm_converter.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/genai/converter/llm_converter.py b/mediapipe/tasks/python/genai/converter/llm_converter.py index fdf5b6d1a3..34b3227f4f 100644 --- a/mediapipe/tasks/python/genai/converter/llm_converter.py +++ b/mediapipe/tasks/python/genai/converter/llm_converter.py @@ -38,6 +38,12 @@ class ConversionConfig(object): activation overflow issue when running in 16-bit floating point mode. To solve this, we need to scale down the weights of certain layers. See go/llm-on-device-fp16 for more detailed explanation. + lora_rank: An integer representing the rank of LoRA. If not provided, then + the converter assumes there is no LoRA weights. Note that only the GPU + backend supports LoRA. + lora_output_tflite_file: A string indicating the name of the generated + tflite file for the LoRA weight. Only applicable when the lora_rank is not + zero. """ def __init__( @@ -55,6 +61,8 @@ def __init__( vocab_model_file: str = '', output_tflite_file: Optional[str] = None, fp16_scale: Optional[float] = None, + lora_rank: Optional[int] = None, + lora_output_tflite_file: Optional[str] = None, ): self.input_ckpt = input_ckpt self.ckpt_format = ckpt_format @@ -142,10 +150,14 @@ def combined_weight_bins_to_tflite( weight_path: str, output_tflite_file: str, vocab_model_file: str, + lora_rank: Optional[int] = None, + lora_weight_path: Optional[str] = None, + lora_output_tflite_file: Optional[str] = None, ): """Combines weight files to tflite file.""" - # TODO: Figure out whether to clean up the weight files after this. if backend == 'cpu': + if lora_rank is not None: + logging.fatal('LoRA is not supported for CPU backend.') model_ckpt_util.GenerateCpuTfLite( model_type, weight_path, @@ -160,6 +172,9 @@ def combined_weight_bins_to_tflite( vocab_model_file, True, output_tflite_file, + 0 if lora_rank is None else lora_rank, + '' if lora_weight_path is None else lora_weight_path, + '' if lora_output_tflite_file is None else lora_output_tflite_file, ) else: raise ValueError('Unsupported backend: %s' % backend) @@ -251,4 +266,7 @@ def convert_checkpoint(config: ConversionConfig) -> None: weight_path=config.output_dir, output_tflite_file=config.output_tflite_file, vocab_model_file=vocab_model_path, + lora_rank=config.lora_rank, + lora_weight_path=config.output_dir, + lora_output_tflite_file=config.lora_output_tflite_file, )