diff --git a/mediapipe/model_maker/python/llm/BUILD b/mediapipe/model_maker/python/llm/BUILD index 8af4bb61bb..7ea093a4c7 100644 --- a/mediapipe/model_maker/python/llm/BUILD +++ b/mediapipe/model_maker/python/llm/BUILD @@ -22,6 +22,33 @@ py_library( srcs = ["converter_base.py"], ) +py_library( + name = "converter_factory", + srcs = ["converter_factory.py"], + deps = [ + ":converter_base", + ":pytorch_converter", + ":safetensors_converter", + ":weight_bins_writer", + ], +) + +py_library( + name = "pytorch_converter", + srcs = ["pytorch_converter.py"], + deps = [ + ":converter_base", + ], +) + +py_library( + name = "safetensors_converter", + srcs = ["safetensors_converter.py"], + deps = [ + ":converter_base", + ], +) + py_library( name = "quantization_util", srcs = ["quantization_util.py"], @@ -35,3 +62,21 @@ py_test( ":quantization_util", ], ) + +py_library( + name = "weight_bins_writer", + srcs = ["weight_bins_writer.py"], + deps = [ + ":converter_base", + ":quantization_util", + ], +) + +py_test( + name = "weight_bins_writer_test", + srcs = ["weight_bins_writer_test.py"], + srcs_version = "PY3", + deps = [ + ":weight_bins_writer", + ], +) diff --git a/mediapipe/model_maker/python/llm/converter_base.py b/mediapipe/model_maker/python/llm/converter_base.py index 6441af3be0..6bff55a2f3 100644 --- a/mediapipe/model_maker/python/llm/converter_base.py +++ b/mediapipe/model_maker/python/llm/converter_base.py @@ -131,7 +131,7 @@ def __init__( self._embedding_quant_bits = embedding_quant_bits def map_to_actions(self, layer_name: str) -> Optional[QuantizationAction]: - """""Maps the layer weights to quantization actions. + """Maps the layer weights to quantization actions. Args: layer_name: A string representing the name of the layer weight. Note that diff --git a/mediapipe/model_maker/python/llm/converter_factory.py b/mediapipe/model_maker/python/llm/converter_factory.py new file mode 100644 index 0000000000..882dabfd6b --- /dev/null +++ b/mediapipe/model_maker/python/llm/converter_factory.py @@ -0,0 +1,78 @@ +# Copyright 2024 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility library that helps create the converter instances.""" + +from mediapipe.model_maker.python.llm import converter_base +from mediapipe.model_maker.python.llm import pytorch_converter +from mediapipe.model_maker.python.llm import safetensors_converter +from mediapipe.model_maker.python.llm import weight_bins_writer + + +def create_ckpt_loader( + ckpt_format: str, *args, **kwargs +) -> converter_base.CkptLoaderBase: + """Creates the checkpoint loader. + + Args: + ckpt_format: A string that indicates which input checkpoint format is. + *args: Additional arguments to be passed into the loader. + **kwargs: Additional arguments to be passed into the loader. + + Returns: + A created CkptLoader instance. + """ + del args + if ckpt_format == "pytorch": + return pytorch_converter.PytorchCkptLoader( + ckpt_path=kwargs["ckpt_path"], + is_symmetric=kwargs["is_symmetric"], + attention_quant_bits=kwargs["attention_quant_bits"], + feedforward_quant_bits=kwargs["feedforward_quant_bits"], + embedding_quant_bits=kwargs["embedding_quant_bits"], + special_model=kwargs["special_model"], + ) + elif ckpt_format == "safetensors": + return safetensors_converter.SafetensorsCkptLoader( + ckpt_path=kwargs["ckpt_path"], + is_symmetric=kwargs["is_symmetric"], + attention_quant_bits=kwargs["attention_quant_bits"], + feedforward_quant_bits=kwargs["feedforward_quant_bits"], + embedding_quant_bits=kwargs["embedding_quant_bits"], + special_model=kwargs["special_model"], + ) + else: + raise ValueError(f"Unknown checkpoint format: {ckpt_format}") + + +def create_writer( + writer_type: str, *args, **kwargs +) -> converter_base.ModelWriterBase: + """Creates the model writer. + + Args: + writer_type: A string the indicates which model writer to create. + *args: Additional arguments to be passed into the loader. + **kwargs: Additional arguments to be passed into the loader. + + Returns: + A created ModelWriter instance. + """ + del args + if writer_type == "weight_bins": + return weight_bins_writer.WeightBinsWriter( + output_dir=kwargs["output_dir"], backend=kwargs["backend"] + ) + else: + raise ValueError(f"Unknown writer type: {writer_type}") diff --git a/mediapipe/model_maker/python/llm/pytorch_converter.py b/mediapipe/model_maker/python/llm/pytorch_converter.py new file mode 100644 index 0000000000..4836cc72fe --- /dev/null +++ b/mediapipe/model_maker/python/llm/pytorch_converter.py @@ -0,0 +1,273 @@ +# Copyright 2024 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CkptLoader implementation for loading the Pytorch file.""" + +import enum +import os +from typing import List, Optional + +import numpy as np +import torch + +from mediapipe.model_maker.python.llm import converter_base + + +class LayerType(enum.Enum): + """Enum for layer type.""" + + NONE = 0 + ATTENTION = 1 # Layer is part of the attention module. + FEEDFORWARD = 2 # Layer is part of the feedforward module in the Transformer. + EMBEDDING = 3 # Layer is the embedding lookup or final projection layer. + LAYER_NORM = ( + 4 # Layer is layer normalization before and after attention layer. + ) + + @classmethod + def get_layer_type(cls, layer_name: str): + """Gets the layer type of the given layer name.""" + ffn_layers = [ + "mlp", + ] + attn_layers = [ + "self_attention", + ] + emb_layers = [ + "word_embeddings", + "lm_head", + ] + layer_norms = [ + "input_layernorm", + "post_attention_layernorm", + "ln_f", + ] + if any(sub_name in layer_name for sub_name in attn_layers): + return LayerType.ATTENTION + if any(sub_name in layer_name for sub_name in ffn_layers): + return LayerType.FEEDFORWARD + if any(sub_name in layer_name for sub_name in emb_layers): + return LayerType.EMBEDDING + if any(sub_name in layer_name for sub_name in layer_norms): + return LayerType.LAYER_NORM + else: + return LayerType.NONE + + +class FalconMapper(converter_base.LayerActionMapperBase): + """LayerActionMapper for handling the Falcon-rw-1b model.""" + + # we don't quantize embedding, final MLP and layer norm for falcon model. + NON_QUANTIZED_LAYERS = [ + "transformer.word_embeddings.weight", + "transformer.ln_f", + "lm_head", + "input_layernorm", + "post_attention_layernorm", + ] + + def map_to_actions( + self, layer_name: str + ) -> Optional[converter_base.QuantizationAction]: + """Map the given layer name to actions.""" + quantize_axis = None + quantize_bits = None + if all(name not in layer_name for name in self.NON_QUANTIZED_LAYERS) and ( + layer_name.endswith(".weight") + ): + layer_type = LayerType.get_layer_type(layer_name) + quantize_axis = [0] + if layer_type == LayerType.FEEDFORWARD: + quantize_bits = self._feedforward_quant_bits + elif layer_type == LayerType.ATTENTION: + quantize_bits = self._attention_quant_bits + elif layer_type == LayerType.EMBEDDING: + quantize_bits = self._embedding_quant_bits + + return converter_base.QuantizationAction( + tensor_name=layer_name, + target_name=layer_name, + quantize_axis=quantize_axis, + quantize_bits=quantize_bits, + pack_dim=0, + ) + + def update_target_name(self, target_name: str) -> str: + """Updates the target name to match the tensor name convention.""" + layer_type = LayerType.get_layer_type(target_name) + + target_name = target_name.replace( + "transformer.h.", "params.lm.transformer.x_layers_" + ) + + if layer_type == LayerType.FEEDFORWARD: + target_name = target_name.replace(".weight", ".linear.w") + target_name = target_name.replace(".bias", ".bias.b") + target_name = target_name.replace( + "mlp.dense_h_to_4h", "ff_layer.ffn_layer1" + ) + target_name = target_name.replace( + "mlp.dense_4h_to_h", "ff_layer.ffn_layer2" + ) + elif layer_type == LayerType.ATTENTION: + target_name = target_name.replace("dense", "post") + target_name = target_name.replace(".weight", ".linear.w") + target_name = target_name.replace(".bias", ".bias.b") + elif layer_type == LayerType.EMBEDDING: + target_name = target_name.replace( + "transformer.word_embeddings", "params.lm.token_embedding" + ) + target_name = target_name.replace( + "lm_head", "params.lm.softmax.logits_ffn" + ) + target_name = target_name.replace(".weight", ".w") + elif layer_type == LayerType.LAYER_NORM: + target_name = target_name.replace("input_layernorm", "pre_layer_norm") + target_name = target_name.replace( + "pre_layer_norm.weight", "pre_layer_norm.scale" + ) + target_name = target_name.replace( + "post_attention_layernorm", "post_layer_norm" + ) + target_name = target_name.replace( + "post_layer_norm.weight", "post_layer_norm.scale" + ) + target_name = target_name.replace( + "transformer.ln_f.weight", "params.lm.final_ln.scale" + ) + target_name = target_name.replace( + "transformer.ln_f.bias", "params.lm.final_ln.bias" + ) + + return target_name + + +class PytorchCkptLoader(converter_base.CkptLoaderBase): + """CkptLoader implementation for loading the Pytorch model.""" + + def __init__( + self, + ckpt_path: str, + is_symmetric: bool, + attention_quant_bits: int, + feedforward_quant_bits: int, + embedding_quant_bits: int, + special_model: str, + ): + """Initializes the loader. + + Args: + ckpt_path: The filepath to the safetensors file. + is_symmetric: Whether to apply symmetric or asymmetric quantization. + attention_quant_bits: An integer that specify the target quantization bits + (support 8 or 4) for the attention layers. + feedforward_quant_bits: An integer that specify the target quantization + bits (support 8 or 4) for the feedforward layers in each Transformer + blocks. + embedding_quant_bits: An integer that specify the target quantization bits + (support 8 or 4) for the embedding (and the final projection) layers. + special_model: A string that indicates which input model is and whether + any special treatment is needed. + """ + super().__init__( + ckpt_path, + is_symmetric, + attention_quant_bits, + feedforward_quant_bits, + embedding_quant_bits, + ) + + self._special_model = special_model + if special_model in ["FALCON_RW_1B"]: + self.mapper = FalconMapper( + is_symmetric, + attention_quant_bits, + feedforward_quant_bits, + ) + else: + raise ValueError(f"Unknown special model: {special_model}") + + self._ckpt_path = ckpt_path + if not os.path.exists(self._ckpt_path): + raise ValueError(f"{self._ckpt_path} does not exists.") + self._model = torch.load(self._ckpt_path, map_location=torch.device("cpu")) + + def load_to_actions(self): + tensor_names = self._model.keys() + actions = [] + for tensor_name in tensor_names: + tensor_value = ( + self._model[tensor_name] + .to(torch.float32) + .t() + .contiguous() + .detach() + .cpu() + .numpy() + ) + if ( + isinstance(self.mapper, FalconMapper) + and "query_key_value" in tensor_name + ): + qkv_tensors = self._decompose_falcon_qkv(tensor_value) + for tensor, qkv_name in zip(qkv_tensors, ["q", "k", "v"]): + decomposed_name = tensor_name.replace("query_key_value", qkv_name) + action = self.mapper.map_to_actions(decomposed_name) + action.tensor_value = tensor + action.target_name = self.mapper.update_target_name(decomposed_name) + actions.append(action) + else: + action = self.mapper.map_to_actions(tensor_name) + if action is None: + continue + action.tensor_value = tensor_value + action.target_name = self.mapper.update_target_name(tensor_name) + actions.append(action) + return actions + + def _decompose_falcon_qkv(self, tensor_value: np.ndarray) -> List[np.ndarray]: + """Decomposes combined qkv tensor used in falcon model into separate q, k and v tensors.""" + chunk_size = 64 + hidden_size = 2048 + + tensor_value = tensor_value.transpose() + + q_tensor = np.zeros( + (hidden_size,) + + ((hidden_size,) if len(tensor_value.shape) == 2 else ()), + dtype=tensor_value.dtype, + ) + k_tensor = np.zeros_like(q_tensor, dtype=tensor_value.dtype) + v_tensor = np.zeros_like(k_tensor, dtype=tensor_value.dtype) + + j = 0 + for i in range(0 * chunk_size, hidden_size * 3, chunk_size * 3): + q_tensor[j : j + chunk_size] = tensor_value[i : i + chunk_size] + j += chunk_size + + j = 0 + for i in range(1 * chunk_size, hidden_size * 3, chunk_size * 3): + k_tensor[j : j + chunk_size] = tensor_value[i : i + chunk_size] + j += chunk_size + + j = 0 + for i in range(2 * chunk_size, hidden_size * 3, chunk_size * 3): + v_tensor[j : j + chunk_size] = tensor_value[i : i + chunk_size] + j += chunk_size + + return [ + np.ascontiguousarray(q_tensor.transpose()), + np.ascontiguousarray(k_tensor.transpose()), + np.ascontiguousarray(v_tensor.transpose()), + ] diff --git a/mediapipe/model_maker/python/llm/safetensors_converter.py b/mediapipe/model_maker/python/llm/safetensors_converter.py new file mode 100644 index 0000000000..3419169b0b --- /dev/null +++ b/mediapipe/model_maker/python/llm/safetensors_converter.py @@ -0,0 +1,315 @@ +# Copyright 2024 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CkptLoader implementation for loading the Safetensors.""" + +import array +import enum +import json +import os +from typing import List, Optional + +import numpy as np +import torch + +from mediapipe.model_maker.python.llm import converter_base + + +class LayerType(enum.Enum): + """Enum for layer type.""" + + NONE = 0 + ATTENTION = 1 # Layer is part of the attention module. + FEEDFORWARD = 2 # Layer is part of the feedforward module in the Transformer. + EMBEDDING = 3 # Layer is the embedding lookup or final projection layer. + LAYER_NORM = ( + 4 # Layer is layer normalization before and after attention layer. + ) + + @classmethod + def get_layer_type(cls, layer_name: str): + """Gets the layer type of the given layer name.""" + ffn_layers = [ + "mlp", + ] + attn_layers = [ + "self_attn", + ] + emb_layers = [ + "embed_tokens", + "lm_head", + ] + layer_norms = [ + "input_layernorm", + "post_attention_layernorm", + "final_layernorm", + ] + if any(sub_name in layer_name for sub_name in attn_layers): + return LayerType.ATTENTION + if any(sub_name in layer_name for sub_name in ffn_layers): + return LayerType.FEEDFORWARD + if any(sub_name in layer_name for sub_name in emb_layers): + return LayerType.EMBEDDING + if any(sub_name in layer_name for sub_name in layer_norms): + return LayerType.LAYER_NORM + else: + return LayerType.NONE + + +class StablelmMapper(converter_base.LayerActionMapperBase): + """LayerActionMapper for handling the StableLM model.""" + + # we don't quantize layer norm for stablelm model. + NON_QUANTIZED_LAYERS = [ + "model.norm.weight", + "input_layernorm", + "post_attention_layernorm", + ] + + def map_to_actions( + self, layer_name: str + ) -> Optional[converter_base.QuantizationAction]: + """Map the given layer name to actions.""" + quantize_axis = None + quantize_bits = None + layer_type = LayerType.get_layer_type(layer_name) + + if layer_type != LayerType.LAYER_NORM and layer_name.endswith(".weight"): + quantize_axis = [0] + if layer_type == LayerType.FEEDFORWARD: + quantize_bits = self._feedforward_quant_bits + elif layer_type == LayerType.ATTENTION: + quantize_bits = self._attention_quant_bits + elif layer_type == LayerType.EMBEDDING: + quantize_bits = self._embedding_quant_bits + target_name = self.update_target_name(layer_name) + + return converter_base.QuantizationAction( + tensor_name=layer_name, + target_name=target_name, + quantize_axis=quantize_axis, + quantize_bits=quantize_bits, + pack_dim=0, + ) + + def update_target_name(self, target_name: str) -> str: + """Updates the target name to match the tensor name convention.""" + target_name = target_name.replace( + "model.layers.", "params.lm.transformer.x_layers_" + ) + target_name = target_name.replace("mlp.up_proj", "ff_layer.ffn_layer1") + target_name = target_name.replace("mlp.down_proj", "ff_layer.ffn_layer2") + target_name = target_name.replace( + "mlp.gate_proj", "ff_layer.ffn_layer1_gate" + ) + target_name = target_name.replace("input_layernorm", "pre_layer_norm") + target_name = target_name.replace( + "pre_layer_norm.weight", "pre_layer_norm.scale" + ) + target_name = target_name.replace( + "post_attention_layernorm", "post_layer_norm" + ) + target_name = target_name.replace( + "post_layer_norm.weight", "post_layer_norm.scale" + ) + target_name = target_name.replace("self_attn.q_proj", "self_attention.q") + target_name = target_name.replace("self_attn.k_proj", "self_attention.k") + target_name = target_name.replace("self_attn.v_proj", "self_attention.v") + target_name = target_name.replace("self_attn.o_proj", "self_attention.post") + target_name = target_name.replace( + "model.embed_tokens", "params.lm.token_embedding" + ) + target_name = target_name.replace("model.norm", "params.lm.final_ln") + target_name = target_name.replace("final_ln.weight", "final_ln.scale") + target_name = target_name.replace("lm_head", "params.lm.softmax.logits_ffn") + target_name = target_name.replace(".weight", ".w") + + return target_name + + +class PhiMapper(converter_base.LayerActionMapperBase): + """LayerActionMapper for handling the Phi model.""" + + def map_to_actions( + self, layer_name: str + ) -> Optional[converter_base.QuantizationAction]: + """Map the given layer name to actions.""" + quantize_axis = None + quantize_bits = None + layer_type = LayerType.get_layer_type(layer_name) + + if layer_type != LayerType.LAYER_NORM and layer_name.endswith(".weight"): + quantize_axis = [0] + if layer_type == LayerType.FEEDFORWARD: + quantize_bits = self._feedforward_quant_bits + elif layer_type == LayerType.ATTENTION: + quantize_bits = self._attention_quant_bits + elif layer_type == LayerType.EMBEDDING: + quantize_bits = self._embedding_quant_bits + target_name = self.update_target_name(layer_name) + + return converter_base.QuantizationAction( + tensor_name=layer_name, + target_name=target_name, + quantize_axis=quantize_axis, + quantize_bits=quantize_bits, + pack_dim=0, + ) + + def update_target_name(self, target_name: str) -> str: + """Updates the target name to match the tensor name convention.""" + target_name = target_name.replace( + "model.layers.", "params.lm.transformer.x_layers_" + ) + + layer_type = LayerType.get_layer_type(target_name) + if layer_type == LayerType.FEEDFORWARD: + target_name = target_name.replace(".weight", ".linear.w") + target_name = target_name.replace(".bias", ".bias.b") + target_name = target_name.replace("mlp.fc1", "ff_layer.ffn_layer1") + target_name = target_name.replace("mlp.fc2", "ff_layer.ffn_layer2") + + elif layer_type == LayerType.ATTENTION: + target_name = target_name.replace(".weight", ".linear.w") + target_name = target_name.replace(".bias", ".bias.b") + target_name = target_name.replace("self_attn.q_proj", "self_attention.q") + target_name = target_name.replace("self_attn.k_proj", "self_attention.k") + target_name = target_name.replace("self_attn.v_proj", "self_attention.v") + target_name = target_name.replace( + "self_attn.dense", "self_attention.post" + ) + elif layer_type == LayerType.EMBEDDING: + target_name = target_name.replace( + "model.embed_tokens", "params.lm.token_embedding" + ) + target_name = target_name.replace( + "lm_head", "params.lm.softmax.logits_ffn" + ) + target_name = target_name.replace( + "logits_ffn.weight", "logits_ffn.linear.w" + ) + target_name = target_name.replace("logits_ffn.bias", "logits_ffn.bias.b") + elif layer_type == LayerType.LAYER_NORM: + target_name = target_name.replace("input_layernorm", "pre_layer_norm") + target_name = target_name.replace( + "pre_layer_norm.weight", "pre_layer_norm.scale" + ) + target_name = target_name.replace( + "model.final_layernorm", "params.lm.final_ln" + ) + target_name = target_name.replace("final_ln.weight", "final_ln.scale") + target_name = target_name.replace(".weight", ".w") + return target_name + + +DTYPE_MAP = { + "F16": torch.float16, + "BF16": torch.bfloat16, + "F32": torch.float32, +} + + +class SafetensorsCkptLoader(converter_base.CkptLoaderBase): + """CkptLoader implementation for loading the Safetensors.""" + + _HEAD_BYTES = 8 + + def __init__( + self, + ckpt_path: str, + is_symmetric: bool, + attention_quant_bits: int, + feedforward_quant_bits: int, + embedding_quant_bits: int, + special_model: str, + ): + """Initializes the loader. + + Args: + ckpt_path: The filepath to the safetensors file. + is_symmetric: Whether to apply symmetric or asymmetric quantization. + attention_quant_bits: An integer that specify the target quantization bits + (support 8 or 4) for the attention layers. + feedforward_quant_bits: An integer that specify the target quantization + bits (support 8 or 4) for the feedforward layers in each Transformer + blocks. + embedding_quant_bits: An integer that specify the target quantization bits + (support 8 or 4) for the embedding (and the final projection) layers. + special_model: A string that indicates which input model is and whether + any special treatment is needed. + """ + super().__init__( + ckpt_path, + is_symmetric, + attention_quant_bits, + feedforward_quant_bits, + embedding_quant_bits, + ) + + self._special_model = special_model + if special_model in ["STABLELM_4E1T_3B"]: + self.mapper = StablelmMapper( + is_symmetric, + attention_quant_bits, + feedforward_quant_bits, + embedding_quant_bits, + ) + elif special_model in ["PHI_2"]: + self.mapper = PhiMapper( + is_symmetric, + attention_quant_bits, + feedforward_quant_bits, + embedding_quant_bits, + ) + else: + raise ValueError(f"Unknown special model: {special_model}") + + self._ckpt_path = ckpt_path + if not os.path.exists(self._ckpt_path): + raise ValueError(f"{self._ckpt_path} does not exists.") + with open(self._ckpt_path, "rb") as f: + head_bytes = f.read(self._HEAD_BYTES) + metadata_bytes_num = np.frombuffer(head_bytes, dtype=np.uint64)[0] + metadata_bytes = f.read(metadata_bytes_num) + self.layers_info = json.loads(metadata_bytes) + self.metadata_bytes_num = metadata_bytes_num + + def load_to_actions(self) -> List[converter_base.QuantizationAction]: + tensor_names = self.layers_info.keys() + actions = [] + for tensor_name in tensor_names: + if tensor_name == "__metadata__": + continue + action = self.mapper.map_to_actions(tensor_name) + if action is None: + continue + action.tensor_value = self._read_tensor_as_numpy(tensor_name) + actions.append(action) + return actions + + def _read_tensor_as_numpy(self, tensor_name) -> np.ndarray: + """Reads a tensor from the model file as a numpy array with np.float32 type.""" + tensor_info = self.layers_info[tensor_name] + with open(self._ckpt_path, "rb") as f: + shape = tensor_info["shape"] + dtype = tensor_info["dtype"] + if dtype not in DTYPE_MAP: + raise ValueError(f"{dtype} is not supported.") + data_offsets = tensor_info["data_offsets"] + f.seek(int(self._HEAD_BYTES + self.metadata_bytes_num + data_offsets[0])) + tensor_bytes = f.read(data_offsets[1] - data_offsets[0]) + raw_tensor = torch.frombuffer( + array.array("b", tensor_bytes), dtype=DTYPE_MAP[dtype] + ).reshape(shape) + return raw_tensor.float().t().contiguous().numpy() diff --git a/mediapipe/model_maker/python/llm/weight_bins_writer.py b/mediapipe/model_maker/python/llm/weight_bins_writer.py new file mode 100644 index 0000000000..37e6573da5 --- /dev/null +++ b/mediapipe/model_maker/python/llm/weight_bins_writer.py @@ -0,0 +1,111 @@ +# Copyright 2024 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ModelWriter for writing a set of weights as binary files.""" + +import contextlib +import os +from typing import Dict, Tuple + +import numpy as np + +from mediapipe.model_maker.python.llm import converter_base +from mediapipe.model_maker.python.llm import quantization_util + + +@contextlib.contextmanager +def filemanager(filename: str, mode: str): + try: + with open(filename, mode) as f: + yield f + finally: + pass + + +def removeprefix(s, prefix): + """Removes the prefix from a string.""" + if s.startswith(prefix): + return s[len(prefix) :] + return s + + +class WeightBinsWriter(converter_base.ModelWriterBase): + """A ModelWriter for writing a set of weights as binary files.""" + + def get_weight_info(self, var_name: str, weight: np.ndarray) -> str: + """Gets the string that describes the weights.""" + dtype_str = str(weight.dtype) + shape_str = '_'.join(map(str, weight.shape)) + return f'mdl_vars.{var_name}.{dtype_str}.{shape_str}\n' + + def write_variables(self, variables: Dict[str, Tuple[np.ndarray, bool]]): + """Writes variable to the binary files. One for each layer. + + Args: + variables: A dictionary that maps from the target variable names to the + quantized tensor values along with a boolean that indicates whether to + pack the values (only applicable for the 4-bit quantized tensors). + """ + weights_info = [] + for var_name, value in variables.items(): + output = value[0] + if value[1]: + # Squeeze the tensor to make sure it is a 1D array for packing. + output = np.expand_dims(np.ravel(output), axis=-1) + # Extra pack needed for 4 bit. We always pack the weights along the + # first dimension since the tensor has already been squeezed. + output = quantization_util.pack_4bit(output, 0) + if 'combined_qkv' in var_name: + var_name = removeprefix(var_name, 'mld_vars.') + var_name_q = var_name.replace('combined_qkv', 'q') + var_name_k = var_name.replace('combined_qkv', 'k') + var_name_v = var_name.replace('combined_qkv', 'v') + if output.shape[0] == 3: + weight_q, weight_k, weight_v = output + assert weight_q.shape == weight_k.shape == weight_v.shape + else: # LoRA right weight is shared across q, k, v + weight_q = weight_k = weight_v = output + weights_info.append(self.get_weight_info(var_name_q, weight_q)) + path_q = os.path.join(self._output_dir, var_name_q) + with filemanager(path_q, 'wb') as f: + f.write(weight_q.tobytes()) + weights_info.append(self.get_weight_info(var_name_k, weight_k)) + path_k = os.path.join(self._output_dir, var_name_k) + with filemanager(path_k, 'wb') as f: + f.write(weight_k.tobytes()) + path_v = os.path.join(self._output_dir, var_name_v) + with filemanager(path_v, 'wb') as f: + f.write(weight_v.tobytes()) + weights_info.append(self.get_weight_info(var_name_v, weight_v)) + else: + if 'key' in var_name: + var_name = var_name.replace('key', 'k') + if 'query' in var_name: + var_name = var_name.replace('query', 'q') + if 'value' in var_name: + var_name = var_name.replace('value', 'v') + path = os.path.join( + self._output_dir, removeprefix(var_name, 'mdl_vars.') + ) + with filemanager(path, 'wb') as f: + f.write(output.tobytes()) + weights_info.append(self.get_weight_info(var_name, output)) + + # Sort weights_info + weights_info.sort() + with filemanager( + os.path.join(self._output_dir, 'layer_info.txt'), 'w' + ) as finfo: + for line in weights_info: + finfo.write(line + '\n') diff --git a/mediapipe/model_maker/python/llm/weight_bins_writer_test.py b/mediapipe/model_maker/python/llm/weight_bins_writer_test.py new file mode 100644 index 0000000000..76b34e2b3f --- /dev/null +++ b/mediapipe/model_maker/python/llm/weight_bins_writer_test.py @@ -0,0 +1,62 @@ +# Copyright 2024 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for pax_converter.""" + +import os + +from absl import flags +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np + +from mediapipe.model_maker.python.llm import weight_bins_writer + + +class WeightBinsWriterTest(parameterized.TestCase): + + def test_get_weight_info(self): + output_dir = os.path.join(flags.FLAGS.test_tmpdir, 'output_dir') + writer = weight_bins_writer.WeightBinsWriter( + output_dir=output_dir, backend='xnnpack' + ) + var_name = 'params.lm.softmax.logits_ffn.linear.w' + weight_info = writer.get_weight_info( + var_name, np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + ) + self.assertEqual( + weight_info, + 'mdl_vars.params.lm.softmax.logits_ffn.linear.w.float32.2_3\n', + ) + + def test_load_to_actions(self): + output_dir = os.path.join(flags.FLAGS.test_tmpdir, 'output_dir') + writer = weight_bins_writer.WeightBinsWriter( + output_dir=output_dir, backend='xnnpack' + ) + variables = { + 'mdl_vars.params.lm.softmax.logits_ffn.linear.w': ( + np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32), + False, + ), + } + writer.write_variables(variables) + file_size = os.path.getsize( + os.path.join(output_dir, 'params.lm.softmax.logits_ffn.linear.w') + ) + self.assertEqual(file_size, 6 * 4) + + +if __name__ == '__main__': + absltest.main()