Skip to content

Commit

Permalink
allow for learning of single observation modality depending on the co…
Browse files Browse the repository at this point in the history
…ntent of pA list
  • Loading branch information
dimarkov committed Jul 12, 2024
1 parent 382ac75 commit 1958e4d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
7 changes: 4 additions & 3 deletions pymdp/jax/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def __init__(
policy_len=1,
control_fac_idx=None,
policies=None,
gamma=16.0,
alpha=16.0,
gamma=1.0,
alpha=1.0,
inductive_depth=1,
inductive_threshold=0.1,
inductive_epsilon=1e-3,
Expand Down Expand Up @@ -156,7 +156,7 @@ def __init__(
factor_dims = tuple([self.num_states[f] for f in self.A_dependencies[m]])
assert self.A[m].shape[2:] == factor_dims, f"Please input an `A_dependencies` whose {m}-th indices correspond to the hidden state factors that line up with lagging dimensions of A[{m}]..."
if self.pA != None:
assert self.pA[m].shape[2:] == factor_dims, f"Please input an `A_dependencies` whose {m}-th indices correspond to the hidden state factors that line up with lagging dimensions of pA[{m}]..."
assert self.pA[m].shape[2:] == factor_dims if self.pA[m] is not None else True, f"Please input an `A_dependencies` whose {m}-th indices correspond to the hidden state factors that line up with lagging dimensions of pA[{m}]..."
assert max(self.A_dependencies[m]) <= (self.num_factors - 1), f"Check modality {m} of `A_dependencies` - must be consistent with `num_states` and `num_factors`..."

# Ensure consistency of B_dependencies with num_states and num_factors
Expand Down Expand Up @@ -280,6 +280,7 @@ def infer_parameters(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1
lr = jnp.broadcast_to(lr_pA, (self.batch_size,))
qA, E_qA = vmap(update_A)(
self.pA,
self.A,
outcomes,
marginal_beliefs,
lr=lr,
Expand Down
15 changes: 10 additions & 5 deletions pymdp/jax/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ def update_obs_likelihood_dirichlet_m(pA_m, obs_m, qs, dependencies_m, lr=1.0):
dfda = vmap(multidimensional_outer)([obs_m] + relevant_factors).sum(axis=0)

new_pA_m = pA_m + lr * dfda
A_m = dirichlet_expected_value(new_pA_m)

return new_pA_m, dirichlet_expected_value(new_pA_m)
return new_pA_m, A_m

def update_obs_likelihood_dirichlet(pA, obs, qs, *, A_dependencies, onehot_obs, num_obs, lr):
def update_obs_likelihood_dirichlet(pA, A, obs, qs, *, A_dependencies, onehot_obs, num_obs, lr):
""" JAX version of ``pymdp.learning.update_obs_likelihood_dirichlet`` """

obs_m = lambda o, dim: nn.one_hot(o, dim) if not onehot_obs else o
Expand All @@ -39,9 +40,13 @@ def update_obs_likelihood_dirichlet(pA, obs, qs, *, A_dependencies, onehot_obs,
result = tree_map(update_A_fn, pA, obs, num_obs, A_dependencies)
qA = []
E_qA = []
for r in result:
qA.append(r[0])
E_qA.append(r[1])
for i, r in enumerate(result):
if r is None:
qA.append(r)
E_qA.append(A[i])
else:
qA.append(r[0])
E_qA.append(r[1])

return qA, E_qA

Expand Down

0 comments on commit 1958e4d

Please sign in to comment.