From 1958e4dae455bb93d4bd994185cc7c4049b1d35a Mon Sep 17 00:00:00 2001 From: dimarkov <5038100+dimarkov@users.noreply.github.com> Date: Fri, 12 Jul 2024 13:59:33 +0200 Subject: [PATCH] allow for learning of single observation modality depending on the content of pA list --- pymdp/jax/agent.py | 7 ++++--- pymdp/jax/learning.py | 15 ++++++++++----- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/pymdp/jax/agent.py b/pymdp/jax/agent.py index 86ab409b..9ab23a05 100644 --- a/pymdp/jax/agent.py +++ b/pymdp/jax/agent.py @@ -104,8 +104,8 @@ def __init__( policy_len=1, control_fac_idx=None, policies=None, - gamma=16.0, - alpha=16.0, + gamma=1.0, + alpha=1.0, inductive_depth=1, inductive_threshold=0.1, inductive_epsilon=1e-3, @@ -156,7 +156,7 @@ def __init__( factor_dims = tuple([self.num_states[f] for f in self.A_dependencies[m]]) assert self.A[m].shape[2:] == factor_dims, f"Please input an `A_dependencies` whose {m}-th indices correspond to the hidden state factors that line up with lagging dimensions of A[{m}]..." if self.pA != None: - assert self.pA[m].shape[2:] == factor_dims, f"Please input an `A_dependencies` whose {m}-th indices correspond to the hidden state factors that line up with lagging dimensions of pA[{m}]..." + assert self.pA[m].shape[2:] == factor_dims if self.pA[m] is not None else True, f"Please input an `A_dependencies` whose {m}-th indices correspond to the hidden state factors that line up with lagging dimensions of pA[{m}]..." assert max(self.A_dependencies[m]) <= (self.num_factors - 1), f"Check modality {m} of `A_dependencies` - must be consistent with `num_states` and `num_factors`..." # Ensure consistency of B_dependencies with num_states and num_factors @@ -280,6 +280,7 @@ def infer_parameters(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1 lr = jnp.broadcast_to(lr_pA, (self.batch_size,)) qA, E_qA = vmap(update_A)( self.pA, + self.A, outcomes, marginal_beliefs, lr=lr, diff --git a/pymdp/jax/learning.py b/pymdp/jax/learning.py index 32ada016..499b94a3 100644 --- a/pymdp/jax/learning.py +++ b/pymdp/jax/learning.py @@ -26,10 +26,11 @@ def update_obs_likelihood_dirichlet_m(pA_m, obs_m, qs, dependencies_m, lr=1.0): dfda = vmap(multidimensional_outer)([obs_m] + relevant_factors).sum(axis=0) new_pA_m = pA_m + lr * dfda + A_m = dirichlet_expected_value(new_pA_m) - return new_pA_m, dirichlet_expected_value(new_pA_m) + return new_pA_m, A_m -def update_obs_likelihood_dirichlet(pA, obs, qs, *, A_dependencies, onehot_obs, num_obs, lr): +def update_obs_likelihood_dirichlet(pA, A, obs, qs, *, A_dependencies, onehot_obs, num_obs, lr): """ JAX version of ``pymdp.learning.update_obs_likelihood_dirichlet`` """ obs_m = lambda o, dim: nn.one_hot(o, dim) if not onehot_obs else o @@ -39,9 +40,13 @@ def update_obs_likelihood_dirichlet(pA, obs, qs, *, A_dependencies, onehot_obs, result = tree_map(update_A_fn, pA, obs, num_obs, A_dependencies) qA = [] E_qA = [] - for r in result: - qA.append(r[0]) - E_qA.append(r[1]) + for i, r in enumerate(result): + if r is None: + qA.append(r) + E_qA.append(A[i]) + else: + qA.append(r[0]) + E_qA.append(r[1]) return qA, E_qA