Skip to content

Commit

Permalink
Move unit test and update existing test for creating A matrix stub
Browse files Browse the repository at this point in the history
  • Loading branch information
Arun-Niranjan authored and dimarkov committed Jun 6, 2024
1 parent b59cfcd commit 1b50710
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 64 deletions.
58 changes: 0 additions & 58 deletions test/test_utils.py

This file was deleted.

62 changes: 56 additions & 6 deletions test/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import os
import unittest
from pathlib import Path
import shutil
import tempfile

import numpy as np
import itertools
import pandas as pd
from pandas.testing import assert_frame_equal

from pymdp.utils import create_A_matrix_stub, read_A_matrix, create_B_matrix_stubs, read_B_matrices
from pymdp.utils import Dimensions, get_model_dimensions_from_labels, create_A_matrix_stub, read_A_matrix, create_B_matrix_stubs, read_B_matrices

tmp_path = Path('tmp_dir')

Expand All @@ -18,11 +14,62 @@

class TestWrappers(unittest.TestCase):

def test_get_model_dimensions_from_labels(self):
"""
Tests model dimension extraction from labels including observations, states and actions.
"""
model_labels = {
"observations": {
"species_observation": [
"absent",
"present",
],
"budget_observation": [
"high",
"medium",
"low",
],
},
"states": {
"species_state": [
"extant",
"extinct",
],
},
"actions": {
"conservation_action": [
"manage",
"survey",
"stop",
],
},
}

want = Dimensions(
num_observations=[2, 3],
num_observation_modalities=2,
num_states=[2],
num_state_factors=1,
num_controls=[3],
num_control_factors=1,
)

got = get_model_dimensions_from_labels(model_labels)

self.assertEqual(want.num_observations, got.num_observations)
self.assertEqual(want.num_observation_modalities, got.num_observation_modalities)
self.assertEqual(want.num_states, got.num_states)
self.assertEqual(want.num_state_factors, got.num_state_factors)
self.assertEqual(want.num_controls, got.num_controls)
self.assertEqual(want.num_control_factors, got.num_control_factors)

def test_A_matrix_stub(self):
"""
This tests the construction of a 2-modality, 2-hidden state factor pandas MultiIndex dataframe using
the `model_labels` dictionary, which contains the modality- and factor-specific levels, labeled with string
identifiers
identifiers.
Note: actions are ignored when creating an A matrix stub
"""

model_labels = {
Expand All @@ -41,6 +88,9 @@ def test_A_matrix_stub(self):
"weather_state": ["raining", "clear"],
"sprinkler_state": ["on", "off"],
},
"actions": {
"actions": ["something", "nothing"],
}
}

num_hidden_state_factors = len(model_labels["states"])
Expand Down

0 comments on commit 1b50710

Please sign in to comment.