From 8ee142642ee82aef707bb1192cd69edb44b9d82d Mon Sep 17 00:00:00 2001 From: dimarkov <5038100+dimarkov@users.noreply.github.com> Date: Mon, 17 Jun 2024 15:22:40 +0200 Subject: [PATCH] changed infer_state api --- pymdp/jax/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymdp/jax/agent.py b/pymdp/jax/agent.py index 8f814203..e3d22e8a 100644 --- a/pymdp/jax/agent.py +++ b/pymdp/jax/agent.py @@ -346,7 +346,7 @@ def infer_parameters(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1 agent = tree_at(lambda x: (x.B, x.pB, x.I), agent, (E_qB, qB, I_updated)) @vmap - def infer_states(self, observations, past_actions, empirical_prior, qs_hist, mask=None): + def infer_states(self, observations, empirical_prior, *, past_actions=None, qs_hist=None, mask=None): """ Update approximate posterior over hidden states by solving variational inference problem, given an observation.