Skip to content

Commit

Permalink
Optimize datatypes of network states (#5)
Browse files Browse the repository at this point in the history
* Optimize datatypes of network states

When simulation with no collective variable, network states were always saved in int64 or float64 which used unnecessary amount of memory.
numpy.min_scalar_type(num_opinions) determines the dtype.
  • Loading branch information
manuelosos authored Sep 2, 2024
1 parent 34c743a commit 7f0e4f1
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 3 deletions.
6 changes: 4 additions & 2 deletions sponet/cnvm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,13 @@ def simulate(
if self.params.network_generator is not None:
self.update_network()

opinion_dtype = np.min_scalar_type(self.params.num_opinions-1)

if x_init is None:
x_init = rng.choice(
np.arange(self.params.num_opinions), size=self.params.num_agents
)
x = np.copy(x_init).astype(int)
x = np.copy(x_init).astype(opinion_dtype)

t_delta = 0 if len_output is None else t_max / (len_output - 1)

Expand Down Expand Up @@ -142,7 +144,7 @@ def simulate(
)

t_traj = np.array(t_traj)
x_traj = np.array(x_traj, dtype=int)
x_traj = np.array(x_traj, dtype=opinion_dtype)
if len_output is None:
# remove duplicate subsequent states
mask = mask_subsequent_duplicates(x_traj)
Expand Down
6 changes: 5 additions & 1 deletion sponet/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,12 @@ def _sample_many_runs_subprocess(
model = model_type(params)

if collective_variable is None:

opinion_dtype = np.min_scalar_type(params.num_opinions-1)

x_out = np.zeros(
(num_initial_states, num_runs, num_timesteps, model.params.num_agents)
(num_initial_states, num_runs, num_timesteps, model.params.num_agents),
dtype=opinion_dtype
)
else:
x_out = np.zeros(
Expand Down
26 changes: 26 additions & 0 deletions tests/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,29 @@ def test_parallelization_with_cv(self):
self.num_opinions,
),
)

def test_output_dtype_no_cv(self):
num_opinions_list = [2, 257]
correct_dtype_list = [np.uint8, np.uint16]

for num_opinions, correct_dtype in zip(num_opinions_list, correct_dtype_list):

params = CNVMParameters(
num_opinions=num_opinions,
num_agents=self.num_agents,
r=1,
r_tilde=1,
)

t, x = sample_many_runs(
params=params,
initial_states=self.initial_states,
t_max=5,
num_timesteps=self.num_timesteps,
num_runs=2,
n_jobs=2,
collective_variable=None,
)

self.assertEqual(correct_dtype, x.dtype)

28 changes: 28 additions & 0 deletions tests/tests_cnvm/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def setUp(self):
r_tilde=self.r_tilde,
)


def test_output(self):
model = CNVM(self.params_complete)
t_max = 100
Expand Down Expand Up @@ -97,3 +98,30 @@ def test_output_concise(self):

for i in range(x.shape[0] - 1):
self.assertFalse(np.allclose(x[i], x[i + 1]))

def test_output_dtype(self):

num_opinions_list = [2, 257]
correct_dtype_list = [np.uint8, np.uint16]

for num_opinions, correct_dtype in zip(num_opinions_list, correct_dtype_list):
params = CNVMParameters(
num_opinions=num_opinions,
num_agents=self.num_agents,
r=1,
r_tilde=1,
)

model = CNVM(params)
t_max = 5
t, x = model.simulate(t_max)

self.assertEqual(correct_dtype, x.dtype)








0 comments on commit 7f0e4f1

Please sign in to comment.