From 865ce9ef866d4527f044521baed566a966a46294 Mon Sep 17 00:00:00 2001 From: dimarkov <5038100+dimarkov@users.noreply.github.com> Date: Mon, 1 Jul 2024 09:45:50 +0200 Subject: [PATCH] fixes for computations with none values for pB and pA --- pymdp/jax/agent.py | 1 - pymdp/jax/control.py | 7 +++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pymdp/jax/agent.py b/pymdp/jax/agent.py index 79aac96d..0d42510a 100644 --- a/pymdp/jax/agent.py +++ b/pymdp/jax/agent.py @@ -432,7 +432,6 @@ def multiaction_probabilities(self, q_pi: Array): ) marginals = jnp.where(locs, q_pi, 0.).sum(-1) - # assert jnp.isclose(jnp.sum(marginals), 1.) # this fails inside scan return marginals @vmap diff --git a/pymdp/jax/control.py b/pymdp/jax/control.py index b02cafe0..177a5a41 100644 --- a/pymdp/jax/control.py +++ b/pymdp/jax/control.py @@ -342,8 +342,11 @@ def scan_body(carry, t): inductive_value = calc_inductive_value_t(qs_init, qs_next, I, epsilon=inductive_epsilon) if use_inductive else 0. - param_info_gain = calc_pA_info_gain(pA, qo, qs_next, A_dependencies) if use_param_info_gain else 0. - param_info_gain += calc_pB_info_gain(pB, qs_next, qs, B_dependencies, policy_i[t]) if use_param_info_gain else 0. + param_info_gain = 0. + if pA is not None: + param_info_gain += calc_pA_info_gain(pA, qo, qs_next, A_dependencies) if use_param_info_gain else 0. + if pB is not None: + param_info_gain += calc_pB_info_gain(pB, qs_next, qs, B_dependencies, policy_i[t]) if use_param_info_gain else 0. neg_G += info_gain + utility - param_info_gain + inductive_value