Skip to content

Commit

Permalink
remove sample_posterior_predictive patch and test prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Jan 20, 2025
1 parent 5e83235 commit 9ccb221
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 23 deletions.
19 changes: 0 additions & 19 deletions pymc_marketing/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,27 +1057,8 @@ def new_fit(self, *args, **kwargs):

return new_fit

def patch_mmm_sample_posterior_predictive(
sample_posterior_predictive: Callable,
) -> Callable:
@wraps(sample_posterior_predictive)
def new_sample_posterior_predictive(self, *args, **kwargs):
posterior_preds = sample_posterior_predictive(self, *args, **kwargs)

log_mmm_evaluation_metrics(
y_true=self.y,
y_pred=posterior_preds[self.output_var],
)

return posterior_preds

return new_sample_posterior_predictive

if log_mmm:
MMM.fit = patch_mmm_fit(MMM.fit)
MMM.sample_posterior_predictive = patch_mmm_sample_posterior_predictive(
MMM.sample_posterior_predictive
)

def patch_clv_fit(fit):
@wraps(fit)
Expand Down
15 changes: 11 additions & 4 deletions tests/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,18 +583,23 @@ def test_log_mmm_evaluation_metrics() -> None:
"""Test logging of summary metrics to MLflow."""
y_true = np.array([1.0, 2.0, 3.0])
y_pred = np.array([[1.1, 2.1, 3.1]]).T
custom_metrics = ["r_squared", "rmse", "mae", "mape", "nrmse", "nmae"]
custom_metrics = ["r_squared", "rmse"]

prefix: str = "in-sample"
with mlflow.start_run() as run:
log_mmm_evaluation_metrics(
y_true, y_pred, metrics_to_calculate=custom_metrics, hdi_prob=0.94
y_true,
y_pred,
metrics_to_calculate=custom_metrics,
hdi_prob=0.94,
prefix=prefix,
)

run_id = run.info.run_id
run_data = get_run_data(run_id)

# Check that metrics are logged with expected prefixes and suffixes
metric_prefixes = {"r_squared", "rmse", "mae", "mape", "nrmse", "nmae"}
metric_prefixes = {"r_squared", "rmse"}
metric_suffixes = {
"mean",
"median",
Expand All @@ -605,7 +610,9 @@ def test_log_mmm_evaluation_metrics() -> None:
"94_hdi_upper",
}
expected_metrics = {
f"{prefix}_{suffix}" for prefix in metric_prefixes for suffix in metric_suffixes
f"{prefix}_{metric_prefix}_{metrix_suffix}"
for metric_prefix in metric_prefixes
for metrix_suffix in metric_suffixes
}
assert set(run_data.metrics.keys()) == expected_metrics

Expand Down

0 comments on commit 9ccb221

Please sign in to comment.