From b59cfcd2e0974ce6ce0e6da2a27f7b07726eafd6 Mon Sep 17 00:00:00 2001 From: Arun-Niranjan Date: Sat, 30 Dec 2023 16:56:01 +0000 Subject: [PATCH] Refactor get model dimensions from labels --- .gitignore | 3 +- pymdp/utils.py | 69 +++++++++++++++++++++++++++++----------------- test/test_utils.py | 58 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 103 insertions(+), 27 deletions(-) create mode 100644 test/test_utils.py diff --git a/.gitignore b/.gitignore index 778d69dd..5f24acf9 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ __pycache__ .ipynb_checkpoints/ .pytest_cache env/ -pymdp.egg-info \ No newline at end of file +pymdp.egg-info +inferactively_pymdp.egg-info diff --git a/pymdp/utils.py b/pymdp/utils.py index 60bc41e0..bd985cf5 100644 --- a/pymdp/utils.py +++ b/pymdp/utils.py @@ -16,6 +16,27 @@ EPS_VAL = 1e-16 # global constant for use in norm_dist() +class Dimensions(object): + """ + The Dimensions class stores all data related to the size and shape of a model. + """ + def __init__( + self, + num_observations=None, + num_observation_modalities=0, + num_states=None, + num_state_factors=0, + num_controls=None, + num_control_factors=0, + ): + self.num_observations=num_observations + self.num_observation_modalities=num_observation_modalities + self.num_states=num_states + self.num_state_factors=num_state_factors + self.num_controls=num_controls + self.num_control_factors=num_control_factors + + def sample(probabilities): probabilities = probabilities.squeeze() if len(probabilities) > 1 else probabilities sample_onehot = np.random.multinomial(1, probabilities) @@ -211,22 +232,22 @@ def get_model_dimensions(A=None, B=None, factorized=False): def get_model_dimensions_from_labels(model_labels): modalities = model_labels['observations'] - num_modalities = len(modalities.keys()) - num_obs = [len(modalities[modality]) for modality in modalities.keys()] - factors = model_labels['states'] - num_factors = len(factors.keys()) - num_states = [len(factors[factor]) for factor in factors.keys()] - if 'actions' in model_labels.keys(): + res = Dimensions( + num_observations=[len(modalities[modality]) for modality in modalities.keys()], + num_observation_modalities=len(modalities.keys()), + num_states=[len(factors[factor]) for factor in factors.keys()], + num_state_factors=len(factors.keys()), + ) + if 'actions' in model_labels.keys(): controls = model_labels['actions'] - num_control_fac = len(controls.keys()) - num_controls = [len(controls[cfac]) for cfac in controls.keys()] + res.num_controls=[len(controls[cfac]) for cfac in controls.keys()] + res.num_control_factors=len(controls.keys()) + + return res - return num_obs, num_modalities, num_states, num_factors, num_controls, num_control_fac - else: - return num_obs, num_modalities, num_states, num_factors def norm_dist(dist): @@ -464,21 +485,18 @@ def construct_full_a(A_reduced, original_factor_idx, num_states): def create_A_matrix_stub(model_labels): - num_obs, _, num_states, _= get_model_dimensions_from_labels(model_labels) + dimensions = get_model_dimensions_from_labels(model_labels) obs_labels, state_labels = model_labels['observations'], model_labels['states'] state_combinations = pd.MultiIndex.from_product(list(state_labels.values()), names=list(state_labels.keys())) - num_state_combos = np.prod(num_states) - # num_rows = (np.array(num_obs) * num_state_combos).sum() - num_rows = sum(num_obs) + num_rows = sum(dimensions.num_observations) cell_values = np.zeros((num_rows, len(state_combinations))) obs_combinations = [] for modality in obs_labels.keys(): levels_to_combine = [[modality]] + [obs_labels[modality]] - # obs_combinations += num_state_combos * list(itertools.product(*levels_to_combine)) obs_combinations += list(itertools.product(*levels_to_combine)) @@ -490,7 +508,7 @@ def create_A_matrix_stub(model_labels): def create_B_matrix_stubs(model_labels): - _, _, num_states, _, num_controls, _ = get_model_dimensions_from_labels(model_labels) + dimensions = get_model_dimensions_from_labels(model_labels) state_labels = model_labels['states'] action_labels = model_labels['actions'] @@ -504,9 +522,9 @@ def create_B_matrix_stubs(model_labels): prev_state_action_combos = pd.MultiIndex.from_product(factor_list, names=[factor, list(action_labels.keys())[f_idx]]) - num_state_action_combos = num_states[f_idx] * num_controls[f_idx] + num_state_action_combos = dimensions.num_states[f_idx] * dimensions.num_controls[f_idx] - num_rows = num_states[f_idx] + num_rows = dimensions.num_states[f_idx] cell_values = np.zeros((num_rows, num_state_action_combos)) @@ -559,13 +577,12 @@ def convert_A_stub_to_ndarray(A_stub, model_labels): This function converts a multi-index pandas dataframe `A_stub` into an object array of different A matrices, one per observation modality. """ + dimensions = get_model_dimensions_from_labels(model_labels) - num_obs, num_modalities, num_states, num_factors = get_model_dimensions_from_labels(model_labels) - - A = obj_array(num_modalities) + A = obj_array(dimensions.num_observation_modalities) for g, modality_name in enumerate(model_labels['observations'].keys()): - A[g] = A_stub.loc[modality_name].to_numpy().reshape(num_obs[g], *num_states) + A[g] = A_stub.loc[modality_name].to_numpy().reshape(dimensions.num_observations[g], *dimensions.num_states) assert (A[g].sum(axis=0) == 1.0).all(), 'A matrix not normalized! Check your initialization....\n' return A @@ -576,13 +593,13 @@ def convert_B_stubs_to_ndarray(B_stubs, model_labels): of different B matrices, one per hidden state factor """ - _, _, num_states, num_factors, num_controls, num_control_fac = get_model_dimensions_from_labels(model_labels) + dimensions = get_model_dimensions_from_labels(model_labels) - B = obj_array(num_factors) + B = obj_array(dimensions.num_control_factors) for f, factor_name in enumerate(B_stubs.keys()): - B[f] = B_stubs[factor_name].to_numpy().reshape(num_states[f], num_states[f], num_controls[f]) + B[f] = B_stubs[factor_name].to_numpy().reshape(dimensions.num_states[f], dimensions.num_states[f], dimensions.num_controls[f]) assert (B[f].sum(axis=0) == 1.0).all(), 'B matrix not normalized! Check your initialization....\n' return B diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 00000000..0a1ed066 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,58 @@ + +import unittest + +from pymdp.utils import get_model_dimensions_from_labels, Dimensions + +class TestUtils(unittest.TestCase): + def test_get_model_dimensions_from_labels(self): + """ + Tests model dimension extraction from labels including observations, states and actions. + """ + model_labels = { + "observations": { + "species_observation": [ + "absent", + "present", + ], + "budget_observation": [ + "high", + "medium", + "low", + ], + }, + "states": { + "species_state": [ + "extant", + "extinct", + ], + }, + "actions": { + "conservation_action": [ + "manage", + "survey", + "stop", + ], + }, + } + + want = Dimensions( + num_observations=[2, 3], + num_observation_modalities=2, + num_states=[2], + num_state_factors=1, + num_controls=[3], + num_control_factors=1, + ) + + got = get_model_dimensions_from_labels(model_labels) + + self.assertEqual(want.num_observations, got.num_observations) + self.assertEqual(want.num_observation_modalities, got.num_observation_modalities) + self.assertEqual(want.num_states, got.num_states) + self.assertEqual(want.num_state_factors, got.num_state_factors) + self.assertEqual(want.num_controls, got.num_controls) + self.assertEqual(want.num_control_factors, got.num_control_factors) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file