Skip to content

Commit

Permalink
fix infer_params and remove learning method
Browse files Browse the repository at this point in the history
  • Loading branch information
dimarkov committed Jun 21, 2024
1 parent 10a204f commit 0575fca
Showing 1 changed file with 2 additions and 45 deletions.
47 changes: 2 additions & 45 deletions pymdp/jax/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,51 +254,6 @@ def unique_multiactions(self):
size = pymath.prod(self.num_controls)
return jnp.unique(self.policies[:, 0], axis=0, size=size, fill_value=-1)

@vmap
def learning(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1., lr_pB=1., **kwargs):
agent = self
if self.learn_A:
o_vec_seq = jtu.tree_map(lambda o, dim: nn.one_hot(o, dim), outcomes, self.num_obs)
qA = learning.update_obs_likelihood_dirichlet(self.pA, o_vec_seq, beliefs_A, self.A_dependencies, lr=lr_pA)
E_qA = jtu.tree_map(lambda x: maths.dirichlet_expected_value(x), qA)
agent = tree_at(lambda x: (x.A, x.pA), agent, (E_qA, qA))

if self.learn_B:
beliefs_B = beliefs_A if beliefs_B is None else beliefs_B
actions_seq = [actions[..., i] for i in range(actions.shape[-1])] # as many elements as there are control factors, where each element is a jnp.ndarray of shape (n_timesteps, )
assert beliefs_B[0].shape[0] == actions_seq[0].shape[0] + 1
actions_onehot = jtu.tree_map(lambda a, dim: nn.one_hot(a, dim, axis=-1), actions_seq, self.num_controls)
qB = learning.update_state_likelihood_dirichlet(self.pB, beliefs_B, actions_onehot, self.B_dependencies, lr=lr_pB)
E_qB = jtu.tree_map(lambda x: maths.dirichlet_expected_value(x), qB)

# if you have updated your beliefs about transitions, you need to re-compute the I matrix used for inductive inferenece
if self.use_inductive and self.H is not None:
I_updated = control.generate_I_matrix(self.H, E_qB, self.inductive_threshold, self.inductive_depth)
else:
I_updated = self.I

agent = tree_at(lambda x: (x.B, x.pB, x.I), agent, (E_qB, qB, I_updated))

# if self.learn_C:
# self.qC = learning.update_C(self.C, *args, **kwargs)
# self.C = jtu.tree_map(lambda x: maths.dirichlet_expected_value(x), self.qC)
# if self.learn_D:
# self.qD = learning.update_D(self.D, *args, **kwargs)
# self.D = jtu.tree_map(lambda x: maths.dirichlet_expected_value(x), self.qD)
# if self.learn_E:
# self.qE = learning.update_E(self.E, *args, **kwargs)
# self.E = maths.dirichlet_expected_value(self.qE)

# do stuff
# variables = ...
# parameters = ...
# varibles = {'A': jnp.ones(5)}

# agent = tree_at(lambda x: (x.A, x.pA, x.B, x.pB, x.I), self, (E_qA, qA, E_qB, qB, I_updated))

return agent


@vmap
def infer_parameters(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1., lr_pB=1., **kwargs):
agent = self
Expand Down Expand Up @@ -344,6 +299,8 @@ def infer_parameters(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1
I_updated = self.I

agent = tree_at(lambda x: (x.B, x.pB, x.I), agent, (E_qB, qB, I_updated))

return agent

@vmap
def infer_states(self, observations, empirical_prior, *, past_actions=None, qs_hist=None, mask=None):
Expand Down

0 comments on commit 0575fca

Please sign in to comment.