diff --git a/pymdp/jax/agent.py b/pymdp/jax/agent.py index 28554963..86ab409b 100644 --- a/pymdp/jax/agent.py +++ b/pymdp/jax/agent.py @@ -367,12 +367,12 @@ def infer_states(self, observations, empirical_prior, *, past_actions=None, qs_h def update_empirical_prior(self, action, qs): # return empirical_prior, and the history of posterior beliefs (filtering distributions) held about hidden states at times 1, 2 ... t - qs_last = jtu.tree_map( lambda x: x[:, -1], qs) # this computation of the predictive prior is correct only for fully factorised Bs. if self.inference_algo in ['mmp', 'vmp']: # in the case of the 'mmp' or 'vmp' we have to use D as prior parameter for infer states pred = self.D else: + qs_last = jtu.tree_map( lambda x: x[:, -1], qs) propagate_beliefs = partial(control.compute_expected_state, B_dependencies=self.B_dependencies) pred = vmap(propagate_beliefs)(qs_last, self.B, action) @@ -420,7 +420,6 @@ def infer_policies(self, qs: List): return q_pi, G - @vmap def multiaction_probabilities(self, q_pi: Array): """ Compute probabilities of unique multi-actions from the posterior over policies. @@ -437,7 +436,8 @@ def multiaction_probabilities(self, q_pi: Array): """ if self.sampling_mode == "marginal": - marginals = control.get_marginals(q_pi, self.policies, self.num_controls) + get_marginals = partial(control.get_marginals, policies=self.policies, num_controls=self.num_controls) + marginals = get_marginals(q_pi) outer = lambda a, b: jnp.outer(a, b).reshape(-1) marginals = jtu.tree_reduce(outer, marginals) @@ -446,7 +446,8 @@ def multiaction_probabilities(self, q_pi: Array): self.policies[:, 0] == jnp.expand_dims(self.unique_multiactions, -2), -1 ) - marginals = jnp.where(locs, q_pi, 0.).sum(-1) + get_marginals = lambda x: jnp.where(locs, x, 0.).sum(-1) + marginals = vmap(get_marginals)(q_pi) return marginals