Skip to content

Commit

Permalink
fixed testing
Browse files Browse the repository at this point in the history
  • Loading branch information
dimarkov committed Jul 12, 2024
1 parent 6d5b7f9 commit 897c670
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions test/test_learning_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,13 @@ def test_update_observation_likelihood_fullyconnected(self):
qA_np_test = update_pA_numpy(pA_np, A_np, obs_np, qs_np, lr=l_rate)

pA_jax = jtu.tree_map(lambda x: jnp.array(x), list(pA_np))
A_jax = jtu.tree_map(lambda x: jnp.array(x), list(A_np))
obs_jax = jtu.tree_map(lambda x: jnp.array(x)[None], list(obs_np))
qs_jax = jtu.tree_map(lambda x: jnp.array(x)[None], list(qs_np))

qA_jax_test, E_qA_jax_test = update_pA_jax(
pA_jax,
A_jax,
obs_jax,
qs_jax,
A_dependencies=A_dependencies,
Expand Down Expand Up @@ -126,11 +128,13 @@ def test_update_observation_likelihood_factorized(self):
qA_np_test = update_pA_numpy_factorized(pA_np, A_np, obs_np, qs_np, A_dependencies, lr=l_rate)

pA_jax = jtu.tree_map(lambda x: jnp.array(x), list(pA_np))
A_jax = jtu.tree_map(lambda x: jnp.array(x), list(A_np))
obs_jax = jtu.tree_map(lambda x: jnp.array(x)[None], list(obs_np))
qs_jax = jtu.tree_map(lambda x: jnp.array(x)[None], list(qs_np))

qA_jax_test, E_qA_jax_test = update_pA_jax(
pA_jax,
A_jax,
obs_jax,
qs_jax,
A_dependencies=A_dependencies,
Expand Down

0 comments on commit 897c670

Please sign in to comment.