diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 744e9d54..f94cc558 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -27,11 +27,22 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install flake8 pytest + python -m pip install coverage if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 run: | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - pytest test + - name: Run tests with pytest, calculating coverage + run: coverage run --source=pymdp -m pytest test/ + - name: Generate coverage HTML report + run: coverage html + # expect actions/upload-artifact@v4 to fail when run locally with `act` + - name: Upload coverage HTML report for pymdp as a build artifact + uses: actions/upload-artifact@v4 + with: + name: pymdp-${{ matrix.python-version }}--coverage-report + path: htmlcov/ + retention-days: 30 + - name: Print coverage report to console + run: coverage report \ No newline at end of file diff --git a/.gitignore b/.gitignore index 5f24acf9..61051a2a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__ .DS_Store .ipynb_checkpoints +.idea .rope* .vscode/ .ipynb_checkpoints/ @@ -9,3 +10,5 @@ __pycache__ env/ pymdp.egg-info inferactively_pymdp.egg-info +htmlcov +.coverage diff --git a/pymdp/jax/task.py b/pymdp/jax/task.py index 5de0315e..0f349d94 100644 --- a/pymdp/jax/task.py +++ b/pymdp/jax/task.py @@ -9,9 +9,9 @@ def select_probs(positions, matrix, dependency_list, actions=None): args = tuple(p for i, p in enumerate(positions) if i in dependency_list) - args += () if actions is None else (actions,) + args = args + (actions,) if actions is not None else args - return matrix[..., *args] + return matrix[(...,) + args] def cat_sample(key, p): a = jnp.arange(p.shape[-1]) diff --git a/test/test_SPM_validation.py b/test/test_SPM_validation.py index ee386378..74366fe4 100644 --- a/test/test_SPM_validation.py +++ b/test/test_SPM_validation.py @@ -1,9 +1,14 @@ import os +import sys import unittest import numpy as np from scipy.io import loadmat +# import the library directly from local source (rather than relying on the library being installed) +# insert the dependency so it's prioritized over an installed variant +sys.path.insert(0, os.path.abspath('../pymdp')) + from pymdp.agent import Agent from pymdp.utils import to_obj_array, build_xn_vn_array, get_model_dimensions, convert_observation_array from pymdp.maths import dirichlet_log_evidence diff --git a/test/test_agent.py b/test/test_agent.py index 161bca56..1696398d 100644 --- a/test/test_agent.py +++ b/test/test_agent.py @@ -8,10 +8,14 @@ """ import os +import sys import unittest import numpy as np -from copy import deepcopy + +# import the library directly from local source (rather than relying on the library being installed) +# insert the dependency so it's prioritized over an installed variant +sys.path.insert(0, os.path.abspath('../pymdp')) from pymdp.agent import Agent from pymdp import utils, maths diff --git a/test/test_agent_jax.py b/test/test_agent_jax.py index ad3d85d8..df01974f 100644 --- a/test/test_agent_jax.py +++ b/test/test_agent_jax.py @@ -6,17 +6,19 @@ """ import os +import sys import unittest -import numpy as np import jax.numpy as jnp from jax import vmap, nn, random -import jax.tree_util as jtu + +# import the library directly from local source (rather than relying on the library being installed) +# insert the dependency so it's prioritized over an installed variant +sys.path.insert(0, os.path.abspath('../pymdp')) from pymdp.jax.maths import compute_log_likelihood_single_modality from pymdp.jax.utils import norm_dist from equinox import Module -from typing import Any, List class TestAgentJax(unittest.TestCase): diff --git a/test/test_control.py b/test/test_control.py index 14b09938..a2a05104 100644 --- a/test/test_control.py +++ b/test/test_control.py @@ -6,10 +6,15 @@ """ import os +import sys import unittest import numpy as np +# import the library directly from local source (rather than relying on the library being installed) +# insert the dependency so it's prioritized over an installed variant +sys.path.insert(0, os.path.abspath('../pymdp')) + from pymdp import utils, maths from pymdp import control diff --git a/test/test_control_jax.py b/test/test_control_jax.py index 75de6912..bb70e484 100644 --- a/test/test_control_jax.py +++ b/test/test_control_jax.py @@ -6,18 +6,21 @@ """ import os +import sys import unittest -import pytest import numpy as np import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu +# import the library directly from local source (rather than relying on the library being installed) +# insert the dependency so it's prioritized over an installed variant +sys.path.insert(0, os.path.abspath('../pymdp')) + import pymdp.jax.control as ctl_jax import pymdp.control as ctl_np -from pymdp.jax.maths import factor_dot from pymdp import utils cfg = {"source_key": 0, "num_models": 4} diff --git a/test/test_demos.py b/test/test_demos.py index d29d3eb4..41feb812 100644 --- a/test/test_demos.py +++ b/test/test_demos.py @@ -1,11 +1,16 @@ import unittest import numpy as np +import os +import sys import copy import seaborn as sns import matplotlib.pyplot as plt +# import the library directly from local source (rather than relying on the library being installed) +# insert the dependency so it's prioritized over an installed variant +sys.path.insert(0, os.path.abspath('../pymdp')) + from pymdp.agent import Agent -from pymdp.utils import plot_beliefs, plot_likelihood from pymdp import utils, maths, default_models from pymdp import control from pymdp.envs import TMazeEnv, TMazeEnvNullOutcome diff --git a/test/test_fpi.py b/test/test_fpi.py index d60f944e..3967f589 100644 --- a/test/test_fpi.py +++ b/test/test_fpi.py @@ -6,10 +6,15 @@ """ import os +import sys import unittest import numpy as np +# import the library directly from local source (rather than relying on the library being installed) +# insert the dependency so it's prioritized over an installed variant +sys.path.insert(0, os.path.abspath('../pymdp')) + from pymdp import utils, maths from pymdp.algos import run_vanilla_fpi, run_vanilla_fpi_factorized diff --git a/test/test_inference.py b/test/test_inference.py index 6528ab6d..2df0c1f3 100644 --- a/test/test_inference.py +++ b/test/test_inference.py @@ -6,10 +6,15 @@ """ import os +import sys import unittest import numpy as np +# import the library directly from local source (rather than relying on the library being installed) +# insert the dependency so it's prioritized over an installed variant +sys.path.insert(0, os.path.abspath('../pymdp')) + from pymdp import utils, maths from pymdp import inference diff --git a/test/test_inference_jax.py b/test/test_inference_jax.py index e426c870..123ac21a 100644 --- a/test/test_inference_jax.py +++ b/test/test_inference_jax.py @@ -6,14 +6,19 @@ """ import os +import sys import unittest import numpy as np import jax.numpy as jnp +# import the library directly from local source (rather than relying on the library being installed) +# insert the dependency so it's prioritized over an installed variant +sys.path.insert(0, os.path.abspath('../pymdp')) + from pymdp.jax.algos import run_vanilla_fpi as fpi_jax from pymdp.algos import run_vanilla_fpi as fpi_numpy -from pymdp import utils, maths +from pymdp import utils class TestInferenceJax(unittest.TestCase): diff --git a/test/test_learning.py b/test/test_learning.py index c839704c..fa9e2483 100644 --- a/test/test_learning.py +++ b/test/test_learning.py @@ -1,6 +1,12 @@ +import os, sys import unittest import numpy as np + +# import the library directly from local source (rather than relying on the library being installed) +# insert the dependency so it's prioritized over an installed variant +sys.path.insert(0, os.path.abspath('../pymdp')) + from pymdp import utils, maths, learning from copy import deepcopy diff --git a/test/test_learning_jax.py b/test/test_learning_jax.py index cdb3b86c..f134d142 100644 --- a/test/test_learning_jax.py +++ b/test/test_learning_jax.py @@ -6,12 +6,17 @@ """ import os +import sys import unittest import numpy as np import jax.numpy as jnp import jax.tree_util as jtu +# import the library directly from local source (rather than relying on the library being installed) +# insert the dependency so it's prioritized over an installed variant +sys.path.insert(0, os.path.abspath('../pymdp')) + from pymdp.learning import update_obs_likelihood_dirichlet as update_pA_numpy from pymdp.learning import update_obs_likelihood_dirichlet_factorized as update_pA_numpy_factorized from pymdp.jax.learning import update_obs_likelihood_dirichlet as update_pA_jax diff --git a/test/test_message_passing_jax.py b/test/test_message_passing_jax.py index b27be336..3f893cc3 100644 --- a/test/test_message_passing_jax.py +++ b/test/test_message_passing_jax.py @@ -6,6 +6,7 @@ """ import os +import sys import unittest from functools import partial @@ -15,16 +16,17 @@ from jax import vmap, nn from jax import random as jr +# import the library directly from local source (rather than relying on the library being installed) +# insert the dependency so it's prioritized over an installed variant +sys.path.insert(0, os.path.abspath('../pymdp')) + from pymdp.jax.algos import run_vanilla_fpi as fpi_jax from pymdp.jax.algos import run_factorized_fpi as fpi_jax_factorized -from pymdp.jax.algos import update_variational_filtering as ovf_jax from pymdp.algos import run_vanilla_fpi as fpi_numpy -from pymdp.algos import run_mmp as mmp_numpy from pymdp.jax.algos import run_mmp as mmp_jax -from pymdp.jax.algos import run_vmp as vmp_jax -from pymdp import utils, maths +from pymdp import utils -from typing import Any, List, Dict +from typing import List, Dict def make_model_configs(source_seed=0, num_models=4) -> Dict: diff --git a/test/test_mmp.py b/test/test_mmp.py index 61ad575c..0619ab67 100644 --- a/test/test_mmp.py +++ b/test/test_mmp.py @@ -8,11 +8,16 @@ """ import os +import sys import unittest import numpy as np from scipy.io import loadmat +# import the library directly from local source (rather than relying on the library being installed) +# insert the dependency so it's prioritized over an installed variant +sys.path.insert(0, os.path.abspath('../pymdp')) + from pymdp.utils import get_model_dimensions, convert_observation_array from pymdp.algos import run_mmp from pymdp.maths import get_joint_likelihood_seq diff --git a/test/test_utils.py b/test/test_utils.py index 033dd8f6..c916cd47 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -6,11 +6,15 @@ __author__: Conor Heins, Alexander Tschantz, Daphne Demekas, Brennan Klein """ - +import os, sys import unittest import numpy as np +# import the library directly from local source (rather than relying on the library being installed) +# insert the dependency so it's prioritized over an installed variant +sys.path.insert(0, os.path.abspath('../pymdp')) + from pymdp import utils class TestUtils(unittest.TestCase): diff --git a/test/test_wrappers.py b/test/test_wrappers.py index cf405e56..fa59d2b5 100644 --- a/test/test_wrappers.py +++ b/test/test_wrappers.py @@ -1,6 +1,11 @@ import os +import sys import unittest -from pathlib import Path + +# import the library directly from local source (rather than relying on the library being installed) +# insert the dependency so it's prioritized over an installed variant +sys.path.insert(0, os.path.abspath('../pymdp')) + from pymdp.utils import Dimensions, get_model_dimensions_from_labels class TestWrappers(unittest.TestCase):