Skip to content

Commit

Permalink
py: extract model id component for base model and datasets if using h…
Browse files Browse the repository at this point in the history
…uggingface url
  • Loading branch information
mofosyne committed Aug 21, 2024
1 parent 1d295cf commit b8d0e03
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 5 deletions.
39 changes: 38 additions & 1 deletion gguf-py/gguf/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,27 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
if isinstance(model_id, str):
if model_id.startswith("http://") or model_id.startswith("https://") or model_id.startswith("ssh://"):
base_model["repo_url"] = model_id

# Check if Hugging Face ID is present in URL
if "huggingface.co" in model_id:
match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", model_id)
if match:
model_id_component = match.group(1)
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id_component, total_params)

# Populate model dictionary with extracted components
if model_full_name_component is not None:
base_model["name"] = Metadata.id_to_title(model_full_name_component)
if org_component is not None:
base_model["organization"] = Metadata.id_to_title(org_component)
if version is not None:
base_model["version"] = version

else:
# Likely a Hugging Face ID
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)

# Populate model dictionary with extracted components
if model_full_name_component is not None:
base_model["name"] = Metadata.id_to_title(model_full_name_component)
if org_component is not None:
Expand Down Expand Up @@ -405,11 +423,29 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
# NOTE: model size of base model is assumed to be similar to the size of the current model
dataset = {}
if isinstance(dataset_id, str):
if dataset_id.startswith("http://") or dataset_id.startswith("https://") or dataset_id.startswith("ssh://"):
if dataset_id.startswith(("http://", "https://", "ssh://")):
dataset["repo_url"] = dataset_id

# Check if Hugging Face ID is present in URL
if "huggingface.co" in dataset_id:
match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", dataset_id)
if match:
dataset_id_component = match.group(1)
dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id_component, total_params)

# Populate dataset dictionary with extracted components
if dataset_name_component is not None:
dataset["name"] = Metadata.id_to_title(dataset_name_component)
if org_component is not None:
dataset["organization"] = Metadata.id_to_title(org_component)
if version is not None:
dataset["version"] = version

else:
# Likely a Hugging Face ID
dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params)

# Populate dataset dictionary with extracted components
if dataset_name_component is not None:
dataset["name"] = Metadata.id_to_title(dataset_name_component)
if org_component is not None:
Expand All @@ -418,6 +454,7 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
dataset["version"] = version
if org_component is not None and dataset_name_component is not None:
dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}"

elif isinstance(dataset_id, dict):
dataset = dataset_id
else:
Expand Down
8 changes: 4 additions & 4 deletions gguf-py/tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,14 @@ def test_apply_metadata_heuristic_from_model_card(self):
self.assertEqual(got, expect)

# Base Model spec is inferred from model id
model_card = {'base_models': ['teknium/OpenHermes-2.5']}
model_card = {'base_models': 'teknium/OpenHermes-2.5'}
expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
self.assertEqual(got, expect)

# Base Model spec is only url
model_card = {'base_models': ['https://huggingface.co/teknium/OpenHermes-2.5']}
expect = gguf.Metadata(base_models=[{'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
self.assertEqual(got, expect)

Expand All @@ -204,14 +204,14 @@ def test_apply_metadata_heuristic_from_model_card(self):
self.assertEqual(got, expect)

# Dataset spec is inferred from model id
model_card = {'datasets': ['teknium/OpenHermes-2.5']}
model_card = {'datasets': 'teknium/OpenHermes-2.5'}
expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
self.assertEqual(got, expect)

# Dataset spec is only url
model_card = {'datasets': ['https://huggingface.co/teknium/OpenHermes-2.5']}
expect = gguf.Metadata(datasets=[{'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
self.assertEqual(got, expect)

Expand Down

0 comments on commit b8d0e03

Please sign in to comment.