Skip to content

Commit

Permalink
py: add support for huggingface imports
Browse files Browse the repository at this point in the history
Signed-off-by: Isabella Basso do Amaral <[email protected]>
  • Loading branch information
isinyaaa committed Jan 17, 2024
1 parent fc3ccbc commit 0497cf4
Show file tree
Hide file tree
Showing 6 changed files with 284 additions and 2 deletions.
35 changes: 35 additions & 0 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,41 @@ version = registry.get_model_version("my-model", "v2.0")
experiment = registry.get_model_artifact("my-model", "v2.0")
```

### Importing from Hugging Face Hub

To import models from Hugging Face Hub, start by installing the `huggingface-hub` package, either directly or as an
extra (available as `model-registry[hf]`).
Models can be imported with

```py
hf_model = registry.register_hf_model(
"hf-namespace/hf-model", # HF repo
"relative/path/to/model/file.onnx",
version="1.2.3",
model_name="my-model",
description="lorem ipsum",
model_format_name="onnx",
model_format_version="1",
)
```

There are caveats to be noted when using this method:

- It's only possible to import a single model file per Hugging Face Hub repo right now.
- If the model you want to import is in a global namespace, you should provide an author, e.g.

```py
hf_model = registry.register_hf_model(
"gpt2", # this model implicitly has no author
"onnx/decoder_model.onnx",
author="OpenAI", # Defaults to unknown in the absence of an author
version="1.0.0",
description="gpt-2 model",
model_format_name="onnx",
model_format_version="1",
)
```

## Development

Common tasks, such as building documentation and running tests, can be executed using [`nox`](https://github.com/wntrblm/nox) sessions.
Expand Down
9 changes: 8 additions & 1 deletion clients/python/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@ def mypy(session: Session) -> None:
def tests(session: Session) -> None:
"""Run the test suite."""
session.install(".")
session.install("coverage[toml]", "pytest", "pytest-cov", "pygments", "testcontainers")
session.install(
"coverage[toml]",
"pytest",
"pytest-cov",
"pygments",
"testcontainers",
"huggingface-hub",
)
try:
session.run(
"pytest",
Expand Down
108 changes: 107 additions & 1 deletion clients/python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ ml-metadata = "^1.14.0"
# ml-metadata = { url = "https://github.com/tarilabs/ml-metadata-remote/releases/download/1.14.0/ml_metadata-1.14.0-py3-none-any.whl" }
typing-extensions = "^4.8"

huggingface-hub = { version = "^0.20.1", optional = true }

[tool.poetry.extras]
hf = ["huggingface-hub"]

[tool.poetry.group.dev.dependencies]
sphinx = "^7.2.6"
furo = "^2023.9.10"
Expand Down
95 changes: 95 additions & 0 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Standard client for the model registry."""
from __future__ import annotations

from typing import get_args
from warnings import warn

from .core import ModelRegistryAPIClient
from .exceptions import StoreException
from .store import ScalarType
Expand Down Expand Up @@ -120,6 +123,98 @@ def register_model(

return rm

def register_hf_model(
self,
repo: str,
path: str,
*,
version: str,
model_format_name: str,
model_format_version: str,
author: str | None = None,
model_name: str | None = None,
description: str | None = None,
git_ref: str = "main",
) -> RegisteredModel:
"""Register a Hugging Face model.
This imports a model from Hugging Face hub and registers it in the model registry.
Note that the model is not downloaded.
Args:
repo: Name of the repository from Hugging Face hub.
path: URI of the model.
Keyword Args:
version: Version of the model. Has to be unique.
model_format_name: Name of the model format.
model_format_version: Version of the model format.
author: Author of the model. Defaults to repo owner.
model_name: Name of the model. Defaults to the repo name.
description: Description of the model.
git_ref: Git reference to use. Defaults to `main`.
Returns:
Registered model.
"""
try:
from huggingface_hub import HfApi, hf_hub_url, utils
except ImportError as e:
msg = "huggingface_hub is not installed"
raise StoreException(msg) from e

api = HfApi()
try:
model_info = api.model_info(repo, revision=git_ref)
except utils.RepositoryNotFoundError as e:
msg = f"Repository {repo} does not exist"
raise StoreException(msg) from e
except utils.RevisionNotFoundError as e:
# TODO: as all hf-hub client calls default to using main, should we provide a tip?
msg = f"Revision {git_ref} does not exist"
raise StoreException(msg) from e

if not author:
# model author can be None if the repo is in a "global" namespace (i.e. no / in repo).
if model_info.author is None:
model_author = "unknown"
warn(
"Model author is unknown. This is likely because the model is in a global namespace.",
stacklevel=2,
)
else:
model_author = model_info.author
else:
model_author = author
source_uri = hf_hub_url(repo, path, revision=git_ref)
metadata = {
"repo": repo,
"source_uri": source_uri,
"model_origin": "huggingface_hub",
"model_author": model_author,
}
# card_data is the new field, but let's use the old one for backwards compatibility.
if card_data := model_info.cardData:
metadata.update(
{
k: v
for k, v in card_data.to_dict().items()
# TODO: (#151) preserve tags, possibly other complex metadata
if isinstance(v, get_args(ScalarType))
}
)
return self.register_model(
model_name or model_info.id,
source_uri,
author=author or model_author,
version=version,
model_format_name=model_format_name,
model_format_version=model_format_version,
description=description,
storage_path=path,
metadata=metadata,
)

def get_registered_model(self, name: str) -> RegisteredModel | None:
"""Get a registered model.
Expand Down
34 changes: 34 additions & 0 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,37 @@ def test_get(mr_client: ModelRegistry):
assert mv.id == _mv.id
assert (_ma := mr_client.get_model_artifact(name, version))
assert ma.id == _ma.id


def test_hf_import(mr_client: ModelRegistry):
pytest.importorskip("huggingface_hub")
name = "gpt2"
version = "1.2.3"

assert mr_client.register_hf_model(
name,
"onnx/decoder_model.onnx",
author="test author",
version=version,
model_format_name="test format",
model_format_version="test version",
)
assert mr_client.get_model_version(name, version)
assert mr_client.get_model_artifact(name, version)


def test_hf_import_missing_author(mr_client: ModelRegistry):
pytest.importorskip("huggingface_hub")
name = "gpt2"
version = "1.2.3"

with pytest.warns(match=r".*author is unknown.*"):
assert mr_client.register_hf_model(
name,
"onnx/decoder_model.onnx",
version=version,
model_format_name="test format",
model_format_version="test version",
)
assert (mv := mr_client.get_model_version(name, version))
assert mv.author == "unknown"

0 comments on commit 0497cf4

Please sign in to comment.