diff --git a/pymdp/jax/agent.py b/pymdp/jax/agent.py index 8f814203..e3d22e8a 100644 --- a/pymdp/jax/agent.py +++ b/pymdp/jax/agent.py @@ -346,7 +346,7 @@ def infer_parameters(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1 agent = tree_at(lambda x: (x.B, x.pB, x.I), agent, (E_qB, qB, I_updated)) @vmap - def infer_states(self, observations, past_actions, empirical_prior, qs_hist, mask=None): + def infer_states(self, observations, empirical_prior, *, past_actions=None, qs_hist=None, mask=None): """ Update approximate posterior over hidden states by solving variational inference problem, given an observation.