Skip to content

Commit

Permalink
Add test case with poisson and zeros
Browse files Browse the repository at this point in the history
Added a test case with poisson log-likelihood and zero valued data with
the `zero_to_one` flag set to `True`.
  • Loading branch information
TimothyWillard committed Dec 10, 2024
1 parent 8e5baa3 commit 10e5d98
Showing 1 changed file with 51 additions and 2 deletions.
53 changes: 51 additions & 2 deletions flepimop/gempyor_pkg/tests/statistics/test_statistic_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,48 @@ def simple_valid_factory_with_pois() -> MockStatisticInput:
)


def simple_valid_factory_with_pois_with_some_zeros() -> MockStatisticInput:
mock_input = simple_valid_factory_with_pois()

mock_input.config["zero_to_one"] = True

mock_input.model_data["incidH"].loc[
{
"date": mock_input.model_data.coords["date"][0],
"subpop": mock_input.model_data.coords["subpop"][0],
}
] = 0

mock_input.gt_data["incidD"].loc[
{
"date": mock_input.gt_data.coords["date"][0],
"subpop": mock_input.gt_data.coords["subpop"][0],
}
] = 0

mock_input.model_data["incidH"].loc[
{
"date": mock_input.model_data.coords["date"][1],
"subpop": mock_input.model_data.coords["subpop"][1],
}
] = 0
mock_input.gt_data["incidH"].loc[
{
"date": mock_input.gt_data.coords["date"][1],
"subpop": mock_input.gt_data.coords["subpop"][1],
}
] = 0

return mock_input


all_valid_factories = [
(simple_valid_factory),
(simple_valid_resample_factory),
(simple_valid_resample_factory),
(simple_valid_resample_and_scale_factory),
(simple_valid_factory_with_pois),
(simple_valid_factory_with_pois_with_some_zeros),
]


Expand Down Expand Up @@ -549,8 +585,21 @@ def test_llik(self, factory: Callable[[], MockStatisticInput]) -> None:
assert np.allclose(
log_likelihood.values,
scipy.stats.poisson.logpmf(
mock_inputs.gt_data[mock_inputs.config["data_var"]].values,
mock_inputs.model_data[mock_inputs.config["data_var"]].values,
np.where(
mock_inputs.config.get("zero_to_one", False)
& (mock_inputs.gt_data[mock_inputs.config["data_var"]].values == 0),
1,
mock_inputs.gt_data[mock_inputs.config["data_var"]].values,
),
np.where(
mock_inputs.config.get("zero_to_one", False)
& (
mock_inputs.model_data[mock_inputs.config["data_var"]].values
== 0
),
1,
mock_inputs.model_data[mock_inputs.config["data_var"]].values,
),
),
)
elif dist_name in {"norm", "norm_cov"}:
Expand Down

0 comments on commit 10e5d98

Please sign in to comment.