Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Py: Add support for importing models from Hugging Face Hub #260

Merged
merged 7 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 52 additions & 9 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,23 @@ from model_registry import ModelRegistry

registry = ModelRegistry(server_address="server-address", port=9090, author="author")

model = registry.register_model("my-model",
"s3://path/to/model",
model_format_name="onnx",
model_format_version="1",
storage_key="aws-connection-path",
storage_path="to/model",
version="v2.0",
description="lorem ipsum",
)
model = registry.register_model(
"my-model", # model name
"s3://path/to/model", # model URI
version="2.0.0",
description="lorem ipsum",
model_format_name="onnx",
model_format_version="1",
storage_key="aws-connection-path",
storage_path="path/to/model",
metadata={
# can be one of the following types
"int_key": 1,
"bool_key": False,
"float_key": 3.14,
"str_key": "str_value",
}
)

model = registry.get_registered_model("my-model")

Expand All @@ -29,6 +37,41 @@ version = registry.get_model_version("my-model", "v2.0")
experiment = registry.get_model_artifact("my-model", "v2.0")
```

### Importing from Hugging Face Hub

To import models from Hugging Face Hub, start by installing the `huggingface-hub` package, either directly or as an
extra (available as `model-registry[hf]`).
Models can be imported with

```py
hf_model = registry.register_hf_model(
"hf-namespace/hf-model", # HF repo
"relative/path/to/model/file.onnx",
version="1.2.3",
model_name="my-model",
description="lorem ipsum",
model_format_name="onnx",
model_format_version="1",
)
```

There are caveats to be noted when using this method:

- It's only possible to import a single model file per Hugging Face Hub repo right now.
- If the model you want to import is in a global namespace, you should provide an author, e.g.

```py
hf_model = registry.register_hf_model(
"gpt2", # this model implicitly has no author
"onnx/decoder_model.onnx",
author="OpenAI", # Defaults to unknown in the absence of an author
version="1.0.0",
description="gpt-2 model",
model_format_name="onnx",
model_format_version="1",
)
```

## Development

Common tasks, such as building documentation and running tests, can be executed using [`nox`](https://github.com/wntrblm/nox) sessions.
Expand Down
9 changes: 8 additions & 1 deletion clients/python/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@ def mypy(session: Session) -> None:
def tests(session: Session) -> None:
"""Run the test suite."""
session.install(".")
session.install("coverage[toml]", "pytest", "pytest-cov", "pygments", "testcontainers")
session.install(
"coverage[toml]",
"pytest",
"pytest-cov",
"pygments",
"testcontainers",
"huggingface-hub",
)
try:
session.run(
"pytest",
Expand Down
108 changes: 107 additions & 1 deletion clients/python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ ml-metadata = "^1.14.0"
# ml-metadata = { url = "https://github.com/tarilabs/ml-metadata-remote/releases/download/1.14.0/ml_metadata-1.14.0-py3-none-any.whl" }
typing-extensions = "^4.8"

huggingface-hub = { version = "^0.20.1", optional = true }

[tool.poetry.extras]
hf = ["huggingface-hub"]

[tool.poetry.group.dev.dependencies]
sphinx = "^7.2.6"
furo = "^2023.9.10"
Expand Down
Loading
Loading