Skip to content

Commit

Permalink
Handle optimization subset and automatically set predictors
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 9, 2025
1 parent f2d7704 commit 02ce740
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 48 deletions.
108 changes: 79 additions & 29 deletions pymc_marketing/mmm/budget_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@
"""Budget optimization module."""

import warnings
from typing import Any, ClassVar, Sequence
from collections.abc import Sequence
from typing import Any, ClassVar

import numpy as np
import pytensor.tensor as pt
from pydantic import BaseModel, ConfigDict, Field
from pymc import inputvars
from pymc.logprob.utils import rvs_in_graph
from pymc.model.transform.optimization import freeze_dims_and_data
from pytensor import clone_replace, function
from pytensor.graph import rewrite_graph, vectorize_graph
from scipy.optimize import minimize
from xarray import DataArray

from pymc_marketing.mmm.mmm import MMM
from pymc_marketing.mmm.components.adstock import AdstockTransformation
from pymc_marketing.mmm.components.saturation import SaturationTransformation
from pymc_marketing.mmm.constraints import (
build_constraint,
build_default_sum_constraint,
Expand Down Expand Up @@ -71,7 +71,7 @@ class BudgetOptimizer(BaseModel):
description="The number of time units at time granularity which the budget is to be allocated.",
)

hmm_model: MMM = Field(
hmm_model: Any = Field(
...,
description="The marketing mix model to optimize.",
arbitrary_types_allowed=True,
Expand All @@ -97,6 +97,11 @@ class BudgetOptimizer(BaseModel):
default=True,
description="Whether to set the default sum constraint to the optimizer.",
)
opt_mask: DataArray | None = Field(
...,
description="Mask defining a subset of budgets that should be optimized. "
"Non optimized budgets are fixed to 0.",
)

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand All @@ -107,14 +112,43 @@ class BudgetOptimizer(BaseModel):

def __init__(self, **data):
super().__init__(**data)
self._num_channels = len(self.hmm_model.model.coords["channel"])
self._pymc_model = self.hmm_model._set_predictors_for_optimization(

Check warning on line 115 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L115

Added line #L115 was not covered by tests
self.num_periods
)
self._coords = self._pymc_model._coords

Check warning on line 118 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L118

Added line #L118 was not covered by tests
self._compiled_functions = {}
self._compile_objective_and_grad()
self._constraints = {}
self.set_constraints(default=self.default_constraints, constraints=self.custom_constraints)
self.set_constraints(

Check warning on line 122 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L121-L122

Added lines #L121 - L122 were not covered by tests
default=self.default_constraints, constraints=self.custom_constraints
)

def _create_budget_variable(self):
model = self._pymc_model
coords = self._coords

Check warning on line 128 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L127-L128

Added lines #L127 - L128 were not covered by tests

budgets_shape = [

Check warning on line 130 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L130

Added line #L130 was not covered by tests
len(coords[dim])
for dim in model.named_vars_to_dims["channel_data"]
if dim != "date"
]
if self.opt_mask is not None:
size_budgets = self.opt_mask.sum().item()

Check warning on line 136 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L135-L136

Added lines #L135 - L136 were not covered by tests
else:
size_budgets = np.prod(budgets_shape)
budgets_flat = pt.tensor("budgets_flat", shape=(size_budgets,))

Check warning on line 139 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L138-L139

Added lines #L138 - L139 were not covered by tests

if self.opt_mask is not None:
budgets = pt.zeros(budgets_shape)[

Check warning on line 142 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L141-L142

Added lines #L141 - L142 were not covered by tests
np.asarray(self.opt_mask).astype(bool)
].set(budgets_flat)
else:
budgets = budgets_flat.reshape(budgets_shape)

Check warning on line 146 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L146

Added line #L146 was not covered by tests

return budgets

Check warning on line 148 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L148

Added line #L148 was not covered by tests

def set_constraints(self, constraints, default=None):
""" set constraints """
"""Set constraints"""
self._constraints = {}
if default is None:
default = False if constraints else True

Check warning on line 154 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L152-L154

Added lines #L152 - L154 were not covered by tests
Expand All @@ -126,28 +160,31 @@ def set_constraints(self, constraints, default=None):
constraint_type=c.get("constraint_type", "eq"),
)
self._constraints[c["key"]] = new_constraint

Check warning on line 162 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L162

Added line #L162 was not covered by tests

if default:
self._constraints["default"] = build_default_sum_constraint("default")

Check warning on line 165 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L164-L165

Added lines #L164 - L165 were not covered by tests

self._compiled_constraints = compile_constraints_for_scipy(constraints=self._constraints, optimizer=self)
budgets = self._create_budget_variable()
self._compiled_constraints = compile_constraints_for_scipy(

Check warning on line 168 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L167-L168

Added lines #L167 - L168 were not covered by tests
budgets=budgets, constraints=self._constraints, optimizer=self
)

def _compile_objective_and_grad(self):
"""Compile the objective function and its gradient using symbolic computation."""
budgets_sym = pt.vector("budgets", shape=(self._num_channels,))

response_distribution = self.extract_response_distribution(budgets=budgets_sym)
budgets = self._create_budget_variable()
response_distribution = self.extract_response_distribution(budgets=budgets)

Check warning on line 175 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L174-L175

Added lines #L174 - L175 were not covered by tests

objective_value = -self.utility_function(
samples=response_distribution, budgets=budgets_sym
samples=response_distribution, budgets=budgets
)

# Compute gradient symbolically
grad_obj = pt.grad(objective_value, budgets_sym)
[budgets_flat] = inputvars([budgets])
grad_obj = pt.grad(objective_value, budgets_flat)

Check warning on line 183 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L182-L183

Added lines #L182 - L183 were not covered by tests

# Compile the functions
utility_func = function([budgets_sym], objective_value)
grad_func = function([budgets_sym], grad_obj)
utility_func = function([budgets_flat], objective_value)
grad_func = function([budgets_flat], grad_obj)

Check warning on line 187 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L186-L187

Added lines #L186 - L187 were not covered by tests

# Cache the compiled functions
self._compiled_functions[self.utility_function] = {
Expand All @@ -169,27 +206,38 @@ def extract_response_distribution(
self, budgets: pt.TensorVariable
) -> pt.TensorVariable:
"""Extract the response graph, conditioned on the posterior draws and a placeholder budget variable."""
if not (isinstance(budgets, pt.TensorVariable) and budgets.type.ndim == 1):
raise ValueError("budgets must be a 1D TensorVariable")
if not (isinstance(budgets, pt.TensorVariable)): # and budgets.type.ndim == 1):
raise ValueError("budgets must be a TensorVariable")

Check warning on line 210 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L209-L210

Added lines #L209 - L210 were not covered by tests

model = self.hmm_model.model
model = self._pymc_model
posterior = self.hmm_model.idata.posterior # type: ignore
max_lag = self.hmm_model.adstock.l_max
num_periods = self.num_periods

Check warning on line 215 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L212-L215

Added lines #L212 - L215 were not covered by tests

# Freeze all but channel dims for a more succinct graph
# model = freeze_dims_and_data(
# model, data=[], dims=[dim for dim in model.coords if dim != "date"]
# )
model = freeze_dims_and_data(

Check warning on line 218 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L218

Added line #L218 was not covered by tests
model, data=[], dims=[dim for dim in self._coords if dim != "date"]
)

response_variable = model[self.response_variable]

Check warning on line 222 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L222

Added line #L222 was not covered by tests

# Replicate the budget over num_periods and append zeros to also quantify carry-over effects
n_channels = len(model.coords["channel"])
budgets_tiled = pt.broadcast_to(budgets, (num_periods, n_channels))
budgets_full = pt.zeros((num_periods + max_lag, n_channels))
channel_data_dims = model.named_vars_to_dims["channel_data"]
date_dim_idx = list(channel_data_dims).index("date")

Check warning on line 226 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L225-L226

Added lines #L225 - L226 were not covered by tests

budgets_tiled_shape = list(tuple(budgets.shape))
budgets_tiled_shape.insert(date_dim_idx, num_periods)
budgets_tiled = pt.broadcast_to(

Check warning on line 230 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L228-L230

Added lines #L228 - L230 were not covered by tests
pt.expand_dims(budgets, date_dim_idx), budgets_tiled_shape
)
# print(f"{budgets_tiled.type=}, {budgets_tiled_shape=}, {channel_data_dims=}")

budget_full_shape = list(tuple(budgets.shape))
budget_full_shape.insert(date_dim_idx, num_periods + max_lag)
budgets_full = pt.zeros(budget_full_shape)

Check warning on line 237 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L235-L237

Added lines #L235 - L237 were not covered by tests
# budgets_full = budgets_full[:num_periods, :].set(budgets_tiled)
budgets_full = pt.set_subtensor(budgets_full[:num_periods, :], budgets_tiled)
set_idxs = (*((slice(None),) * date_dim_idx), slice(None, num_periods))
budgets_full = pt.set_subtensor(budgets_full[set_idxs], budgets_tiled)
budgets_full.name = "budgets_full"

Check warning on line 241 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L239-L241

Added lines #L239 - L241 were not covered by tests

# Replace model free_RVs by placeholder variables
Expand Down Expand Up @@ -306,7 +354,9 @@ def allocate_budget(
}
)

initial_guess = np.ones(self._num_channels) * total_budget / self._num_channels
[budgets_size] = inputvars(self._create_budget_variable())[0].type.shape

Check warning on line 357 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L357

Added line #L357 was not covered by tests

initial_guess = np.ones(budgets_size) * total_budget / budgets_size

Check warning on line 359 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L359

Added line #L359 was not covered by tests
bounds = [
(
(budget_bounds[channel][0], budget_bounds[channel][1])
Expand All @@ -325,7 +375,7 @@ def allocate_budget(
result = minimize(
fun=self._objective,
x0=initial_guess,
bounds=bounds,
# bounds=bounds,
constraints=constraints_for_scipy,
jac=self._gradient,
**minimize_kwargs,
Expand Down
55 changes: 39 additions & 16 deletions pymc_marketing/mmm/constraints.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
# Copyright 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# constraints.py

import pytensor.tensor as pt
from pymc.pytensorf import inputvars, rewrite_pregrad
from pytensor import function
from pymc.pytensorf import rewrite_pregrad


def auto_jacobian(
constraint_fun,
Expand All @@ -11,11 +25,15 @@ def auto_jacobian(
Given a symbolic constraint function constraint_fun(budgets_sym, total_budget_sym, optimizer),
return a symbolic jacobian function that depends on the same variables.
"""

def _jac(budgets_sym, total_budget_sym, optimizer):
_fun = constraint_fun(budgets_sym, total_budget_sym, optimizer)
return pt.grad(rewrite_pregrad(_fun), budgets_sym)
[budgets_flat] = inputvars([budgets_sym])
return pt.grad(rewrite_pregrad(_fun), budgets_flat)

Check warning on line 32 in pymc_marketing/mmm/constraints.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/constraints.py#L29-L32

Added lines #L29 - L32 were not covered by tests

return _jac

Check warning on line 34 in pymc_marketing/mmm/constraints.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/constraints.py#L34

Added line #L34 was not covered by tests


def build_constraint(
key: str,
constraint_type: str,
Expand Down Expand Up @@ -43,10 +61,12 @@ def build_constraint(
"sym_jac": constraint_jac,
}


def build_default_sum_constraint(key: str = "default"):
"""
Returns a constraint dict that enforces sum(budgets) == total_budget.
Return a constraint dict that enforces sum(budgets) == total_budget.
"""

def _constraint_fun(budgets_sym, total_budget_sym, optimizer):
return pt.sum(budgets_sym) - total_budget_sym

Check warning on line 71 in pymc_marketing/mmm/constraints.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/constraints.py#L70-L71

Added lines #L70 - L71 were not covered by tests

Expand All @@ -57,12 +77,13 @@ def _constraint_fun(budgets_sym, total_budget_sym, optimizer):
constraint_jac=None,
)

def compile_constraints_for_scipy(constraints, optimizer):
""" compile constraints for scipy """

def compile_constraints_for_scipy(budgets, constraints, optimizer):
"""Compile constraints for scipy."""
compiled_constraints = []

Check warning on line 83 in pymc_marketing/mmm/constraints.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/constraints.py#L83

Added line #L83 was not covered by tests

budgets_sym = pt.vector("budgets")
total_budget_sym = pt.scalar("total_budget")
[budgets_flat] = inputvars([budgets])
total_budget = pt.scalar("total_budget")

Check warning on line 86 in pymc_marketing/mmm/constraints.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/constraints.py#L85-L86

Added lines #L85 - L86 were not covered by tests

for c in constraints.values() if isinstance(constraints, dict) else constraints:
ctype = c["type"]
Expand All @@ -71,19 +92,21 @@ def compile_constraints_for_scipy(constraints, optimizer):

# Compile symbolic => python callables
compiled_fun = function(

Check warning on line 94 in pymc_marketing/mmm/constraints.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/constraints.py#L94

Added line #L94 was not covered by tests
inputs=[budgets_sym, total_budget_sym],
outputs=sym_fun(budgets_sym, total_budget_sym, optimizer),
inputs=[budgets_flat, total_budget],
outputs=sym_fun(budgets, total_budget, optimizer),
on_unused_input="ignore",
)
compiled_jac = function(

Check warning on line 99 in pymc_marketing/mmm/constraints.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/constraints.py#L99

Added line #L99 was not covered by tests
inputs=[budgets_sym, total_budget_sym],
outputs=sym_jac(budgets_sym, total_budget_sym, optimizer),
inputs=[budgets_flat, total_budget],
outputs=sym_jac(budgets, total_budget, optimizer),
on_unused_input="ignore",
)

compiled_constraints.append({
"type": ctype,
"fun": compiled_fun,
"jac": compiled_jac,
})
compiled_constraints.append(

Check warning on line 105 in pymc_marketing/mmm/constraints.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/constraints.py#L105

Added line #L105 was not covered by tests
{
"type": ctype,
"fun": compiled_fun,
"jac": compiled_jac,
}
)
return compiled_constraints

Check warning on line 112 in pymc_marketing/mmm/constraints.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/constraints.py#L112

Added line #L112 was not covered by tests
21 changes: 18 additions & 3 deletions pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import json
import logging
import warnings
from typing import Annotated, Any, Literal, Sequence
from collections.abc import Sequence
from typing import Annotated, Any, Literal

import arviz as az
import matplotlib.pyplot as plt
Expand All @@ -32,7 +33,6 @@

from pymc_marketing.hsgp_kwargs import HSGPKwargs
from pymc_marketing.mmm.base import BaseValidateMMM

from pymc_marketing.mmm.components.adstock import (
AdstockTransformation,
adstock_from_dict,
Expand Down Expand Up @@ -2232,6 +2232,22 @@ def sample_response_distribution(
progressbar=False,
).merge(constant_data)

def _set_predictors_for_optimization(self, num_periods: int) -> pm.Model:
"""Return the respective PyMC model with any predictors set for optimization."""
model = self.model

Check warning on line 2237 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L2237

Added line #L2237 was not covered by tests
# Models with HSGP have a time_index data variable
if "time_index" in model.named_vars:
model = model.copy()
start_date = model["time_index"].get_value(borrow=True)[-1]
training_dates = (

Check warning on line 2242 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L2239-L2242

Added lines #L2239 - L2242 were not covered by tests
np.arange(num_periods + self.adstock.l_max) + start_date + 1
)
# Consider using pm.set_data, but then we need new coordinates
model["time_index"].set_value(

Check warning on line 2246 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L2246

Added line #L2246 was not covered by tests
training_dates.astype(model["time_index"].type.dtype)
)
return model

Check warning on line 2249 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L2249

Added line #L2249 was not covered by tests

def optimize_budget(
self,
budget: float | int,
Expand Down Expand Up @@ -2293,7 +2309,6 @@ def optimize_budget(
ValueError
If the noise level is not a float.
"""

from pymc_marketing.mmm.budget_optimizer import BudgetOptimizer

Check warning on line 2312 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L2312

Added line #L2312 was not covered by tests

allocator = BudgetOptimizer(
Expand Down

0 comments on commit 02ce740

Please sign in to comment.