Skip to content

Commit

Permalink
Internal changes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 624328298
  • Loading branch information
MediaPipe Team authored and copybara-github committed Apr 12, 2024
1 parent 2285e30 commit 4e5b396
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion mediapipe/tasks/python/genai/converter/llm_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ class ConversionConfig(object):
activation overflow issue when running in 16-bit floating point mode. To
solve this, we need to scale down the weights of certain layers. See
go/llm-on-device-fp16 for more detailed explanation.
lora_rank: An integer representing the rank of LoRA. If not provided, then
the converter assumes there is no LoRA weights. Note that only the GPU
backend supports LoRA.
lora_output_tflite_file: A string indicating the name of the generated
tflite file for the LoRA weight. Only applicable when the lora_rank is not
zero.
"""

def __init__(
Expand All @@ -55,6 +61,8 @@ def __init__(
vocab_model_file: str = '',
output_tflite_file: Optional[str] = None,
fp16_scale: Optional[float] = None,
lora_rank: Optional[int] = None,
lora_output_tflite_file: Optional[str] = None,
):
self.input_ckpt = input_ckpt
self.ckpt_format = ckpt_format
Expand Down Expand Up @@ -142,10 +150,14 @@ def combined_weight_bins_to_tflite(
weight_path: str,
output_tflite_file: str,
vocab_model_file: str,
lora_rank: Optional[int] = None,
lora_weight_path: Optional[str] = None,
lora_output_tflite_file: Optional[str] = None,
):
"""Combines weight files to tflite file."""
# TODO: Figure out whether to clean up the weight files after this.
if backend == 'cpu':
if lora_rank is not None:
logging.fatal('LoRA is not supported for CPU backend.')
model_ckpt_util.GenerateCpuTfLite(
model_type,
weight_path,
Expand All @@ -160,6 +172,9 @@ def combined_weight_bins_to_tflite(
vocab_model_file,
True,
output_tflite_file,
0 if lora_rank is None else lora_rank,
'' if lora_weight_path is None else lora_weight_path,
'' if lora_output_tflite_file is None else lora_output_tflite_file,
)
else:
raise ValueError('Unsupported backend: %s' % backend)
Expand Down Expand Up @@ -251,4 +266,7 @@ def convert_checkpoint(config: ConversionConfig) -> None:
weight_path=config.output_dir,
output_tflite_file=config.output_tflite_file,
vocab_model_file=vocab_model_path,
lora_rank=config.lora_rank,
lora_weight_path=config.output_dir,
lora_output_tflite_file=config.lora_output_tflite_file,
)

0 comments on commit 4e5b396

Please sign in to comment.