Skip to content

Commit

Permalink
py: default metadata capture environment vars
Browse files Browse the repository at this point in the history
  • Loading branch information
tarilabs committed Feb 14, 2024
1 parent dbcacad commit 387d88a
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 5 deletions.
5 changes: 5 additions & 0 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ version = registry.get_model_version("my-model", "v2.0")
experiment = registry.get_model_artifact("my-model", "v2.0")
```

### Default values for metadata

If not supplied, `metadata` values defaults to a predefined set of conventional values.
Reference the technical documentation in the pydoc of the client.

### Importing from Hugging Face Hub

To import models from Hugging Face Hub, start by installing the `huggingface-hub` package, either directly or as an
Expand Down
21 changes: 18 additions & 3 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Standard client for the model registry."""
from __future__ import annotations

import os
from typing import get_args
from warnings import warn

Expand Down Expand Up @@ -98,8 +99,8 @@ def register_model(
storage_key: Storage key.
storage_path: Storage path.
service_account_name: Service account name.
metadata: Additional version metadata.
metadata: Additional version metadata. Defaults to values returned by `default_metadata()`.
Returns:
Registered model.
"""
Expand All @@ -109,7 +110,7 @@ def register_model(
version,
author or self._author,
description=description,
metadata=metadata or {},
metadata=metadata or self.default_metadata(),
)
self._register_model_artifact(
mv,
Expand All @@ -122,6 +123,19 @@ def register_model(
)

return rm

def default_metadata(self) -> dict[str, ScalarType]:
"""Default metadata valorisations
When not explicitly supplied by the end users, these valorisations will be used
by default.
Returns:
default metadata valorisations.
"""
return {
key: os.environ[key] for key in ['AWS_S3_ENDPOINT', 'AWS_S3_BUCKET', 'AWS_DEFAULT_REGION'] if key in os.environ
}

def register_hf_model(
self,
Expand Down Expand Up @@ -188,6 +202,7 @@ def register_hf_model(
model_author = author
source_uri = hf_hub_url(repo, path, revision=git_ref)
metadata = {
**self.default_metadata(),
"repo": repo,
"source_uri": source_uri,
"model_origin": "huggingface_hub",
Expand Down
66 changes: 64 additions & 2 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import os
from model_registry import ModelRegistry
from model_registry.core import ModelRegistryAPIClient
from model_registry.exceptions import StoreException
Expand Down Expand Up @@ -46,13 +47,15 @@ def test_register_existing_version(mr_client: ModelRegistry):
def test_get(mr_client: ModelRegistry):
name = "test_model"
version = "1.0.0"
metadata = {"a": 1, "b": "2"}

rm = mr_client.register_model(
name,
"s3",
model_format_name="test_format",
model_format_version="test_version",
version=version,
metadata=metadata
)

assert (_rm := mr_client.get_registered_model(name))
Expand All @@ -64,22 +67,81 @@ def test_get(mr_client: ModelRegistry):

assert (_mv := mr_client.get_model_version(name, version))
assert mv.id == _mv.id
assert mv.metadata == metadata
assert (_ma := mr_client.get_model_artifact(name, version))
assert ma.id == _ma.id


def test_default_md(mr_client: ModelRegistry):
name = "test_model"
version = "1.0.0"
env_values = {"AWS_S3_ENDPOINT": "value1", "AWS_S3_BUCKET": "value2", "AWS_DEFAULT_REGION": "value3"}
for k, v in env_values.items():
os.environ[k] = v

assert mr_client.register_model(
name,
"s3",
model_format_name="test_format",
model_format_version="test_version",
version=version,
# ensure leave empty metadata
)
assert (mv := mr_client.get_model_version(name, version))
assert mv.metadata == env_values

for k in env_values.keys():
os.environ.pop(k)


def test_hf_import(mr_client: ModelRegistry):
pytest.importorskip("huggingface_hub")
name = "openai-community/gpt2"
version = "1.2.3"
author = "test author"

assert mr_client.register_hf_model(
name,
"onnx/decoder_model.onnx",
author=author,
version=version,
model_format_name="test format",
model_format_version="test version",
)
assert (mv := mr_client.get_model_version(name, version))
assert mv.author == author
assert mv.metadata["model_author"] == author
assert mv.metadata["model_origin"] == "huggingface_hub"
assert mv.metadata["source_uri"] == "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx"
assert mv.metadata["repo"] == name
assert mr_client.get_model_artifact(name, version)


def test_hf_import_default_env(mr_client: ModelRegistry):
"""Test setting environment variables, hence triggering defaults, does _not_ interfere with HF metadata
"""
pytest.importorskip("huggingface_hub")
name = "openai-community/gpt2"
version = "1.2.3"
author = "test author"
env_values = {"AWS_S3_ENDPOINT": "value1", "AWS_S3_BUCKET": "value2", "AWS_DEFAULT_REGION": "value3"}
for k, v in env_values.items():
os.environ[k] = v

assert mr_client.register_hf_model(
name,
"onnx/decoder_model.onnx",
author="test author",
author=author,
version=version,
model_format_name="test format",
model_format_version="test version",
)
assert mr_client.get_model_version(name, version)
assert (mv := mr_client.get_model_version(name, version))
assert mv.metadata["model_author"] == author
assert mv.metadata["model_origin"] == "huggingface_hub"
assert mv.metadata["source_uri"] == "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx"
assert mv.metadata["repo"] == name
assert mr_client.get_model_artifact(name, version)

for k in env_values.keys():
os.environ.pop(k)

0 comments on commit 387d88a

Please sign in to comment.