Skip to content

Commit

Permalink
Fix random seed handling for "vi" (#869)
Browse files Browse the repository at this point in the history
* Fix random seed for "vi"

Random seed was not properly passed when `inference_method="vi"`. This is now resolved.

* Relaxed boundary values for beta regression

---------

Co-authored-by: Boje Deforce <[email protected]>
  • Loading branch information
B-Deforce and Boje Deforce authored Jan 14, 2025
1 parent 089d8e9 commit 2a4f7e1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions bambi/backend/pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def run(
**kwargs,
)
elif inference_method in self.pymc_methods["vi"]:
result = self._run_vi(**kwargs)
result = self._run_vi(random_seed, **kwargs)
elif inference_method == "laplace":
result = self._run_laplace(draws, omit_offsets, include_response_params)
else:
Expand Down Expand Up @@ -382,9 +382,9 @@ def _clean_results(self, idata, omit_offsets, include_response_params, idata_fro

return idata

def _run_vi(self, **kwargs):
def _run_vi(self, random_seed, **kwargs):
with self.model:
self.vi_approx = pm.fit(**kwargs)
self.vi_approx = pm.fit(random_seed=random_seed, **kwargs)
return self.vi_approx

def _run_laplace(self, draws, omit_offsets, include_response_params):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,16 +824,16 @@ def test_beta_regression(self, gasoline_data):
model.predict(idata, kind="response")

assert (0 < idata.posterior["mu"]).all() & (idata.posterior["mu"] < 1).all()
assert (0 < idata.posterior_predictive["yield"]).all() & (
idata.posterior_predictive["yield"] < 1
assert (0 <= idata.posterior_predictive["yield"]).all() & (
idata.posterior_predictive["yield"] <= 1
).all()

model.predict(idata, kind="response_params", data=gasoline_data.iloc[:20, :])
model.predict(idata, kind="response", data=gasoline_data.iloc[:20, :])

assert (0 < idata.posterior["mu"]).all() & (idata.posterior["mu"] < 1).all()
assert (0 < idata.posterior_predictive["yield"]).all() & (
idata.posterior_predictive["yield"] < 1
assert (0 <= idata.posterior_predictive["yield"]).all() & (
idata.posterior_predictive["yield"] <= 1
).all()


Expand Down

0 comments on commit 2a4f7e1

Please sign in to comment.