Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Enable to override params at predict time in KedroPipeline… #612

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion kedro_mlflow/framework/hooks/mlflow_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,18 @@ def after_pipeline_run(
if isinstance(model_signature, str):
if model_signature == "auto":
input_data = catalog.load(pipeline.input_name)
model_signature = infer_signature(model_input=input_data)

# all pipeline params will be overridable at predict time: https://mlflow.org/docs/latest/model/signatures.html#model-signatures-with-inference-params
# I add the special "runner" parameter to be able to choose it at runtime
pipeline_params = {
ds_name[7:]: catalog.load(ds_name)
for ds_name in pipeline.inference.inputs()
if ds_name.startswith("params:")
} | {"runner": "SequentialRunner"}
model_signature = infer_signature(
model_input=input_data,
params=pipeline_params,
)

mlflow.pyfunc.log_model(
python_model=kedro_pipeline_model,
Expand Down
25 changes: 23 additions & 2 deletions kedro_mlflow/mlflow/kedro_pipeline_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,17 +196,38 @@ def load_context(self, context):
updated_catalog._datasets[name]._filepath = Path(uri)
self.loaded_catalog.save(name=name, data=updated_catalog.load(name))

def predict(self, context, model_input):
def predict(self, context, model_input, params=None):
# we create an empty hook manager but do NOT register hooks
# because we want this model be executable outside of a kedro project

# params can pass
# TODO globals
# TODO runtime
# TODO parameters -> I'd prefer not have them, but it would require catalog to be able to not be fully resolved if we want to pass runtime and globals
# TODO hooks
# TODO runner

params = params or {}
runner_class = params.pop("runner", "SequentialRunner")
runner = (
self.runner
) # runner="build it dynamically from runner class" or self.runner

hook_manager = _create_hook_manager()
# _register_hooks(hook_manager, predict_params.hooks)

for name, value in params.items():
# no need to check if params are ni the catalog, because mlflow already checks that the params mathc the signature
param = f"params:{name}"
self._logger.info(f"Using {param}={value} for the prediction")
self.loaded_catalog.save(name=param, data=value)

self.loaded_catalog.save(
name=self.input_name,
data=model_input,
)

run_output = self.runner.run(
run_output = runner.run(
pipeline=self.pipeline,
catalog=self.loaded_catalog,
hook_manager=hook_manager,
Expand Down
Loading