Skip to content

Commit

Permalink
Refactor get model dimensions from labels
Browse files Browse the repository at this point in the history
  • Loading branch information
Arun-Niranjan committed Dec 30, 2023
1 parent 5a0bf1a commit 357cc0d
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 27 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ __pycache__
.ipynb_checkpoints/
.pytest_cache
env/
pymdp.egg-info
pymdp.egg-info
inferactively_pymdp.egg-info
69 changes: 43 additions & 26 deletions pymdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,27 @@

EPS_VAL = 1e-16 # global constant for use in norm_dist()

class Dimensions(object):
"""
The Dimensions class stores all data related to the size and shape of a model.
"""
def __init__(
self,
num_observations=None,
num_observation_modalities=0,
num_states=None,
num_state_factors=0,
num_controls=None,
num_control_factors=0,
):
self.num_observations=num_observations
self.num_observation_modalities=num_observation_modalities
self.num_states=num_states
self.num_state_factors=num_state_factors
self.num_controls=num_controls
self.num_control_factors=num_control_factors


def sample(probabilities):
probabilities = probabilities.squeeze() if len(probabilities) > 1 else probabilities
sample_onehot = np.random.multinomial(1, probabilities)
Expand Down Expand Up @@ -196,22 +217,22 @@ def get_model_dimensions(A=None, B=None):
def get_model_dimensions_from_labels(model_labels):

modalities = model_labels['observations']
num_modalities = len(modalities.keys())
num_obs = [len(modalities[modality]) for modality in modalities.keys()]

factors = model_labels['states']
num_factors = len(factors.keys())
num_states = [len(factors[factor]) for factor in factors.keys()]

if 'actions' in model_labels.keys():
res = Dimensions(
num_observations=[len(modalities[modality]) for modality in modalities.keys()],
num_observation_modalities=len(modalities.keys()),
num_states=[len(factors[factor]) for factor in factors.keys()],
num_state_factors=len(factors.keys()),
)

if 'actions' in model_labels.keys():
controls = model_labels['actions']
num_control_fac = len(controls.keys())
num_controls = [len(controls[cfac]) for cfac in controls.keys()]
res.num_controls=[len(controls[cfac]) for cfac in controls.keys()]
res.num_control_factors=len(controls.keys())

return res

return num_obs, num_modalities, num_states, num_factors, num_controls, num_control_fac
else:
return num_obs, num_modalities, num_states, num_factors


def norm_dist(dist):
Expand Down Expand Up @@ -449,21 +470,18 @@ def construct_full_a(A_reduced, original_factor_idx, num_states):

def create_A_matrix_stub(model_labels):

num_obs, _, num_states, _= get_model_dimensions_from_labels(model_labels)
dimensions = get_model_dimensions_from_labels(model_labels)

obs_labels, state_labels = model_labels['observations'], model_labels['states']

state_combinations = pd.MultiIndex.from_product(list(state_labels.values()), names=list(state_labels.keys()))
num_state_combos = np.prod(num_states)
# num_rows = (np.array(num_obs) * num_state_combos).sum()
num_rows = sum(num_obs)
num_rows = sum(dimensions.num_observations)

cell_values = np.zeros((num_rows, len(state_combinations)))

obs_combinations = []
for modality in obs_labels.keys():
levels_to_combine = [[modality]] + [obs_labels[modality]]
# obs_combinations += num_state_combos * list(itertools.product(*levels_to_combine))
obs_combinations += list(itertools.product(*levels_to_combine))


Expand All @@ -475,7 +493,7 @@ def create_A_matrix_stub(model_labels):

def create_B_matrix_stubs(model_labels):

_, _, num_states, _, num_controls, _ = get_model_dimensions_from_labels(model_labels)
dimensions = get_model_dimensions_from_labels(model_labels)

state_labels = model_labels['states']
action_labels = model_labels['actions']
Expand All @@ -489,9 +507,9 @@ def create_B_matrix_stubs(model_labels):

prev_state_action_combos = pd.MultiIndex.from_product(factor_list, names=[factor, list(action_labels.keys())[f_idx]])

num_state_action_combos = num_states[f_idx] * num_controls[f_idx]
num_state_action_combos = dimensions.num_states[f_idx] * dimensions.num_controls[f_idx]

num_rows = num_states[f_idx]
num_rows = dimensions.num_states[f_idx]

cell_values = np.zeros((num_rows, num_state_action_combos))

Expand Down Expand Up @@ -544,13 +562,12 @@ def convert_A_stub_to_ndarray(A_stub, model_labels):
This function converts a multi-index pandas dataframe `A_stub` into an object array of different
A matrices, one per observation modality.
"""
dimensions = get_model_dimensions_from_labels(model_labels)

num_obs, num_modalities, num_states, num_factors = get_model_dimensions_from_labels(model_labels)

A = obj_array(num_modalities)
A = obj_array(dimensions.num_observation_modalities)

for g, modality_name in enumerate(model_labels['observations'].keys()):
A[g] = A_stub.loc[modality_name].to_numpy().reshape(num_obs[g], *num_states)
A[g] = A_stub.loc[modality_name].to_numpy().reshape(dimensions.num_observations[g], *dimensions.num_states)
assert (A[g].sum(axis=0) == 1.0).all(), 'A matrix not normalized! Check your initialization....\n'

return A
Expand All @@ -561,13 +578,13 @@ def convert_B_stubs_to_ndarray(B_stubs, model_labels):
of different B matrices, one per hidden state factor
"""

_, _, num_states, num_factors, num_controls, num_control_fac = get_model_dimensions_from_labels(model_labels)
dimensions = get_model_dimensions_from_labels(model_labels)

B = obj_array(num_factors)
B = obj_array(dimensions.num_control_factors)

for f, factor_name in enumerate(B_stubs.keys()):

B[f] = B_stubs[factor_name].to_numpy().reshape(num_states[f], num_states[f], num_controls[f])
B[f] = B_stubs[factor_name].to_numpy().reshape(dimensions.num_states[f], dimensions.num_states[f], dimensions.num_controls[f])
assert (B[f].sum(axis=0) == 1.0).all(), 'B matrix not normalized! Check your initialization....\n'

return B
Expand Down
58 changes: 58 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@

import unittest

from pymdp.utils import get_model_dimensions_from_labels, Dimensions

class TestUtils(unittest.TestCase):
def test_get_model_dimensions_from_labels(self):
"""
Tests model dimension extraction from labels including observations, states and actions.
"""
model_labels = {
"observations": {
"species_observation": [
"absent",
"present",
],
"budget_observation": [
"high",
"medium",
"low",
],
},
"states": {
"species_state": [
"extant",
"extinct",
],
},
"actions": {
"conservation_action": [
"manage",
"survey",
"stop",
],
},
}

want = Dimensions(
num_observations=[2, 3],
num_observation_modalities=2,
num_states=[2],
num_state_factors=1,
num_controls=[3],
num_control_factors=1,
)

got = get_model_dimensions_from_labels(model_labels)

self.assertEqual(want.num_observations, got.num_observations)
self.assertEqual(want.num_observation_modalities, got.num_observation_modalities)
self.assertEqual(want.num_states, got.num_states)
self.assertEqual(want.num_state_factors, got.num_state_factors)
self.assertEqual(want.num_controls, got.num_controls)
self.assertEqual(want.num_control_factors, got.num_control_factors)


if __name__ == "__main__":
unittest.main()

0 comments on commit 357cc0d

Please sign in to comment.