From 2164c9deb3d2cfe80ae282e32de0def19d95463e Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 19 Jul 2024 12:30:37 -0400 Subject: [PATCH 1/5] gguf-py : fix some metadata name extraction edge cases * convert_lora : use the lora dir for the model card path --- convert_hf_to_gguf.py | 4 +++- convert_lora_to_gguf.py | 24 +++++++++++++++++------- gguf-py/gguf/metadata.py | 25 +++++++++++++++++++++---- gguf-py/tests/test_metadata.py | 18 ++++++++++++++++-- 4 files changed, 57 insertions(+), 14 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e3c8aac3d32aa..03b84f24da001 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -62,6 +62,7 @@ class Model: gguf_writer: gguf.GGUFWriter model_name: str | None metadata_override: Path | None + dir_model_card: Path # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -90,6 +91,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path | self.tensor_names = None self.metadata_override = metadata_override self.model_name = model_name + self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type if self.ftype == gguf.LlamaFileType.GUESSED: @@ -345,7 +347,7 @@ def prepare_metadata(self, vocab_only: bool): total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count() - self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model, self.model_name, total_params) + self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model, self.model_name, self.dir_model_card, total_params) # Fallback to model directory name if metadata name is still missing if self.metadata.name is None: diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index 66e8da37cba7c..2b84009f9b9a1 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -304,12 +304,6 @@ def parse_args() -> argparse.Namespace: # load base model logger.info(f"Loading base model: {dir_base_model.name}") hparams = Model.load_hparams(dir_base_model) - - with open(lora_config, "r") as f: - lparams: dict[str, Any] = json.load(f) - - alpha: float = lparams["lora_alpha"] - with torch.inference_mode(): try: model_class = Model.from_model_architecture(hparams["architectures"][0]) @@ -320,12 +314,21 @@ def parse_args() -> argparse.Namespace: class LoraModel(model_class): model_arch = model_class.model_arch + lora_alpha: float + + def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs): + + super().__init__(*args, **kwargs) + + self.dir_model_card = dir_lora_model + self.lora_alpha = float(lora_alpha) + def set_type(self): self.gguf_writer.add_type(gguf.GGUFType.ADAPTER) self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora") def set_gguf_parameters(self): - self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha)) + self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) super().set_gguf_parameters() def get_tensors(self) -> Iterator[tuple[str, Tensor]]: @@ -368,6 +371,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield (dest_name + ".lora_a", lora_a) yield (dest_name + ".lora_b", lora_b) + with open(lora_config, "r") as f: + lparams: dict[str, Any] = json.load(f) + + alpha: float = lparams["lora_alpha"] + model_instance = LoraModel( dir_base_model, ftype, @@ -376,6 +384,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter use_temp_file=False, eager=args.no_lazy, dry_run=args.dry_run, + dir_lora_model=dir_lora, + lora_alpha=alpha, ) logger.info("Exporting model...") diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index bac6ebfb3777a..a3607ef1dccbe 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -44,7 +44,7 @@ class Metadata: datasets: Optional[list[str]] = None @staticmethod - def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata: + def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, model_card_path: Optional[Path] = None, total_params: int = 0) -> Metadata: # This grabs as many contextual authorship metadata as possible from the model repository # making any conversion as required to match the gguf kv store metadata format # as well as giving users the ability to override any authorship metadata that may be incorrect @@ -52,11 +52,14 @@ def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Pat # Create a new Metadata instance metadata = Metadata() - model_card = Metadata.load_model_card(model_path) + if model_card_path is None: + model_card_path = model_path + + model_card = Metadata.load_model_card(model_card_path) hf_params = Metadata.load_hf_parameters(model_path) # heuristics - metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params) + metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_card_path, total_params) # Metadata Override File Provided # This is based on LLM_KV_NAMES mapping in llama.cpp @@ -177,6 +180,12 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int = org_component = None name_parts: list[str] = model_full_name_component.split('-') + + # Remove empty parts + for i in reversed(range(len(name_parts))): + if len(name_parts[i]) == 0: + del name_parts[i] + name_types: list[ set[Literal["basename", "size_label", "finetune", "version", "type"]] ] = [set() for _ in name_parts] @@ -227,6 +236,13 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int = if part.lower() == "lora": name_parts[i] = "LoRA" + # Ignore word-based size labels when there is at least a number-based one present + if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n): + for n, t in zip(name_parts, name_types): + if "size_label" in t: + if all(c.isalpha() for c in n): + t.remove("size_label") + at_start = True # Find the basename through the annotated name for part, t in zip(name_parts, name_types): @@ -247,7 +263,8 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int = break basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None - size_label = "-".join(s for s, t in zip(name_parts, name_types) if "size_label" in t) or None + # Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys) + size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None # TODO: should the basename version always be excluded? # TODO: should multiple versions be joined together? diff --git a/gguf-py/tests/test_metadata.py b/gguf-py/tests/test_metadata.py index 3fac8218883f1..80ed35363d9b3 100755 --- a/gguf-py/tests/test_metadata.py +++ b/gguf-py/tests/test_metadata.py @@ -54,7 +54,7 @@ def test_get_model_id_components(self): self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"), ('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B')) - # Can't detect all non standard form in a heuristically safe way... best to err in caution and output nothing... + # Non standard naming self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"), ('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B')) @@ -71,7 +71,7 @@ def test_get_model_id_components(self): self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3), ('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K')) - # None standard and not easy to disambiguate + # Non standard and not easy to disambiguate self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"), ('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None)) @@ -123,6 +123,20 @@ def test_get_model_id_components(self): self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"), ('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B')) + # Ignore full-text size labels when there are number-based ones, and deduplicate size labels + self.assertEqual(gguf.Metadata.get_model_id_components("MaziyarPanahi/GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1"), + ('GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1', 'MaziyarPanahi', 'GreenNode-mini', 'multilingual-v1olet-Mistral-Instruct', 'v0.1', '7B')) + + # Version at the end with a long basename + self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/Mistral-Nemo-Base-2407"), + ('Mistral-Nemo-Base-2407', 'mistralai', 'Mistral-Nemo-Base', None, '2407', None)) + + ## Invalid cases ## + + # Start with a dash and has dashes in rows + self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/-Mistral--Nemo-Base-2407-"), + ('-Mistral--Nemo-Base-2407-', 'mistralai', 'Mistral-Nemo-Base', None, '2407', None)) + def test_apply_metadata_heuristic_from_model_card(self): model_card = { 'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'], From 912e6fa5c6bb286478ac7df53057482d9abb7db4 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 19 Jul 2024 13:46:41 -0400 Subject: [PATCH 2/5] gguf-py : more metadata edge cases fixes Multiple finetune versions are now joined together, and the removal of the basename annotation on trailing versions is more robust. --- gguf-py/gguf/metadata.py | 9 ++++----- gguf-py/tests/test_metadata.py | 22 ++++++++++++++++++---- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index a3607ef1dccbe..f13cf5403bde0 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -256,9 +256,8 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int = # Remove the basename annotation from trailing version for part, t in zip(reversed(name_parts), reversed(name_types)): - if "basename" in t: - if len(t) > 1: - t.remove("basename") + if "basename" in t and len(t) > 1: + t.remove("basename") else: break @@ -267,8 +266,8 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int = size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None # TODO: should the basename version always be excluded? - # TODO: should multiple versions be joined together? - version = ([v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t] or [None])[-1] + # NOTE: multiple finetune versions are joined together + version = "-".join(v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t) or None if size_label is None and finetune is None and version is None: # Too ambiguous, output nothing diff --git a/gguf-py/tests/test_metadata.py b/gguf-py/tests/test_metadata.py index 80ed35363d9b3..74863b40dcc8d 100755 --- a/gguf-py/tests/test_metadata.py +++ b/gguf-py/tests/test_metadata.py @@ -127,9 +127,23 @@ def test_get_model_id_components(self): self.assertEqual(gguf.Metadata.get_model_id_components("MaziyarPanahi/GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1"), ('GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1', 'MaziyarPanahi', 'GreenNode-mini', 'multilingual-v1olet-Mistral-Instruct', 'v0.1', '7B')) - # Version at the end with a long basename - self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/Mistral-Nemo-Base-2407"), - ('Mistral-Nemo-Base-2407', 'mistralai', 'Mistral-Nemo-Base', None, '2407', None)) + # Instruct in a name without a size label + self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/Mistral-Nemo-Instruct-2407"), + ('Mistral-Nemo-Instruct-2407', 'mistralai', 'Mistral-Nemo', 'Instruct', '2407', None)) + + # Non-obvious splitting relying on 'chat' keyword + self.assertEqual(gguf.Metadata.get_model_id_components("deepseek-ai/DeepSeek-V2-Chat-0628"), + ('DeepSeek-V2-Chat-0628', 'deepseek-ai', 'DeepSeek-V2', 'Chat', '0628', None)) + + # Multiple versions + self.assertEqual(gguf.Metadata.get_model_id_components("OpenGVLab/Mini-InternVL-Chat-2B-V1-5"), + ('Mini-InternVL-Chat-2B-V1-5', 'OpenGVLab', 'Mini-InternVL', 'Chat', 'V1-5', '2B')) + + # Too ambiguous + # TODO: should "base" be a 'finetune' or 'size_label'? + # (in this case it should be a size label, but other models use it to signal that they are not finetuned) + self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Florence-2-base"), + ('Florence-2-base', 'microsoft', None, None, None, None)) ## Invalid cases ## @@ -148,7 +162,7 @@ def test_apply_metadata_heuristic_from_model_card(self): } got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) expect = gguf.Metadata() - expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': 'v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}] + expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': '14-v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}] expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'] expect.languages=['en'] expect.datasets=['teknium/OpenHermes-2.5'] From a3d154b2604650424d8bec105928033dfed61222 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 20 Jul 2024 15:57:46 -0400 Subject: [PATCH 3/5] gguf-py : add more name metadata extraction tests --- gguf-py/tests/test_metadata.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/gguf-py/tests/test_metadata.py b/gguf-py/tests/test_metadata.py index 74863b40dcc8d..950e5383c28f7 100755 --- a/gguf-py/tests/test_metadata.py +++ b/gguf-py/tests/test_metadata.py @@ -139,6 +139,14 @@ def test_get_model_id_components(self): self.assertEqual(gguf.Metadata.get_model_id_components("OpenGVLab/Mini-InternVL-Chat-2B-V1-5"), ('Mini-InternVL-Chat-2B-V1-5', 'OpenGVLab', 'Mini-InternVL', 'Chat', 'V1-5', '2B')) + # TODO: DPO in the name + self.assertEqual(gguf.Metadata.get_model_id_components("jondurbin/bagel-dpo-2.8b-v0.2"), + ('bagel-dpo-2.8b-v0.2', 'jondurbin', 'bagel-dpo', None, 'v0.2', '2.8B')) + + # DPO in name, but can't be used for the finetune to keep 'LLaMA-3' in the basename + self.assertEqual(gguf.Metadata.get_model_id_components("voxmenthe/SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized"), + ('SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized', 'voxmenthe', 'SFR-Iterative-DPO-LLaMA-3', 'R-unquantized', None, '8B')) + # Too ambiguous # TODO: should "base" be a 'finetune' or 'size_label'? # (in this case it should be a size label, but other models use it to signal that they are not finetuned) From bf8e71b0c08b0d6c3dbc9e5cccf55d6bf06dfc4b Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 20 Jul 2024 16:40:58 -0400 Subject: [PATCH 4/5] convert_lora : fix default filename The default filename was previously hardcoded. * convert_hf : Model.fname_out can no longer be None --- convert_hf_to_gguf.py | 36 ++++++++++++++-------------------- convert_lora_to_gguf.py | 2 +- gguf-py/gguf/metadata.py | 19 +++++++++--------- gguf-py/tests/test_metadata.py | 9 +++++++++ 4 files changed, 35 insertions(+), 31 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 03b84f24da001..fa7ea609a97c7 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -48,7 +48,7 @@ class Model: dir_model: Path ftype: gguf.LlamaFileType - fname_out: Path | None + fname_out: Path is_big_endian: bool endianess: gguf.GGUFEndian use_temp_file: bool @@ -67,7 +67,7 @@ class Model: # subclasses should define this! model_arch: gguf.MODEL_ARCH - def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path | None, is_big_endian: bool = False, + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False): @@ -347,7 +347,7 @@ def prepare_metadata(self, vocab_only: bool): total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count() - self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model, self.model_name, self.dir_model_card, total_params) + self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model_card, self.model_name, total_params) # Fallback to model directory name if metadata name is still missing if self.metadata.name is None: @@ -361,27 +361,22 @@ def prepare_metadata(self, vocab_only: bool): output_type: str = self.ftype.name.partition("_")[2] # Filename Output - # Note: `not is_dir()` is used because `.is_file()` will not detect - # file template strings as it doesn't actually exist as a file - if self.fname_out is not None and not self.fname_out.is_dir(): - # Output path is a custom defined templated filename - - # Process templated file name with the output ftype, useful with the "auto" ftype - self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type) - else: + if self.fname_out.is_dir(): # Generate default filename based on model specification and available metadata if not vocab_only: fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type, model_type="LoRA" if total_params < 0 else None) else: fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=None, model_type="vocab") - # Check if preferred output directory path was provided - if self.fname_out is not None and self.fname_out.is_dir(): - # output path is a directory - self.fname_out = self.fname_out / f"{fname_default}.gguf" - else: - # output in the same directory as the model by default - self.fname_out = self.dir_model / f"{fname_default}.gguf" + # Use the default filename + self.fname_out = self.fname_out / f"{fname_default}.gguf" + else: + # Output path is a custom defined templated filename + # Note: `not is_dir()` is used because `.is_file()` will not detect + # file template strings as it doesn't actually exist as a file + + # Process templated file name with the output ftype, useful with the "auto" ftype + self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type) self.set_type() @@ -3626,10 +3621,10 @@ def main() -> None: logger.error("Error: Cannot use temp file when splitting") sys.exit(1) - fname_out = None - if args.outfile is not None: fname_out = args.outfile + else: + fname_out = dir_model logger.info(f"Loading model: {dir_model.name}") @@ -3660,7 +3655,6 @@ def main() -> None: else: logger.info("Exporting model...") model_instance.write() - assert model_instance.fname_out is not None out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out logger.info(f"Model successfully exported to {out_path}") diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index 2b84009f9b9a1..a88d0d4a978a9 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -290,7 +290,7 @@ def parse_args() -> argparse.Namespace: fname_out = args.outfile else: # output in the same directory as the model by default - fname_out = dir_lora / 'ggml-lora-{ftype}.gguf' + fname_out = dir_lora if os.path.exists(input_model): # lazy import load_file only if lora is in safetensors format. diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index f13cf5403bde0..15189f7177500 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -44,7 +44,7 @@ class Metadata: datasets: Optional[list[str]] = None @staticmethod - def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, model_card_path: Optional[Path] = None, total_params: int = 0) -> Metadata: + def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata: # This grabs as many contextual authorship metadata as possible from the model repository # making any conversion as required to match the gguf kv store metadata format # as well as giving users the ability to override any authorship metadata that may be incorrect @@ -52,14 +52,12 @@ def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Pat # Create a new Metadata instance metadata = Metadata() - if model_card_path is None: - model_card_path = model_path - - model_card = Metadata.load_model_card(model_card_path) + model_card = Metadata.load_model_card(model_path) hf_params = Metadata.load_hf_parameters(model_path) + # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter # heuristics - metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_card_path, total_params) + metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params) # Metadata Override File Provided # This is based on LLM_KV_NAMES mapping in llama.cpp @@ -232,11 +230,14 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int = name_parts[i] = part # Some easy to recognize finetune names elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE): - name_types[i].add("finetune") - if part.lower() == "lora": - name_parts[i] = "LoRA" + if total_params < 0 and part.lower() == "lora": + # ignore redundant "lora" in the finetune part when the output is a lora adapter + name_types[i].add("type") + else: + name_types[i].add("finetune") # Ignore word-based size labels when there is at least a number-based one present + # TODO: should word-based size labels always be removed instead? if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n): for n, t in zip(name_parts, name_types): if "size_label" in t: diff --git a/gguf-py/tests/test_metadata.py b/gguf-py/tests/test_metadata.py index 950e5383c28f7..81a2a30ae60f4 100755 --- a/gguf-py/tests/test_metadata.py +++ b/gguf-py/tests/test_metadata.py @@ -159,6 +159,15 @@ def test_get_model_id_components(self): self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/-Mistral--Nemo-Base-2407-"), ('-Mistral--Nemo-Base-2407-', 'mistralai', 'Mistral-Nemo-Base', None, '2407', None)) + ## LoRA ## + + self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B"), + ('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration-LoRA', None, '8B')) + + # Negative size --> output is a LoRA adaper --> prune "LoRA" out of the name to avoid redundancy with the suffix + self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B", -1234), + ('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration', None, '8B')) + def test_apply_metadata_heuristic_from_model_card(self): model_card = { 'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'], From 1932a1b871e97e25496985bf4b3b53d3e7227a16 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 20 Jul 2024 16:47:43 -0400 Subject: [PATCH 5/5] gguf-py : do not use title case for naming convention Some models use acronyms in lowercase, which can't be title-cased like other words, so it's best to simply use the same case as in the original model name. Note that the size label still has an uppercased suffix to make it distinguishable from the context size of a finetune. --- gguf-py/gguf/utility.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index ef76831b521ee..40d59b75ee04e 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -50,15 +50,15 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st # Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention if base_name is not None: - name = base_name.strip().title().replace(' ', '-').replace('/', '-') + name = base_name.strip().replace(' ', '-').replace('/', '-') elif model_name is not None: - name = model_name.strip().title().replace(' ', '-').replace('/', '-') + name = model_name.strip().replace(' ', '-').replace('/', '-') else: name = "ggml-model" parameters = f"-{size_label}" if size_label is not None else "" - finetune = f"-{finetune_string.strip().title().replace(' ', '-')}" if finetune_string is not None else "" + finetune = f"-{finetune_string.strip().replace(' ', '-')}" if finetune_string is not None else "" version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""