diff --git a/.gitignore b/.gitignore index 94ead7f7..3d49b035 100644 --- a/.gitignore +++ b/.gitignore @@ -362,3 +362,5 @@ Java\ model/ examples/output +!tests/fixtures/data/test_sir_python.npy +!tests/fixtures/data/test_sir_python_w_breaks_python.npy \ No newline at end of file diff --git a/examples/models/sir/sir_python.py b/examples/models/sir/sir_python.py index 7154afdf..676f7460 100644 --- a/examples/models/sir/sir_python.py +++ b/examples/models/sir/sir_python.py @@ -46,12 +46,10 @@ def SIR(theta: NDArray, N: int, seed: int | None) -> NDArray: # noqa: N802, N80 Returns: simulated series """ - np.random.seed(seed=seed) - num_agents = 100000 g = nx.watts_strogatz_graph(num_agents, int(theta[0]), theta[1], seed=theta[5]) - model = ep.SIRModel(g) + model = ep.SIRModel(g, seed=seed) cfg = ModelConfig.Configuration() cfg.add_model_parameter("beta", theta[3]) # infection rate @@ -102,12 +100,10 @@ def SIR_w_breaks( # noqa: N802 Returns: simulated series """ - np.random.seed(seed=seed) - num_agents = 100000 g = nx.watts_strogatz_graph(num_agents, int(theta[0]), theta[1], seed=theta[11]) - model = ep.SIRModel(g) + model = ep.SIRModel(g, seed=seed) cfg = ModelConfig.Configuration() cfg.add_model_parameter("beta", theta[3]) # infection rate diff --git a/tests/fixtures/data/test_sir_python.npy b/tests/fixtures/data/test_sir_python.npy new file mode 100644 index 00000000..5825c50b Binary files /dev/null and b/tests/fixtures/data/test_sir_python.npy differ diff --git a/tests/test_examples/test_sir_python.py b/tests/test_examples/test_sir_python.py new file mode 100644 index 00000000..20ca416e --- /dev/null +++ b/tests/test_examples/test_sir_python.py @@ -0,0 +1,88 @@ +# Black-box ABM Calibration Kit (Black-it) +# Copyright (C) 2021-2023 Banca d'Italia +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Test the SIR model implementation in Python.""" + +import numpy as np + +from examples.models.sir.sir_python import SIR, SIR_w_breaks +from tests.conftest import TEST_DIR + + +def test_sir() -> None: + """Test the 'SIR' function in examples/models/sir/sir_python.py.""" + expected_output = np.load(TEST_DIR / "fixtures" / "data" / "test_sir_python.npy") + model_seed = 0 + + lattice_order = 20 + rewire_probability = 0.2 + percentage_infected = 0.05 + beta = 0.2 + gamma = 0.15 + networkx_seed = 0 + theta = [ + lattice_order, + rewire_probability, + percentage_infected, + beta, + gamma, + networkx_seed, + ] + + n = 100 + output = SIR(theta, n, seed=model_seed) + + assert np.isclose(output, expected_output).all() + + +def test_sir_w_breaks() -> None: + """Test the 'SIR_w_breaks' function in examples/models/sir/sir_python.py.""" + expected_output = np.load( + TEST_DIR / "fixtures" / "data" / "test_sir_w_breaks_python.npy", + ) + model_seed = 0 + + lattice_order = 20 + rewire_probability = 0.2 + percentage_infected = 0.05 + beta_1 = 0.2 + gamma_1 = 0.15 + beta_2 = 0.3 + beta_3 = 0.1 + beta_4 = 0.01 + t_break_1 = 10 + t_break_2 = 20 + t_break_3 = 30 + networkx_seed = 0 + theta = [ + lattice_order, + rewire_probability, + percentage_infected, + beta_1, + gamma_1, + beta_2, + beta_3, + beta_4, + t_break_1, + t_break_2, + t_break_3, + networkx_seed, + ] + + n = 100 + output = SIR_w_breaks(theta, n, seed=model_seed) + + assert np.isclose(output, expected_output).all()