Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor get model dimensions from labels #126

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ __pycache__
.ipynb_checkpoints/
.pytest_cache
env/
pymdp.egg-info
pymdp.egg-info
inferactively_pymdp.egg-info
69 changes: 43 additions & 26 deletions pymdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,27 @@

EPS_VAL = 1e-16 # global constant for use in norm_dist()

class Dimensions(object):
"""
The Dimensions class stores all data related to the size and shape of a model.
"""
def __init__(
self,
num_observations=None,
num_observation_modalities=0,
num_states=None,
num_state_factors=0,
num_controls=None,
num_control_factors=0,
):
self.num_observations=num_observations
self.num_observation_modalities=num_observation_modalities
self.num_states=num_states
self.num_state_factors=num_state_factors
self.num_controls=num_controls
self.num_control_factors=num_control_factors


def sample(probabilities):
probabilities = probabilities.squeeze() if len(probabilities) > 1 else probabilities
sample_onehot = np.random.multinomial(1, probabilities)
Expand Down Expand Up @@ -196,22 +217,22 @@ def get_model_dimensions(A=None, B=None):
def get_model_dimensions_from_labels(model_labels):

modalities = model_labels['observations']
num_modalities = len(modalities.keys())
num_obs = [len(modalities[modality]) for modality in modalities.keys()]

factors = model_labels['states']
num_factors = len(factors.keys())
num_states = [len(factors[factor]) for factor in factors.keys()]

if 'actions' in model_labels.keys():
res = Dimensions(
num_observations=[len(modalities[modality]) for modality in modalities.keys()],
num_observation_modalities=len(modalities.keys()),
num_states=[len(factors[factor]) for factor in factors.keys()],
num_state_factors=len(factors.keys()),
)

if 'actions' in model_labels.keys():
controls = model_labels['actions']
num_control_fac = len(controls.keys())
num_controls = [len(controls[cfac]) for cfac in controls.keys()]
res.num_controls=[len(controls[cfac]) for cfac in controls.keys()]
res.num_control_factors=len(controls.keys())

return res

return num_obs, num_modalities, num_states, num_factors, num_controls, num_control_fac
else:
return num_obs, num_modalities, num_states, num_factors


def norm_dist(dist):
Expand Down Expand Up @@ -449,21 +470,18 @@ def construct_full_a(A_reduced, original_factor_idx, num_states):

def create_A_matrix_stub(model_labels):

num_obs, _, num_states, _= get_model_dimensions_from_labels(model_labels)
dimensions = get_model_dimensions_from_labels(model_labels)

obs_labels, state_labels = model_labels['observations'], model_labels['states']

state_combinations = pd.MultiIndex.from_product(list(state_labels.values()), names=list(state_labels.keys()))
num_state_combos = np.prod(num_states)
# num_rows = (np.array(num_obs) * num_state_combos).sum()
num_rows = sum(num_obs)
num_rows = sum(dimensions.num_observations)

cell_values = np.zeros((num_rows, len(state_combinations)))

obs_combinations = []
for modality in obs_labels.keys():
levels_to_combine = [[modality]] + [obs_labels[modality]]
# obs_combinations += num_state_combos * list(itertools.product(*levels_to_combine))
obs_combinations += list(itertools.product(*levels_to_combine))


Expand All @@ -475,7 +493,7 @@ def create_A_matrix_stub(model_labels):

def create_B_matrix_stubs(model_labels):

_, _, num_states, _, num_controls, _ = get_model_dimensions_from_labels(model_labels)
dimensions = get_model_dimensions_from_labels(model_labels)

state_labels = model_labels['states']
action_labels = model_labels['actions']
Expand All @@ -489,9 +507,9 @@ def create_B_matrix_stubs(model_labels):

prev_state_action_combos = pd.MultiIndex.from_product(factor_list, names=[factor, list(action_labels.keys())[f_idx]])

num_state_action_combos = num_states[f_idx] * num_controls[f_idx]
num_state_action_combos = dimensions.num_states[f_idx] * dimensions.num_controls[f_idx]

num_rows = num_states[f_idx]
num_rows = dimensions.num_states[f_idx]

cell_values = np.zeros((num_rows, num_state_action_combos))

Expand Down Expand Up @@ -544,13 +562,12 @@ def convert_A_stub_to_ndarray(A_stub, model_labels):
This function converts a multi-index pandas dataframe `A_stub` into an object array of different
A matrices, one per observation modality.
"""
dimensions = get_model_dimensions_from_labels(model_labels)

num_obs, num_modalities, num_states, num_factors = get_model_dimensions_from_labels(model_labels)

A = obj_array(num_modalities)
A = obj_array(dimensions.num_observation_modalities)

for g, modality_name in enumerate(model_labels['observations'].keys()):
A[g] = A_stub.loc[modality_name].to_numpy().reshape(num_obs[g], *num_states)
A[g] = A_stub.loc[modality_name].to_numpy().reshape(dimensions.num_observations[g], *dimensions.num_states)
assert (A[g].sum(axis=0) == 1.0).all(), 'A matrix not normalized! Check your initialization....\n'

return A
Expand All @@ -561,13 +578,13 @@ def convert_B_stubs_to_ndarray(B_stubs, model_labels):
of different B matrices, one per hidden state factor
"""

_, _, num_states, num_factors, num_controls, num_control_fac = get_model_dimensions_from_labels(model_labels)
dimensions = get_model_dimensions_from_labels(model_labels)

B = obj_array(num_factors)
B = obj_array(dimensions.num_control_factors)

for f, factor_name in enumerate(B_stubs.keys()):

B[f] = B_stubs[factor_name].to_numpy().reshape(num_states[f], num_states[f], num_controls[f])
B[f] = B_stubs[factor_name].to_numpy().reshape(dimensions.num_states[f], dimensions.num_states[f], dimensions.num_controls[f])
assert (B[f].sum(axis=0) == 1.0).all(), 'B matrix not normalized! Check your initialization....\n'

return B
Expand Down
62 changes: 56 additions & 6 deletions test/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import os
import unittest
from pathlib import Path
import shutil
import tempfile

import numpy as np
import itertools
import pandas as pd
from pandas.testing import assert_frame_equal

from pymdp.utils import create_A_matrix_stub, read_A_matrix, create_B_matrix_stubs, read_B_matrices
from pymdp.utils import Dimensions, get_model_dimensions_from_labels, create_A_matrix_stub, read_A_matrix, create_B_matrix_stubs, read_B_matrices

tmp_path = Path('tmp_dir')

Expand All @@ -18,11 +14,62 @@

class TestWrappers(unittest.TestCase):

def test_get_model_dimensions_from_labels(self):
"""
Tests model dimension extraction from labels including observations, states and actions.
"""
model_labels = {
"observations": {
"species_observation": [
"absent",
"present",
],
"budget_observation": [
"high",
"medium",
"low",
],
},
"states": {
"species_state": [
"extant",
"extinct",
],
},
"actions": {
"conservation_action": [
"manage",
"survey",
"stop",
],
},
}

want = Dimensions(
num_observations=[2, 3],
num_observation_modalities=2,
num_states=[2],
num_state_factors=1,
num_controls=[3],
num_control_factors=1,
)

got = get_model_dimensions_from_labels(model_labels)

self.assertEqual(want.num_observations, got.num_observations)
self.assertEqual(want.num_observation_modalities, got.num_observation_modalities)
self.assertEqual(want.num_states, got.num_states)
self.assertEqual(want.num_state_factors, got.num_state_factors)
self.assertEqual(want.num_controls, got.num_controls)
self.assertEqual(want.num_control_factors, got.num_control_factors)

def test_A_matrix_stub(self):
"""
This tests the construction of a 2-modality, 2-hidden state factor pandas MultiIndex dataframe using
the `model_labels` dictionary, which contains the modality- and factor-specific levels, labeled with string
identifiers
identifiers.

Note: actions are ignored when creating an A matrix stub
"""

model_labels = {
Expand All @@ -41,6 +88,9 @@ def test_A_matrix_stub(self):
"weather_state": ["raining", "clear"],
"sprinkler_state": ["on", "off"],
},
"actions": {
"actions": ["something", "nothing"],
}
}

num_hidden_state_factors = len(model_labels["states"])
Expand Down
Loading