From 897c670ad9ff05c00abd858d5538b3a05fd296a1 Mon Sep 17 00:00:00 2001 From: dimarkov <5038100+dimarkov@users.noreply.github.com> Date: Fri, 12 Jul 2024 17:32:22 +0200 Subject: [PATCH] fixed testing --- test/test_learning_jax.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_learning_jax.py b/test/test_learning_jax.py index 60d6bcb3..6b943932 100644 --- a/test/test_learning_jax.py +++ b/test/test_learning_jax.py @@ -64,11 +64,13 @@ def test_update_observation_likelihood_fullyconnected(self): qA_np_test = update_pA_numpy(pA_np, A_np, obs_np, qs_np, lr=l_rate) pA_jax = jtu.tree_map(lambda x: jnp.array(x), list(pA_np)) + A_jax = jtu.tree_map(lambda x: jnp.array(x), list(A_np)) obs_jax = jtu.tree_map(lambda x: jnp.array(x)[None], list(obs_np)) qs_jax = jtu.tree_map(lambda x: jnp.array(x)[None], list(qs_np)) qA_jax_test, E_qA_jax_test = update_pA_jax( pA_jax, + A_jax, obs_jax, qs_jax, A_dependencies=A_dependencies, @@ -126,11 +128,13 @@ def test_update_observation_likelihood_factorized(self): qA_np_test = update_pA_numpy_factorized(pA_np, A_np, obs_np, qs_np, A_dependencies, lr=l_rate) pA_jax = jtu.tree_map(lambda x: jnp.array(x), list(pA_np)) + A_jax = jtu.tree_map(lambda x: jnp.array(x), list(A_np)) obs_jax = jtu.tree_map(lambda x: jnp.array(x)[None], list(obs_np)) qs_jax = jtu.tree_map(lambda x: jnp.array(x)[None], list(qs_np)) qA_jax_test, E_qA_jax_test = update_pA_jax( pA_jax, + A_jax, obs_jax, qs_jax, A_dependencies=A_dependencies,