Skip to content

Commit

Permalink
fix bootstrap keyword bug
Browse files Browse the repository at this point in the history
  • Loading branch information
alanlujan91 committed Mar 4, 2024
1 parent d7d53ec commit ac3afa2
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
6 changes: 4 additions & 2 deletions src/estimagic/estimation/msm_weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_moments_cov(
moment_kwargs (dict): Additional keyword arguments for calculate_moments.
bootstrap_kwargs (dict): Additional keyword arguments that govern the
bootstrapping. Allowed arguments are "n_draws", "seed", "n_cores",
"batch_evaluator", "cluster" and "error_handling". For details see the
"batch_evaluator", "cluster_by" and "error_handling". For details see the
bootstrap function.
Returns:
Expand All @@ -39,8 +39,10 @@ def get_moments_cov(
"n_draws",
"seed",
"batch_evaluator",
"cluster",
"cluster_by",
"error_handling",
"existing_result",
"outcome_kwargs",
}
problematic = set(bootstrap_kwargs).difference(valid_bs_kwargs)
if problematic:
Expand Down
2 changes: 1 addition & 1 deletion src/estimagic/inference/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def bootstrap(
data (pd.DataFrame): Dataset.
existing_result (BootstrapResult): An existing BootstrapResult
object from a previous call of bootstrap(). Default is None.
outcome_kwargs (dict): Additional keyword arguments for outco me.
outcome_kwargs (dict): Additional keyword arguments for outcome.
n_draws (int): Number of bootstrap samples to draw.
If len(existing_outcomes) >= n_draws, a random subset of existing_outcomes
is used.
Expand Down
5 changes: 3 additions & 2 deletions tests/estimation/test_msm_weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,19 @@ def test_get_moments_cov_passes_bootstrap_kwargs_to_bootstrap():
rng = get_rng(1234)
data = rng.normal(scale=[10, 5, 1], size=(100, 3))
data = pd.DataFrame(data=data)
data["cluster"] = np.random.choice([1, 2, 3], size=100)

def calc_moments(data, keys):
means = data.mean()
means.index = keys
return means.to_dict()

moment_kwargs = {"keys": ["a", "b", "c"]}
moment_kwargs = {"keys": ["a", "b", "c", "cluster"]}

with pytest.raises(ValueError, match="a must be a positive integer unless no"):
get_moments_cov(
data=data,
calculate_moments=calc_moments,
moment_kwargs=moment_kwargs,
bootstrap_kwargs={"n_draws": -1},
bootstrap_kwargs={"n_draws": -1, "cluster_by": "cluster"},
)

0 comments on commit ac3afa2

Please sign in to comment.