From 420bcbc0a0f966c585ed3cfde44393d052329800 Mon Sep 17 00:00:00 2001 From: Arun-Niranjan Date: Sat, 30 Dec 2023 17:03:50 +0000 Subject: [PATCH] Move unit test and update existing test for creating A matrix stub --- test/test_utils.py | 58 ---------------------------------------- test/test_wrappers.py | 62 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 64 deletions(-) delete mode 100644 test/test_utils.py diff --git a/test/test_utils.py b/test/test_utils.py deleted file mode 100644 index 0a1ed066..00000000 --- a/test/test_utils.py +++ /dev/null @@ -1,58 +0,0 @@ - -import unittest - -from pymdp.utils import get_model_dimensions_from_labels, Dimensions - -class TestUtils(unittest.TestCase): - def test_get_model_dimensions_from_labels(self): - """ - Tests model dimension extraction from labels including observations, states and actions. - """ - model_labels = { - "observations": { - "species_observation": [ - "absent", - "present", - ], - "budget_observation": [ - "high", - "medium", - "low", - ], - }, - "states": { - "species_state": [ - "extant", - "extinct", - ], - }, - "actions": { - "conservation_action": [ - "manage", - "survey", - "stop", - ], - }, - } - - want = Dimensions( - num_observations=[2, 3], - num_observation_modalities=2, - num_states=[2], - num_state_factors=1, - num_controls=[3], - num_control_factors=1, - ) - - got = get_model_dimensions_from_labels(model_labels) - - self.assertEqual(want.num_observations, got.num_observations) - self.assertEqual(want.num_observation_modalities, got.num_observation_modalities) - self.assertEqual(want.num_states, got.num_states) - self.assertEqual(want.num_state_factors, got.num_state_factors) - self.assertEqual(want.num_controls, got.num_controls) - self.assertEqual(want.num_control_factors, got.num_control_factors) - - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/test/test_wrappers.py b/test/test_wrappers.py index db25c984..254d902b 100644 --- a/test/test_wrappers.py +++ b/test/test_wrappers.py @@ -1,15 +1,11 @@ import os import unittest from pathlib import Path -import shutil -import tempfile -import numpy as np -import itertools import pandas as pd from pandas.testing import assert_frame_equal -from pymdp.utils import create_A_matrix_stub, read_A_matrix, create_B_matrix_stubs, read_B_matrices +from pymdp.utils import Dimensions, get_model_dimensions_from_labels, create_A_matrix_stub, read_A_matrix, create_B_matrix_stubs, read_B_matrices tmp_path = Path('tmp_dir') @@ -18,11 +14,62 @@ class TestWrappers(unittest.TestCase): + def test_get_model_dimensions_from_labels(self): + """ + Tests model dimension extraction from labels including observations, states and actions. + """ + model_labels = { + "observations": { + "species_observation": [ + "absent", + "present", + ], + "budget_observation": [ + "high", + "medium", + "low", + ], + }, + "states": { + "species_state": [ + "extant", + "extinct", + ], + }, + "actions": { + "conservation_action": [ + "manage", + "survey", + "stop", + ], + }, + } + + want = Dimensions( + num_observations=[2, 3], + num_observation_modalities=2, + num_states=[2], + num_state_factors=1, + num_controls=[3], + num_control_factors=1, + ) + + got = get_model_dimensions_from_labels(model_labels) + + self.assertEqual(want.num_observations, got.num_observations) + self.assertEqual(want.num_observation_modalities, got.num_observation_modalities) + self.assertEqual(want.num_states, got.num_states) + self.assertEqual(want.num_state_factors, got.num_state_factors) + self.assertEqual(want.num_controls, got.num_controls) + self.assertEqual(want.num_control_factors, got.num_control_factors) + def test_A_matrix_stub(self): """ This tests the construction of a 2-modality, 2-hidden state factor pandas MultiIndex dataframe using the `model_labels` dictionary, which contains the modality- and factor-specific levels, labeled with string - identifiers + identifiers. + + Note: actions are ignored when creating an A matrix stub """ model_labels = { @@ -41,6 +88,9 @@ def test_A_matrix_stub(self): "weather_state": ["raining", "clear"], "sprinkler_state": ["on", "off"], }, + "actions": { + "actions": ["something", "nothing"], + } } num_hidden_state_factors = len(model_labels["states"])