From 31060a54bf741e33019c2dc9993292110b89cc8e Mon Sep 17 00:00:00 2001 From: Marco Favorito Date: Wed, 30 Aug 2023 17:21:09 +0200 Subject: [PATCH] lint: fix NPY002 for examples/models/sir/sir_python.py Fixed NPY002 issues in SIR models implemented in examples/models/sir/sir_python.py. Added tests to verify the new seed management works. --- .gitignore | 2 + examples/models/sir/sir_python.py | 8 +-- tests/fixtures/data/test_sir_python.npy | Bin 0 -> 2528 bytes tests/test_examples/test_sir_python.py | 88 ++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 6 deletions(-) create mode 100644 tests/fixtures/data/test_sir_python.npy create mode 100644 tests/test_examples/test_sir_python.py diff --git a/.gitignore b/.gitignore index 94ead7f7..3d49b035 100644 --- a/.gitignore +++ b/.gitignore @@ -362,3 +362,5 @@ Java\ model/ examples/output +!tests/fixtures/data/test_sir_python.npy +!tests/fixtures/data/test_sir_python_w_breaks_python.npy \ No newline at end of file diff --git a/examples/models/sir/sir_python.py b/examples/models/sir/sir_python.py index 7154afdf..676f7460 100644 --- a/examples/models/sir/sir_python.py +++ b/examples/models/sir/sir_python.py @@ -46,12 +46,10 @@ def SIR(theta: NDArray, N: int, seed: int | None) -> NDArray: # noqa: N802, N80 Returns: simulated series """ - np.random.seed(seed=seed) - num_agents = 100000 g = nx.watts_strogatz_graph(num_agents, int(theta[0]), theta[1], seed=theta[5]) - model = ep.SIRModel(g) + model = ep.SIRModel(g, seed=seed) cfg = ModelConfig.Configuration() cfg.add_model_parameter("beta", theta[3]) # infection rate @@ -102,12 +100,10 @@ def SIR_w_breaks( # noqa: N802 Returns: simulated series """ - np.random.seed(seed=seed) - num_agents = 100000 g = nx.watts_strogatz_graph(num_agents, int(theta[0]), theta[1], seed=theta[11]) - model = ep.SIRModel(g) + model = ep.SIRModel(g, seed=seed) cfg = ModelConfig.Configuration() cfg.add_model_parameter("beta", theta[3]) # infection rate diff --git a/tests/fixtures/data/test_sir_python.npy b/tests/fixtures/data/test_sir_python.npy new file mode 100644 index 0000000000000000000000000000000000000000..5825c50b3d91dda1c471078362d77927e30b0f1f GIT binary patch literal 2528 zcmeIxNoW%R6b9hei)iqWn>lzGBs3L*Of5=~Y6m205xa<$Sc07{iEWWKrged$4243F z3UN7j5IP`u@Y2DG;zdXF;EE$C6@(HMt%%FG)#AzIf8$}^AXM?#cX|9@ejdprIo;H@ zqh(JO9iov~PiCk)AKReB>a&ShT#01|^TYYhfsVm^Pp0BqI`Yi)0Qu8niHxAN=UCbFZixIg*9*VpSy$j)}x z=i0<`WOekj-w#eBTjTtG$s4We345RS`8>CmKZZNi->^LCNzPRyb{g(?01* zxE}D$d|$eu!CmRJ5w7;i`|IFVz}Lb{0k7rz0guC#K6$?et_8dr-V*RATo3q4croA+ yxVc~tqi+ZLOX1FfozO1@`hRk_fy(0s`hV+{E9L2z&zX7A7oUUo1bF?&-^nlYK|k66 literal 0 HcmV?d00001 diff --git a/tests/test_examples/test_sir_python.py b/tests/test_examples/test_sir_python.py new file mode 100644 index 00000000..20ca416e --- /dev/null +++ b/tests/test_examples/test_sir_python.py @@ -0,0 +1,88 @@ +# Black-box ABM Calibration Kit (Black-it) +# Copyright (C) 2021-2023 Banca d'Italia +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Test the SIR model implementation in Python.""" + +import numpy as np + +from examples.models.sir.sir_python import SIR, SIR_w_breaks +from tests.conftest import TEST_DIR + + +def test_sir() -> None: + """Test the 'SIR' function in examples/models/sir/sir_python.py.""" + expected_output = np.load(TEST_DIR / "fixtures" / "data" / "test_sir_python.npy") + model_seed = 0 + + lattice_order = 20 + rewire_probability = 0.2 + percentage_infected = 0.05 + beta = 0.2 + gamma = 0.15 + networkx_seed = 0 + theta = [ + lattice_order, + rewire_probability, + percentage_infected, + beta, + gamma, + networkx_seed, + ] + + n = 100 + output = SIR(theta, n, seed=model_seed) + + assert np.isclose(output, expected_output).all() + + +def test_sir_w_breaks() -> None: + """Test the 'SIR_w_breaks' function in examples/models/sir/sir_python.py.""" + expected_output = np.load( + TEST_DIR / "fixtures" / "data" / "test_sir_w_breaks_python.npy", + ) + model_seed = 0 + + lattice_order = 20 + rewire_probability = 0.2 + percentage_infected = 0.05 + beta_1 = 0.2 + gamma_1 = 0.15 + beta_2 = 0.3 + beta_3 = 0.1 + beta_4 = 0.01 + t_break_1 = 10 + t_break_2 = 20 + t_break_3 = 30 + networkx_seed = 0 + theta = [ + lattice_order, + rewire_probability, + percentage_infected, + beta_1, + gamma_1, + beta_2, + beta_3, + beta_4, + t_break_1, + t_break_2, + t_break_3, + networkx_seed, + ] + + n = 100 + output = SIR_w_breaks(theta, n, seed=model_seed) + + assert np.isclose(output, expected_output).all()