Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔊 Log model version information when a model is loaded from the registry (#552) #588

Merged
merged 8 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

- :sparkles: Implement missing ``PipelineML`` filtering functionalities to let ``kedro`` display resume hints and avoid breaking ``kedro-viz`` ([#377](https://github.com/Galileo-Galilei/kedro-mlflow/pull/377), [#601, Calychas](https://github.com/Galileo-Galilei/kedro-mlflow/pull/601))
- :sparkles: Sanitize parameters name with unsupported characters to avoid ``mlflow`` errors when logging ([#595, pascalwhoop](https://github.com/Galileo-Galilei/kedro-mlflow/pull/595))
- :loud_sound: Add logs about the exact ``run_id`` loaded within a ``MlflowRegistryDataset`` because some URI are confusing (e.g. ``latest``) and hard to debug ([#552](https://github.com/Galileo-Galilei/kedro-mlflow/pull/552))

### Changed

Expand Down
15 changes: 14 additions & 1 deletion kedro_mlflow/io/models/mlflow_model_registry_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from logging import Logger, getLogger
from typing import Any, Dict, Optional, Union

from kedro.io.core import DatasetError
Expand Down Expand Up @@ -67,6 +68,10 @@ def __init__(
else f"models:/{model_name}/{stage_or_version}"
)

@property
def _logger(self) -> Logger:
return getLogger(__name__)

def _load(self) -> Any:
"""Loads an MLflow model from local path or from MLflow run.

Expand All @@ -77,10 +82,18 @@ def _load(self) -> Any:
# If `run_id` is specified, pull the model from MLflow.
# TODO: enable loading from another mlflow conf (with a client with another tracking uri)
# Alternatively, use local path to load the model.
return self._mlflow_model_module.load_model(
model = self._mlflow_model_module.load_model(
model_uri=self.model_uri, **self._load_args
)

# log some info because "latest" model is not very informative
# the model itself does not have information about its registry
# because the same run can be registered under several different names
# in the registry. See https://github.com/Galileo-Galilei/kedro-mlflow/issues/552

self._logger.info(f"Loading model from run_id='{model.metadata.run_id}'")
return model

def _save(self, model: Any) -> None:
raise NotImplementedError(
"The 'save' method is not implemented for MlflowModelRegistryDataset. You can pass 'registered_model_name' argument in 'MLflowModelTrackingDataset(..., save_args={registered_model_name='my_model'}' to save and register a model in the same step. "
Expand Down
1 change: 0 additions & 1 deletion tests/framework/cli/test_cli_modelify.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,6 @@ def test_modelify_with_pip_requirements(monkeypatch, kp_for_modelify):
runs_list_before_cmd = context.mlflow.server._mlflow_client.search_runs(
context.mlflow.tracking.experiment._experiment.experiment_id
)
print(runs_list_before_cmd)
cli_runner = CliRunner()

result = cli_runner.invoke(
Expand Down
46 changes: 46 additions & 0 deletions tests/io/models/test_mlflow_model_registry_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,52 @@ def test_mlflow_model_registry_alias_and_stage_or_version_fails(tmp_path):
)


# this test is failing because of long standing issues like this :
# https://github.com/pytest-dev/pytest/issues/7335
# https://github.com/pytest-dev/pytest/issues/5160
# To make logging occur, we need to from kedro.framework.projcet import LOGGING at the beginning
# ironically, the sderr error reported by pytest shows that logging actually occurs!
# If I remove with mlflow.start_run(), caplog is indeed not empty, it seems mlflow flushes the internal loger
# probably related to https://github.com/mlflow/mlflow/issues/4957
@pytest.mark.xfail
def test_mlflow_model_registry_logs_run_id(caplog, tmp_path, monkeypatch):
# we must change the working directory because when
# using mlflow with a local database tracking, the artifacts
# are stored in a relative mlruns/ folder so we need to have
# the same working directory that the one of the tracking uri
monkeypatch.chdir(tmp_path)
tracking_and_registry_uri = r"sqlite:///" + (tmp_path / "mlruns3.db").as_posix()
mlflow.set_tracking_uri(tracking_and_registry_uri)
mlflow.set_registry_uri(tracking_and_registry_uri)

# setup: we train 2 version of a model under a single
# registered model and stage the 2nd one
run_ids = {}
for i in range(2):
with mlflow.start_run():
model = DecisionTreeClassifier()
mlflow.sklearn.log_model(
model, artifact_path="demo_model", registered_model_name="demo_model"
)
run_ids[i + 1] = mlflow.active_run().info.run_id

# case 1: no version is provided, we take the last one

ml_ds = MlflowModelRegistryDataset(model_name="demo_model", stage_or_version=1)
ml_ds.load()

# caplog.text, caplog.messages, caplog.records are all empty ???, but th stderr will show them
assert run_ids[1] in caplog.text

# case 2: a stage is provided, we take the last model with this stage
ml_ds = MlflowModelRegistryDataset(
model_name="demo_model", stage_or_version="latest"
)
ml_ds._load()

assert run_ids[2] in caplog.text


def test_mlflow_model_registry_load_given_stage_or_version(tmp_path, monkeypatch):
# we must change the working directory because when
# using mlflow with a local database tracking, the artifacts
Expand Down
Loading