Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 602220793
  • Loading branch information
MediaPipe Team authored and copybara-github committed Jan 28, 2024
1 parent ecf4d4f commit 850da4e
Show file tree
Hide file tree
Showing 7 changed files with 885 additions and 1 deletion.
45 changes: 45 additions & 0 deletions mediapipe/model_maker/python/llm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,33 @@ py_library(
srcs = ["converter_base.py"],
)

py_library(
name = "converter_factory",
srcs = ["converter_factory.py"],
deps = [
":converter_base",
":pytorch_converter",
":safetensors_converter",
":weight_bins_writer",
],
)

py_library(
name = "pytorch_converter",
srcs = ["pytorch_converter.py"],
deps = [
":converter_base",
],
)

py_library(
name = "safetensors_converter",
srcs = ["safetensors_converter.py"],
deps = [
":converter_base",
],
)

py_library(
name = "quantization_util",
srcs = ["quantization_util.py"],
Expand All @@ -35,3 +62,21 @@ py_test(
":quantization_util",
],
)

py_library(
name = "weight_bins_writer",
srcs = ["weight_bins_writer.py"],
deps = [
":converter_base",
":quantization_util",
],
)

py_test(
name = "weight_bins_writer_test",
srcs = ["weight_bins_writer_test.py"],
srcs_version = "PY3",
deps = [
":weight_bins_writer",
],
)
2 changes: 1 addition & 1 deletion mediapipe/model_maker/python/llm/converter_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(
self._embedding_quant_bits = embedding_quant_bits

def map_to_actions(self, layer_name: str) -> Optional[QuantizationAction]:
"""""Maps the layer weights to quantization actions.
"""Maps the layer weights to quantization actions.
Args:
layer_name: A string representing the name of the layer weight. Note that
Expand Down
78 changes: 78 additions & 0 deletions mediapipe/model_maker/python/llm/converter_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2024 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility library that helps create the converter instances."""

from mediapipe.model_maker.python.llm import converter_base
from mediapipe.model_maker.python.llm import pytorch_converter
from mediapipe.model_maker.python.llm import safetensors_converter
from mediapipe.model_maker.python.llm import weight_bins_writer


def create_ckpt_loader(
ckpt_format: str, *args, **kwargs
) -> converter_base.CkptLoaderBase:
"""Creates the checkpoint loader.
Args:
ckpt_format: A string that indicates which input checkpoint format is.
*args: Additional arguments to be passed into the loader.
**kwargs: Additional arguments to be passed into the loader.
Returns:
A created CkptLoader instance.
"""
del args
if ckpt_format == "pytorch":
return pytorch_converter.PytorchCkptLoader(
ckpt_path=kwargs["ckpt_path"],
is_symmetric=kwargs["is_symmetric"],
attention_quant_bits=kwargs["attention_quant_bits"],
feedforward_quant_bits=kwargs["feedforward_quant_bits"],
embedding_quant_bits=kwargs["embedding_quant_bits"],
special_model=kwargs["special_model"],
)
elif ckpt_format == "safetensors":
return safetensors_converter.SafetensorsCkptLoader(
ckpt_path=kwargs["ckpt_path"],
is_symmetric=kwargs["is_symmetric"],
attention_quant_bits=kwargs["attention_quant_bits"],
feedforward_quant_bits=kwargs["feedforward_quant_bits"],
embedding_quant_bits=kwargs["embedding_quant_bits"],
special_model=kwargs["special_model"],
)
else:
raise ValueError(f"Unknown checkpoint format: {ckpt_format}")


def create_writer(
writer_type: str, *args, **kwargs
) -> converter_base.ModelWriterBase:
"""Creates the model writer.
Args:
writer_type: A string the indicates which model writer to create.
*args: Additional arguments to be passed into the loader.
**kwargs: Additional arguments to be passed into the loader.
Returns:
A created ModelWriter instance.
"""
del args
if writer_type == "weight_bins":
return weight_bins_writer.WeightBinsWriter(
output_dir=kwargs["output_dir"], backend=kwargs["backend"]
)
else:
raise ValueError(f"Unknown writer type: {writer_type}")
Loading

0 comments on commit 850da4e

Please sign in to comment.