Skip to content

Commit

Permalink
✨ Sanitize parameters names to avoid logging error with unsupported c…
Browse files Browse the repository at this point in the history
…haracters (#595)

* add ability to blacklist params. also sanitise keys

* rm dropping secret params

* rm lint

* rm empty l

* rm docs

* add regex

* add testsd

* resolve nitpicks

* fix typo

* fix test on windows

* remove platform specific tests for sanitization

* add changelog entries

---------

Co-authored-by: Yolan Honoré-Rougé <[email protected]>
  • Loading branch information
pascalwhoop and Galileo-Galilei authored Oct 23, 2024
1 parent 516cb4b commit fc584c9
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 0 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

## [Unreleased]

### Added

- :sparkles: Implement missing PipelineML filtering functionalities to let kedro display resume hints and avoid breaking kedro-viz ([#377, Calychas](https://github.com/Galileo-Galilei/kedro-mlflow/pull/377), [#601, Calychas](https://github.com/Galileo-Galilei/kedro-mlflow/pull/601))
- :sparkles: Sanitize parameters name with unsupported characters to avoid mlflow errors when logging ([#595, pascalwhoop](https://github.com/Galileo-Galilei/kedro-mlflow/pull/595))

## [0.13.2] - 2024-10-15

### Fixed
Expand Down
28 changes: 28 additions & 0 deletions kedro_mlflow/framework/hooks/mlflow_hook.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import re
from logging import Logger, getLogger
from pathlib import Path
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -303,6 +305,11 @@ def before_node_run(
d=params_inputs, recursive=self.recursive, sep=self.sep
)

# sanitize params inputs to avoid mlflow errors
params_inputs = {
self.sanitize_param_name(k): v for k, v in params_inputs.items()
}

# logging parameters based on defined strategy
for k, v in params_inputs.items():
self._log_param(k, v)
Expand Down Expand Up @@ -446,5 +453,26 @@ def on_pipeline_error(
# hence it should not be modified. this is only a safeguard
switch_catalog_logging(catalog, True)

def sanitize_param_name(self, name: str) -> str:
# regex taken from MLFlow codebase: https://github.com/mlflow/mlflow/blob/e40e782b6fcab473159e6d4fee85bc0fc10f78fd/mlflow/utils/validation.py#L140C1-L148C44

# for windows colon ':' are not accepted
matching_pattern = r"^[/\w.\- ]*$" if is_windows() else r"^[/\w.\- :]*$"

if re.match(matching_pattern, name):
return name
else:
replacement_pattern = r"[^/\w.\- ]" if is_windows() else r"[^/\w.\- :]"
# Replace invalid characters with underscore
sanitized_name = re.sub(replacement_pattern, "_", name)
self._logger.warning(
f"'{name}' is not a valid name for a mlflow paramter. It is renamed as '{sanitized_name}'"
)
return sanitized_name


def is_windows():
return os.name == "nt"


mlflow_hook = MlflowHook()
131 changes: 131 additions & 0 deletions tests/framework/hooks/test_hook_log_parameters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path
from typing import Dict

Expand Down Expand Up @@ -74,6 +75,136 @@ def dummy_catalog():
return catalog


@pytest.mark.parametrize(
"param_name,expected_name",
[
("valid_param", "valid_param"),
("valid-param", "valid-param"),
("invalid/param", "invalid/param"),
("invalid.param", "invalid.param"),
("[invalid]$param", "_invalid__param"),
],
)
def test_parameter_name_sanitization(
kedro_project, dummy_run_params, param_name, expected_name
):
mlflow_tracking_uri = (kedro_project / "mlruns").as_uri()
mlflow.set_tracking_uri(mlflow_tracking_uri)

node_inputs = {f"params:{param_name}": "test_value"}

bootstrap_project(kedro_project)
with KedroSession.create(
project_path=kedro_project,
) as session:
context = session.load_context()
mlflow_node_hook = MlflowHook()
mlflow_node_hook.after_context_created(context)

with mlflow.start_run():
mlflow_node_hook.before_pipeline_run(
run_params=dummy_run_params,
pipeline=Pipeline([]),
catalog=DataCatalog(),
)
mlflow_node_hook.before_node_run(
node=node(func=lambda x: x, inputs=dict(x="a"), outputs=None),
catalog=DataCatalog(),
inputs=node_inputs,
is_async=False,
)
run_id = mlflow.active_run().info.run_id

mlflow_client = MlflowClient(mlflow_tracking_uri)
current_run = mlflow_client.get_run(run_id)
assert expected_name in current_run.data.params
assert current_run.data.params[expected_name] == "test_value"


@pytest.mark.skipif(
os.name != "nt", reason="Windows does not log params with colon symbol"
)
def test_parameter_name_with_colon_sanitization_on_windows(
kedro_project, dummy_run_params
):
param_name = "valid:param"
expected_name = "valid_param"

mlflow_tracking_uri = (kedro_project / "mlruns").as_uri()
mlflow.set_tracking_uri(mlflow_tracking_uri)

node_inputs = {f"params:{param_name}": "test_value"}

bootstrap_project(kedro_project)
with KedroSession.create(
project_path=kedro_project,
) as session:
context = session.load_context()
mlflow_node_hook = MlflowHook()
mlflow_node_hook.after_context_created(context)

with mlflow.start_run():
mlflow_node_hook.before_pipeline_run(
run_params=dummy_run_params,
pipeline=Pipeline([]),
catalog=DataCatalog(),
)
mlflow_node_hook.before_node_run(
node=node(func=lambda x: x, inputs=dict(x="a"), outputs=None),
catalog=DataCatalog(),
inputs=node_inputs,
is_async=False,
)
run_id = mlflow.active_run().info.run_id

mlflow_client = MlflowClient(mlflow_tracking_uri)
current_run = mlflow_client.get_run(run_id)
assert expected_name in current_run.data.params
assert current_run.data.params[expected_name] == "test_value"


@pytest.mark.skipif(
os.name == "nt", reason="Linux and Mac do log params with colon symbol"
)
def test_parameter_name_with_colon_sanitization_on_mac_linux(
kedro_project, dummy_run_params
):
param_name = "valid:param"
expected_name = "valid:param"

mlflow_tracking_uri = (kedro_project / "mlruns").as_uri()
mlflow.set_tracking_uri(mlflow_tracking_uri)

node_inputs = {f"params:{param_name}": "test_value"}

bootstrap_project(kedro_project)
with KedroSession.create(
project_path=kedro_project,
) as session:
context = session.load_context()
mlflow_node_hook = MlflowHook()
mlflow_node_hook.after_context_created(context)

with mlflow.start_run():
mlflow_node_hook.before_pipeline_run(
run_params=dummy_run_params,
pipeline=Pipeline([]),
catalog=DataCatalog(),
)
mlflow_node_hook.before_node_run(
node=node(func=lambda x: x, inputs=dict(x="a"), outputs=None),
catalog=DataCatalog(),
inputs=node_inputs,
is_async=False,
)
run_id = mlflow.active_run().info.run_id

mlflow_client = MlflowClient(mlflow_tracking_uri)
current_run = mlflow_client.get_run(run_id)
assert expected_name in current_run.data.params
assert current_run.data.params[expected_name] == "test_value"


def test_pipeline_run_hook_getting_configs(
kedro_project,
dummy_run_params,
Expand Down

0 comments on commit fc584c9

Please sign in to comment.