Skip to content

Commit

Permalink
Fix uniqueness constraint with SqlRegisteredModel.name
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Nov 1, 2023
1 parent 7f31a20 commit 0b39054
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


## in progress
- Fix uniqueness constraint with `SqlRegisteredModel.name`. Thanks, @andnig.

## 2023-10-11 0.2.0
- Update to [MLflow 2.7](https://github.com/mlflow/mlflow/releases/tag/v2.7.0)
Expand Down
2 changes: 2 additions & 0 deletions mlflow_cratedb/patch/mlflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ def polyfill_uniqueness_constraints():
TODO: Submit patch to `crate-python`, to be enabled by a
dialect parameter `crate_polyfill_unique` or such.
"""
from mlflow.store.model_registry.dbmodels.models import SqlRegisteredModel
from mlflow.store.tracking.dbmodels.models import SqlExperiment

listen(SqlExperiment, "before_insert", check_uniqueness_factory(SqlExperiment, "name"))
listen(SqlRegisteredModel, "before_insert", check_uniqueness_factory(SqlRegisteredModel, "name"))


def polyfill_refresh_after_dml():
Expand Down
57 changes: 57 additions & 0 deletions tests/test_tracking_issues.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
import re

import mlflow
import mlflow.sklearn
from mlflow.models import Model
from mlflow.models.model import ModelInfo


def test_log_model_twice(tracking_store, reset_database):
"""
Problem
-------
Problems when calling `mlflow.sklearn.log_model` twice.
UPDATE statement on table 'registered_models' expected to update 1 row(s); 2 were matched.
Solution
--------
Add a uniqueness constraint.
References
----------
- https://github.com/crate-workbench/mlflow-cratedb/issues/46
"""

# Activate backend for tracking.
os.environ["MLFLOW_TRACKING_URI"] = tracking_store.db_uri

# Every experiment needs a name.
mlflow.set_experiment("test_log_model")

artifact_path = "testdrive-artifact"
registered_model_name = "testdrive-artifact-model"
sk_model = None

# Emulate `mlflow.sklearn.log_model`.
def log_model(metadata=None):
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.sklearn,
registered_model_name=registered_model_name,
await_registration_for=0.01,
metadata=metadata,
sk_model=sk_model,
)

# Verify that the model incurred an update.

model_info = log_model(metadata={"status": "update-1", "training": True})
assert isinstance(model_info, ModelInfo)
assert re.match(r".*runs:/[0-9a-z]+/testdrive-artifact.*", model_info.model_uri)
assert model_info.metadata == {"status": "update-1", "training": True}

model_info = log_model(metadata={"status": "update-2", "knowledge": "excellent"})
assert isinstance(model_info, ModelInfo)
assert re.match(r".*runs:/[0-9a-z]+/testdrive-artifact.*", model_info.model_uri)
assert model_info.metadata == {"status": "update-2", "knowledge": "excellent"}

0 comments on commit 0b39054

Please sign in to comment.