-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path{% if not metric_compare_promotion %}promote_latest_version.py{% endif %}
46 lines (37 loc) · 1.59 KB
/
{% if not metric_compare_promotion %}promote_latest_version.py{% endif %}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# {% include 'template/license_header' %}
from zenml import get_step_context, step
from zenml import Model
from zenml.logger import get_logger
from utils import promote_in_model_registry
logger = get_logger(__name__)
@step
def promote_latest_version(
mlflow_model_name: str,
target_env: str
) -> None:
"""Promote latest trained model.
This is an example of a model promotion step, which promotes the
latest trained model to the current version.
"""
### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ###
# Get model version numbers from Model Control Plane
latest_version = get_step_context().model
current_version = Model(name=latest_version.name, version=target_env)
logger.info(f"Promoting latest model version `{latest_version}`")
# Promote in Model Control Plane
model = get_step_context().model
model.set_stage(stage=target_env, force=True)
logger.info(f"Current model version was promoted to '{target_env}'.")
# Promote in Model Registry
latest_version_model_registry_number = latest_version.run_metadata["model_registry_version"]
if current_version.number is None:
current_version_model_registry_number = latest_version_model_registry_number
else:
current_version_model_registry_number = current_version.run_metadata["model_registry_version"]
promote_in_model_registry(
latest_version=latest_version_model_registry_number,
current_version=current_version_model_registry_number,
model_name=mlflow_model_name,
target_env=target_env.capitalize(),
)
### YOUR CODE ENDS HERE ###