Skip to content

Commit

Permalink
changed infer_state api
Browse files Browse the repository at this point in the history
  • Loading branch information
dimarkov committed Jun 17, 2024
1 parent 475dc51 commit 8ee1426
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion pymdp/jax/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def infer_parameters(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1
agent = tree_at(lambda x: (x.B, x.pB, x.I), agent, (E_qB, qB, I_updated))

@vmap
def infer_states(self, observations, past_actions, empirical_prior, qs_hist, mask=None):
def infer_states(self, observations, empirical_prior, *, past_actions=None, qs_hist=None, mask=None):
"""
Update approximate posterior over hidden states by solving variational inference problem, given an observation.
Expand Down

0 comments on commit 8ee1426

Please sign in to comment.