Skip to content

Commit

Permalink
Merge pull request #54 from infer-actively/E_vector_into_agent
Browse files Browse the repository at this point in the history
E vector input into `Agent()` constructor
  • Loading branch information
conorheins authored Oct 27, 2021
2 parents bcb7440 + 28816d0 commit c1f40e9
Show file tree
Hide file tree
Showing 5 changed files with 2,220 additions and 36 deletions.
2,161 changes: 2,141 additions & 20 deletions examples/tmaze_learning_demo.ipynb

Large diffs are not rendered by default.

31 changes: 26 additions & 5 deletions pymdp/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
B,
C=None,
D=None,
E = None,
pA=None,
pB = None,
pD = None,
Expand Down Expand Up @@ -76,7 +77,7 @@ def __init__(

assert utils.is_normalized(self.A), "A matrix is not normalized (i.e. A.sum(axis = 0) must all equal 1.0"

# Determine number of modalities and their dimensionaliti
""" Determine number of observation modalities and their respective dimensions """
self.num_obs = [self.A[m].shape[0] for m in range(len(self.A))]
self.num_modalities = len(self.num_obs)

Expand Down Expand Up @@ -130,6 +131,7 @@ def __init__(
assert all([n_c == max_action for (n_c, max_action) in zip(self.num_controls, list(np.max(all_policies, axis =0)+1))]), "Maximum number of actions is not consistent with `num_controls`"

""" Construct prior preferences (uniform if not specified) """

if C is not None:
if not isinstance(C, np.ndarray):
raise TypeError(
Expand All @@ -144,7 +146,7 @@ def __init__(
else:
self.C = self._construct_C_prior()

# Construct initial beliefs (uniform if not specified)
""" Construct prior over hidden states (uniform if not specified) """

if D is not None:
if not isinstance(D, np.ndarray):
Expand All @@ -168,6 +170,20 @@ def __init__(
""" Assigning prior parameters on initial hidden states (pD vectors) """
self.pD = pD

""" Construct prior over policies (uniform if not specified) """

if E is not None:
if not isinstance(E, np.ndarray):
raise TypeError(
'E vector must be a numpy array'
)
self.E = E

assert len(self.E) == len(self.policies), f"Check E vector: length of E must be equal to number of policies: {len(self.policies)}"

else:
self.E = self._construct_E_prior()

self.edge_handling_params = {}
self.edge_handling_params['use_BMA'] = use_BMA # creates a 'D-like' moving prior
self.edge_handling_params['policy_sep_prior'] = policy_sep_prior # carries forward last timesteps posterior, in a policy-conditioned way
Expand Down Expand Up @@ -233,6 +249,10 @@ def _construct_num_controls(self):
)

return num_controls

def _construct_E_prior(self):
E = np.ones(len(self.policies)) / len(self.policies)
return E

def reset(self, init_qs=None):

Expand Down Expand Up @@ -440,7 +460,8 @@ def infer_policies(self):
self.use_param_info_gain,
self.pA,
self.pB,
self.gamma
E = self.E,
gamma = self.gamma
)
elif self.inference_algo == "MMP":

Expand All @@ -458,8 +479,8 @@ def infer_policies(self):
self.latest_belief,
self.pA,
self.pB,
self.F,
E = None,
F = self.F,
E = self.E,
gamma = self.gamma
)

Expand Down
19 changes: 12 additions & 7 deletions pymdp/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,13 @@ def update_posterior_policies_mmp(
efe = np.zeros(num_policies)

if F is None:
F = np.zeros(num_policies)
F = spm_log_single(np.ones(num_policies) / num_policies)

if E is None:
E = np.zeros(num_policies)
lnE = spm_log_single(np.ones(num_policies) / num_policies)
else:
lnE = spm_log_single(E)


for p_idx, policy in enumerate(policies):

Expand All @@ -83,9 +87,8 @@ def update_posterior_policies_mmp(
if pB is not None:
efe[p_idx] += calc_pB_info_gain(pB, qs_seq_pi[p_idx], prior, policy)


q_pi = softmax(efe * gamma - F + E)

q_pi = softmax(efe * gamma - F + lnE)

return q_pi, efe


Expand Down Expand Up @@ -151,7 +154,9 @@ def update_posterior_policies(
q_pi = np.zeros((n_policies, 1))

if E is None:
E = np.zeros(n_policies)
lnE = spm_log_single(np.ones(n_policies) / n_policies)
else:
lnE = spm_log_single(E)

for idx, policy in enumerate(policies):
qs_pi = get_expected_states(qs, B, policy)
Expand All @@ -169,7 +174,7 @@ def update_posterior_policies(
if pB is not None:
efe[idx] += calc_pB_info_gain(pB, qs_pi, qs, policy)

q_pi = softmax(efe * gamma + E)
q_pi = softmax(efe * gamma + lnE)

return q_pi, efe

Expand Down
4 changes: 0 additions & 4 deletions test/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,6 @@ def test_agent_with_D_learning_MMP(self):

self.assertTrue(np.allclose(pD_test[factor], pD_validation[factor]))







if __name__ == "__main__":
Expand Down
41 changes: 41 additions & 0 deletions test/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,47 @@ def test_sample_action(self):

chosen_action = control.sample_action(q_pi, policies, num_controls, action_selection="deterministic")
self.assertEqual(int(chosen_action[0]), 1)

def test_update_posterior_policies_withE_vector(self):
"""
Test update posterior policies in the case that there is a prior over policies
"""

""" Construct an explicit example where policy 0 is preferred based on utility,
but action 2 also gets a bump in probability because of prior over policies
"""
num_obs = [3]
num_states = [3]
num_controls = [3]

qs = utils.to_arr_of_arr(utils.onehot(0, num_states[0]))
A = utils.to_arr_of_arr(np.eye(num_obs[0]))
B = utils.construct_controllable_B(num_states, num_controls)

C = utils.to_arr_of_arr(np.array([1.5, 1.0, 1.0]))

D = utils.to_arr_of_arr(utils.onehot(0, num_states[0]))
E = np.array([0.05, 0.05, 0.9])

policies = control.construct_policies(num_states, num_controls, policy_len=1)

q_pi, efe = control.update_posterior_policies(
qs,
A,
B,
C,
policies,
use_utility = True,
use_states_info_gain = False,
use_param_info_gain = False,
pA=None,
pB=None,
E = E,
gamma=16.0
)

self.assertGreater(q_pi[0], q_pi[1])
self.assertGreater(q_pi[2], q_pi[1])


if __name__ == "__main__":
Expand Down

0 comments on commit c1f40e9

Please sign in to comment.