From 264ee7de2fb46a2a46634a343ca08d98101a8900 Mon Sep 17 00:00:00 2001 From: dimarkov <5038100+dimarkov@users.noreply.github.com> Date: Fri, 21 Jun 2024 10:49:44 +0200 Subject: [PATCH] fix passing of predictive distribution inside update_empirical_prior method for mmp and vmp algos --- pymdp/jax/agent.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pymdp/jax/agent.py b/pymdp/jax/agent.py index 0cb7b694..79aac96d 100644 --- a/pymdp/jax/agent.py +++ b/pymdp/jax/agent.py @@ -357,7 +357,11 @@ def update_empirical_prior(self, action, qs): qs_last = jtu.tree_map( lambda x: x[-1], qs) # this computation of the predictive prior is correct only for fully factorised Bs. - pred = control.compute_expected_state(qs_last, self.B, action, B_dependencies=self.B_dependencies) + if self.inference_algo in ['mmp', 'vmp']: + # in the case of the 'mmp' or 'vmp' we have to use D as prior parameter for infer states + pred = self.D + else: + pred = control.compute_expected_state(qs_last, self.B, action, B_dependencies=self.B_dependencies) return (pred, qs)