Skip to content

Commit

Permalink
remove vmap decorator from sample actions
Browse files Browse the repository at this point in the history
  • Loading branch information
dimarkov committed Jul 5, 2024
1 parent 7d05550 commit 382ac75
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions pymdp/jax/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,12 @@ def infer_states(self, observations, empirical_prior, *, past_actions=None, qs_h
def update_empirical_prior(self, action, qs):
# return empirical_prior, and the history of posterior beliefs (filtering distributions) held about hidden states at times 1, 2 ... t

qs_last = jtu.tree_map( lambda x: x[:, -1], qs)
# this computation of the predictive prior is correct only for fully factorised Bs.
if self.inference_algo in ['mmp', 'vmp']:
# in the case of the 'mmp' or 'vmp' we have to use D as prior parameter for infer states
pred = self.D
else:
qs_last = jtu.tree_map( lambda x: x[:, -1], qs)
propagate_beliefs = partial(control.compute_expected_state, B_dependencies=self.B_dependencies)
pred = vmap(propagate_beliefs)(qs_last, self.B, action)

Expand Down Expand Up @@ -420,7 +420,6 @@ def infer_policies(self, qs: List):

return q_pi, G

@vmap
def multiaction_probabilities(self, q_pi: Array):
"""
Compute probabilities of unique multi-actions from the posterior over policies.
Expand All @@ -437,7 +436,8 @@ def multiaction_probabilities(self, q_pi: Array):
"""

if self.sampling_mode == "marginal":
marginals = control.get_marginals(q_pi, self.policies, self.num_controls)
get_marginals = partial(control.get_marginals, policies=self.policies, num_controls=self.num_controls)
marginals = get_marginals(q_pi)
outer = lambda a, b: jnp.outer(a, b).reshape(-1)
marginals = jtu.tree_reduce(outer, marginals)

Expand All @@ -446,7 +446,8 @@ def multiaction_probabilities(self, q_pi: Array):
self.policies[:, 0] == jnp.expand_dims(self.unique_multiactions, -2),
-1
)
marginals = jnp.where(locs, q_pi, 0.).sum(-1)
get_marginals = lambda x: jnp.where(locs, x, 0.).sum(-1)
marginals = vmap(get_marginals)(q_pi)

return marginals

Expand Down

0 comments on commit 382ac75

Please sign in to comment.