Skip to content

Commit

Permalink
handle inference parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Calychas committed Aug 12, 2024
1 parent 0ec5458 commit 309a513
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 20 deletions.
1 change: 1 addition & 0 deletions kedro_mlflow/framework/hooks/mlflow_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def after_pipeline_run(
pipeline=pipeline.inference,
catalog=catalog,
input_name=pipeline.input_name,
params_input_name=pipeline.params_input_name,
**pipeline.kpm_kwargs,
)
artifacts = kedro_pipeline_model.extract_pipeline_artifacts(
Expand Down
22 changes: 17 additions & 5 deletions kedro_mlflow/mlflow/kedro_pipeline_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, Any

from kedro.framework.hooks import _create_hook_manager
from kedro.io import DataCatalog, MemoryDataset
Expand All @@ -20,6 +20,7 @@ def __init__(
input_name: str,
runner: Optional[AbstractRunner] = None,
copy_mode: Optional[Union[Dict[str, str], str]] = "assign",
params_input_name: Optional[str] = None,
):
"""[summary]
Expand All @@ -30,6 +31,8 @@ def __init__(
catalog (DataCatalog): The DataCatalog associated
to the PipelineMl
input_name (str): TODO
runner (Optional[AbstractRunner], optional): The kedro
AbstractRunner to use. Defaults to SequentialRunner if
None.
Expand All @@ -45,12 +48,16 @@ def __init__(
- a dictionary with (dataset name, copy_mode) key/values
pairs. The associated mode must be a valid kedro mode
("deepcopy", "copy" and "assign") for each. Defaults to None.
params_input_name (Optional[str]): TODO
"""

self.pipeline = (
pipeline.inference if isinstance(pipeline, PipelineML) else pipeline
)
self.input_name = input_name
self.params_input_name = params_input_name
self.initial_catalog = self._extract_pipeline_catalog(catalog)

nb_outputs = len(self.pipeline.outputs())
Expand Down Expand Up @@ -107,7 +114,7 @@ def copy_mode(self, copy_mode):
def _extract_pipeline_catalog(self, catalog: DataCatalog) -> DataCatalog:
sub_catalog = DataCatalog()
for dataset_name in self.pipeline.inputs():
if dataset_name == self.input_name:
if dataset_name in (self.input_name, self.params_input_name):
# there is no obligation that this dataset is persisted
# and even if it is, we keep only an ampty memory dataset to avoid
# extra uneccessary dependencies: this dataset will be replaced at
Expand Down Expand Up @@ -145,7 +152,7 @@ def extract_pipeline_artifacts(
):
artifacts = {}
for name, dataset in self.initial_catalog._datasets.items():
if name != self.input_name:
if name not in (self.input_name, self.params_input_name):
if name.startswith("params:"):
# we need to persist it locally for mlflow access
absolute_param_path = (
Expand Down Expand Up @@ -177,7 +184,9 @@ def load_context(self, context):
# but we rely on a mlflow function for saving, and it is unaware of kedro
# pipeline structure
mlflow_artifacts_keys = set(context.artifacts.keys())
kedro_artifacts_keys = set(self.pipeline.inputs() - {self.input_name})
kedro_artifacts_keys = set(
self.pipeline.inputs() - {self.input_name, self.params_input_name}
)
if mlflow_artifacts_keys != kedro_artifacts_keys:
in_artifacts_but_not_inference = (
mlflow_artifacts_keys - kedro_artifacts_keys
Expand All @@ -196,7 +205,7 @@ def load_context(self, context):
updated_catalog._datasets[name]._filepath = Path(uri)
self.loaded_catalog.save(name=name, data=updated_catalog.load(name))

def predict(self, context, model_input):
def predict(self, context, model_input, params: Optional[dict[str, Any]] = None):
# we create an empty hook manager but do NOT register hooks
# because we want this model be executable outside of a kedro project
hook_manager = _create_hook_manager()
Expand All @@ -206,6 +215,9 @@ def predict(self, context, model_input):
data=model_input,
)

if self.params_input_name:
self.loaded_catalog.save(name=self.params_input_name, data=params)

run_output = self.runner.run(
pipeline=self.pipeline,
catalog=self.loaded_catalog,
Expand Down
44 changes: 32 additions & 12 deletions kedro_mlflow/pipeline/pipeline_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
input_name: str,
kpm_kwargs: Optional[Dict[str, str]] = None,
log_model_kwargs: Optional[Dict[str, str]] = None,
params_input_name: Optional[str] = None,
):
"""Store all necessary information for calling mlflow.log_model in the pipeline.
Expand All @@ -56,9 +57,9 @@ def __init__(
stored in mlflow and use the output(s)
of the training pipeline (namely, the model)
to predict the outcome.
input_name (str, optional): The name of the dataset in
input_name (str): The name of the dataset in
the catalog.yml which the model's user must provide
for prediction (i.e. the data). Defaults to None.
for prediction (i.e. the data).
kpm_kwargs:
extra arguments to be passed to `KedroPipelineModel`
when the PipelineML object is automatically saved at the end of a run.
Expand All @@ -70,13 +71,15 @@ def __init__(
extra arguments to be passed to `mlflow.pyfunc.log_model`, e.g.:
- "signature" accepts an extra "auto" which automatically infer the signature
based on "input_name" dataset
params_input_name (str, optional): TODO
"""

super().__init__(nodes, *args, tags=tags)

self.inference = inference
self.input_name = input_name
self.params_input_name = params_input_name
# they will be passed to KedroPipelineModel to enable flexibility

kpm_kwargs = kpm_kwargs or {}
Expand All @@ -91,7 +94,7 @@ def training(self) -> Pipeline:
return Pipeline(self.nodes)

@property
def inference(self) -> str:
def inference(self) -> Pipeline:
return self._inference

@inference.setter
Expand All @@ -114,6 +117,22 @@ def input_name(self, name: str) -> None:
)
self._input_name = name

@property
def params_input_name(self) -> str | None:
return self._params_input_name

@params_input_name.setter
def params_input_name(self, name: str | None) -> None:
if name is not None:
allowed_names = self.inference.inputs()
pp_allowed_names = "\n - ".join(allowed_names)
if name not in allowed_names:
raise KedroMlflowPipelineMLError(
f"params_input_name='{name}' but it must be an input of 'inference'"
f", i.e. one of: \n - {pp_allowed_names}"
)
self._params_input_name = name

def _check_inference(self, inference: Pipeline) -> None:
nb_outputs = len(inference.outputs())
outputs_txt = "\n - ".join(inference.outputs())
Expand All @@ -133,7 +152,7 @@ def _check_consistency(self) -> None:

free_inputs_set = (
self.inference.inputs()
- {self.input_name}
- {self.input_name, self.params_input_name}
- self.all_outputs()
- self.inputs()
- inference_parameters # it is allowed to pass parameters: they will be automatically persisted by the hook
Expand All @@ -147,7 +166,7 @@ def _check_consistency(self) -> None:
" \nNo free input is allowed."
" Please make sure that 'inference.inputs()' are all"
" in 'training.all_outputs() + training.inputs()'"
"except 'input_name' and parameters which starts with 'params:'."
"except 'input_name', 'params_input_name' and parameters which starts with 'params:'."
)

def _turn_pipeline_to_ml(self, pipeline: Pipeline):
Expand All @@ -157,6 +176,7 @@ def _turn_pipeline_to_ml(self, pipeline: Pipeline):
input_name=self.input_name,
kpm_kwargs=self.kpm_kwargs,
log_model_kwargs=self.log_model_kwargs,
params_input_name=self.params_input_name,
)

def only_nodes(self, *node_names: str) -> "Pipeline": # pragma: no cover
Expand Down Expand Up @@ -211,13 +231,13 @@ def tag(self, tags: Union[str, Iterable[str]]) -> "PipelineML":

def filter(
self,
tags: Iterable[str] = None,
from_nodes: Iterable[str] = None,
to_nodes: Iterable[str] = None,
node_names: Iterable[str] = None,
from_inputs: Iterable[str] = None,
to_outputs: Iterable[str] = None,
node_namespace: str = None,
tags: Optional[Iterable[str]] = None,
from_nodes: Optional[Iterable[str]] = None,
to_nodes: Optional[Iterable[str]] = None,
node_names: Optional[Iterable[str]] = None,
from_inputs: Optional[Iterable[str]] = None,
to_outputs: Optional[Iterable[str]] = None,
node_namespace: Optional[str] = None,
) -> "Pipeline":
# see from_inputs for an explanation of why we don't call super()
pipeline = self.training.filter(
Expand Down
10 changes: 7 additions & 3 deletions kedro_mlflow/pipeline/pipeline_ml_factory.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from kedro.pipeline import Pipeline

from kedro_mlflow.pipeline.pipeline_ml import PipelineML
from typing import Optional


def pipeline_ml_factory(
training: Pipeline,
inference: Pipeline,
input_name: str = None,
input_name: str,
kpm_kwargs=None,
log_model_kwargs=None,
params_input_name: Optional[str] = None,
) -> PipelineML:
"""This function is a helper to create `PipelineML`
object directly from two Kedro `Pipelines` (one of
Expand All @@ -23,9 +25,9 @@ def pipeline_ml_factory(
stored in mlflow and use the output(s)
of the training pipeline (namely, the model)
to predict the outcome.
input_name (str, optional): The name of the dataset in
input_name (str): The name of the dataset in
the catalog.yml which the model's user must provide
for prediction (i.e. the data). Defaults to None.
for prediction (i.e. the data).
kpm_kwargs:
extra arguments to be passed to `KedroPipelineModel`
when the PipelineML object is automatically saved at the end of a run.
Expand All @@ -37,6 +39,7 @@ def pipeline_ml_factory(
extra arguments to be passed to `mlflow.pyfunc.log_model`
when the PipelineML object is automatically saved at the end of a run.
See mlflow documentation to see all available options: https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#mlflow.pyfunc.log_model
params_input_name (str, optional): TODO
Returns:
PipelineML: A `PipelineML` which is automatically
Expand All @@ -51,5 +54,6 @@ def pipeline_ml_factory(
input_name=input_name,
kpm_kwargs=kpm_kwargs,
log_model_kwargs=log_model_kwargs,
params_input_name=params_input_name,
)
return pipeline

0 comments on commit 309a513

Please sign in to comment.