Skip to content

Commit

Permalink
Add params to custom class (#34)
Browse files Browse the repository at this point in the history
* Add params to FastTextWrapper.predict

* Change predict script in applications
  • Loading branch information
tomseimandi authored Nov 23, 2023
1 parent 20aa835 commit ef8482e
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 24 deletions.
7 changes: 1 addition & 6 deletions slides/en/applications/_application2.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,7 @@ model = mlflow.pyfunc.load_model(
list_libs = ["vendeur d'huitres", "boulanger"]
test_data = {
"query": list_libs,
"k": 1
}
results = model.predict(test_data)
results = model.predict(list_libs, params={"k": 1})
print(results)
```
Expand Down
7 changes: 1 addition & 6 deletions slides/fr/applications/_application2.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,7 @@ model = mlflow.pyfunc.load_model(
list_libs = ["vendeur d'huitres", "boulanger"]
test_data = {
"query": list_libs,
"k": 1
}
results = model.predict(test_data)
results = model.predict(list_libs, params={"k": 1})
print(results)
```
</details>
Expand Down
26 changes: 15 additions & 11 deletions src/fasttext_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
FastText wrapper for MLflow.
"""
from typing import Tuple, Dict
from typing import Tuple, Optional, Dict, Any, List
import fasttext
import mlflow
import pandas as pd
Expand Down Expand Up @@ -39,31 +39,35 @@ def load_context(self, context: mlflow.pyfunc.PythonModelContext) -> None:
self.model = fasttext.load_model(context.artifacts["model_path"])

def predict(
self, context: mlflow.pyfunc.PythonModelContext, model_input: Dict
self,
context: mlflow.pyfunc.PythonModelContext,
model_input: List,
params: Optional[Dict[str, Any]] = None
) -> Tuple:
"""
Predicts the k most likely codes to a query.
Predicts the most likely codes for a list of texts.
Args:
context (mlflow.pyfunc.PythonModelContext): The MLflow model
context.
model_input (dict): A dictionary containing the input data for the
model. It should have the following keys:
- 'query': A dictionary containing the query features.
- 'k': An integer representing the number of predicted codes to
return.
model_input (List): A list of text observations.
params (Optional[Dict[str, Any]]): Additional parameters to
pass to the model for inference.
Returns:
A tuple containing the k most likely codes to the query.
"""
df = self.preprocessor.clean_text(
pd.DataFrame(model_input["query"], columns=[TEXT_FEATURE]),
pd.DataFrame(model_input, columns=[TEXT_FEATURE]),
text_feature=TEXT_FEATURE,
)

texts = df.apply(self._format_item, axis=1).to_list()

predictions = self.model.predict(texts, k=model_input["k"])
predictions = self.model.predict(
texts,
**params
)

predictions_formatted = {
i: {
Expand All @@ -72,7 +76,7 @@ def predict(
"nace": predictions[0][i][rank_pred].replace(LABEL_PREFIX, ""),
"probability": float(predictions[1][i][rank_pred]),
}
for rank_pred in range(model_input["k"])
for rank_pred in range(params["k"])
}
for i in range(len(predictions[0]))
}
Expand Down
2 changes: 1 addition & 1 deletion src/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def clean_text(self, df: pd.DataFrame, text_feature: str) -> pd.DataFrame:
df (pd.DataFrame): Clean DataFrame.
"""
df = df.copy()

# Fix encoding
df[text_feature] = df[text_feature].map(unidecode.unidecode)

Expand Down
9 changes: 9 additions & 0 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ def train(
"train_data": training_data_path,
}

inference_params = {
"k": 1,
}
# Infer the signature including parameters
signature = mlflow.models.infer_signature(
params=inference_params,
)

mlflow.pyfunc.log_model(
artifact_path=run_name,
python_model=FastTextWrapper(),
Expand All @@ -107,6 +115,7 @@ def train(
"src/constants.py",
],
artifacts=artifacts,
signature=signature
)

# Log parameters
Expand Down

0 comments on commit ef8482e

Please sign in to comment.