Skip to content

Commit

Permalink
Scale input
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 13, 2025
1 parent 3f725ac commit aeeb060
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions pymc_marketing/mmm/budget_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def _create_budget_variable(self):
].set(budgets_flat)
else:
budgets = budgets_flat.reshape(budgets_shape)

return budgets

Check warning on line 147 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L146-L147

Added lines #L146 - L147 were not covered by tests

def set_constraints(self, constraints, default=None):
Expand Down Expand Up @@ -209,10 +208,13 @@ def extract_response_distribution(
if not (isinstance(budgets, pt.TensorVariable)): # and budgets.type.ndim == 1):
raise ValueError("budgets must be a TensorVariable")

Check warning on line 209 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L208-L209

Added lines #L208 - L209 were not covered by tests

num_periods = self.num_periods
model = self._pymc_model
posterior = self.hmm_model.idata.posterior # type: ignore
max_lag = self.hmm_model.adstock.l_max
num_periods = self.num_periods
channel_scales = self.hmm_model.scaler.input_scales[

Check warning on line 215 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L211-L215

Added lines #L211 - L215 were not covered by tests
self.hmm_model.channel_columns
].values

# Freeze all but channel dims for a more succinct graph
model = freeze_dims_and_data(

Check warning on line 220 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L220

Added line #L220 was not covered by tests
Expand All @@ -227,6 +229,8 @@ def extract_response_distribution(

budgets_tiled_shape = list(tuple(budgets.shape))
budgets_tiled_shape.insert(date_dim_idx, num_periods)

Check warning on line 231 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L230-L231

Added lines #L230 - L231 were not covered by tests
# TODO: If scales become part of the model, we don't need to transform it here
budgets /= channel_scales[:, None]
budgets_tiled = pt.broadcast_to(

Check warning on line 234 in pymc_marketing/mmm/budget_optimizer.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/budget_optimizer.py#L233-L234

Added lines #L233 - L234 were not covered by tests
pt.expand_dims(budgets, date_dim_idx), budgets_tiled_shape
)
Expand Down

0 comments on commit aeeb060

Please sign in to comment.