From 0575fca72e8db40c22d0ccb23a0a26e57b59f38f Mon Sep 17 00:00:00 2001 From: dimarkov <5038100+dimarkov@users.noreply.github.com> Date: Fri, 21 Jun 2024 09:57:47 +0200 Subject: [PATCH] fix infer_params and remove learning method --- pymdp/jax/agent.py | 47 ++-------------------------------------------- 1 file changed, 2 insertions(+), 45 deletions(-) diff --git a/pymdp/jax/agent.py b/pymdp/jax/agent.py index e3d22e8a..0cb7b694 100644 --- a/pymdp/jax/agent.py +++ b/pymdp/jax/agent.py @@ -254,51 +254,6 @@ def unique_multiactions(self): size = pymath.prod(self.num_controls) return jnp.unique(self.policies[:, 0], axis=0, size=size, fill_value=-1) - @vmap - def learning(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1., lr_pB=1., **kwargs): - agent = self - if self.learn_A: - o_vec_seq = jtu.tree_map(lambda o, dim: nn.one_hot(o, dim), outcomes, self.num_obs) - qA = learning.update_obs_likelihood_dirichlet(self.pA, o_vec_seq, beliefs_A, self.A_dependencies, lr=lr_pA) - E_qA = jtu.tree_map(lambda x: maths.dirichlet_expected_value(x), qA) - agent = tree_at(lambda x: (x.A, x.pA), agent, (E_qA, qA)) - - if self.learn_B: - beliefs_B = beliefs_A if beliefs_B is None else beliefs_B - actions_seq = [actions[..., i] for i in range(actions.shape[-1])] # as many elements as there are control factors, where each element is a jnp.ndarray of shape (n_timesteps, ) - assert beliefs_B[0].shape[0] == actions_seq[0].shape[0] + 1 - actions_onehot = jtu.tree_map(lambda a, dim: nn.one_hot(a, dim, axis=-1), actions_seq, self.num_controls) - qB = learning.update_state_likelihood_dirichlet(self.pB, beliefs_B, actions_onehot, self.B_dependencies, lr=lr_pB) - E_qB = jtu.tree_map(lambda x: maths.dirichlet_expected_value(x), qB) - - # if you have updated your beliefs about transitions, you need to re-compute the I matrix used for inductive inferenece - if self.use_inductive and self.H is not None: - I_updated = control.generate_I_matrix(self.H, E_qB, self.inductive_threshold, self.inductive_depth) - else: - I_updated = self.I - - agent = tree_at(lambda x: (x.B, x.pB, x.I), agent, (E_qB, qB, I_updated)) - - # if self.learn_C: - # self.qC = learning.update_C(self.C, *args, **kwargs) - # self.C = jtu.tree_map(lambda x: maths.dirichlet_expected_value(x), self.qC) - # if self.learn_D: - # self.qD = learning.update_D(self.D, *args, **kwargs) - # self.D = jtu.tree_map(lambda x: maths.dirichlet_expected_value(x), self.qD) - # if self.learn_E: - # self.qE = learning.update_E(self.E, *args, **kwargs) - # self.E = maths.dirichlet_expected_value(self.qE) - - # do stuff - # variables = ... - # parameters = ... - # varibles = {'A': jnp.ones(5)} - - # agent = tree_at(lambda x: (x.A, x.pA, x.B, x.pB, x.I), self, (E_qA, qA, E_qB, qB, I_updated)) - - return agent - - @vmap def infer_parameters(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1., lr_pB=1., **kwargs): agent = self @@ -344,6 +299,8 @@ def infer_parameters(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1 I_updated = self.I agent = tree_at(lambda x: (x.B, x.pB, x.I), agent, (E_qB, qB, I_updated)) + + return agent @vmap def infer_states(self, observations, empirical_prior, *, past_actions=None, qs_hist=None, mask=None):