Skip to content

Commit

Permalink
fix merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
dimarkov committed Jun 6, 2024
2 parents 09cffaf + 1b50710 commit a7d5a29
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 37 deletions.
88 changes: 67 additions & 21 deletions pymdp/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,18 @@ def __init__(
factors_to_learn="all",
lr_pB=1.0,
lr_pD=1.0,
use_BMA = True,
use_BMA=True,
policy_sep_prior=False,
save_belief_hist=False,
A_factor_list=None,
B_factor_list=None
B_factor_list=None,
sophisticated=False,
si_horizon=3,
si_policy_prune_threshold=1/16,
si_state_prune_threshold=1/16,
si_prune_penalty=512,
ii_depth=10,
ii_threshold=1/16,
):

### Constant parameters ###
Expand All @@ -86,6 +93,15 @@ def __init__(
self.lr_pB = lr_pB
self.lr_pD = lr_pD

# sophisticated inference parameters
self.sophisticated = sophisticated
if self.sophisticated:
assert self.policy_len == 1, "Sophisticated inference only works with policy_len = 1"
self.si_horizon = si_horizon
self.si_policy_prune_threshold = si_policy_prune_threshold
self.si_state_prune_threshold = si_state_prune_threshold
self.si_prune_penalty = si_prune_penalty

# Initialise observation model (A matrices)
if not isinstance(A, np.ndarray):
raise TypeError(
Expand Down Expand Up @@ -129,6 +145,7 @@ def __init__(
self.num_controls = num_controls

# checking that `A_factor_list` and `B_factor_list` are consistent with `num_factors`, `num_states`, and lagging dimensions of `A` and `B` tensors
self.factorized = False
if A_factor_list == None:
self.A_factor_list = self.num_modalities * [list(range(self.num_factors))] # defaults to having all modalities depend on all factors
for m in range(self.num_modalities):
Expand All @@ -137,6 +154,7 @@ def __init__(
if self.pA is not None:
assert self.pA[m].shape[1:] == factor_dims, f"Please input an `A_factor_list` whose {m}-th indices pick out the hidden state factors that line up with lagging dimensions of pA{m}..."
else:
self.factorized = True
for m in range(self.num_modalities):
assert max(A_factor_list[m]) <= (self.num_factors - 1), f"Check modality {m} of A_factor_list - must be consistent with `num_states` and `num_factors`..."
factor_dims = tuple([self.num_states[f] for f in A_factor_list[m]])
Expand Down Expand Up @@ -164,6 +182,7 @@ def __init__(
if self.pB is not None:
assert self.pB[f].shape[1:-1] == factor_dims, f"Please input a `B_factor_list` whose {f}-th indices pick out the hidden state factors that line up with the all-but-final lagging dimensions of pB{f}..."
else:
self.factorized = True
for f in range(self.num_factors):
assert max(B_factor_list[f]) <= (self.num_factors - 1), f"Check factor {f} of B_factor_list - must be consistent with `num_states` and `num_factors`..."
factor_dims = tuple([self.num_states[f] for f in B_factor_list[f]])
Expand All @@ -186,7 +205,7 @@ def __init__(

# Again, the use can specify a set of possible policies, or
# all possible combinations of actions and timesteps will be considered
if policies == None:
if policies is None:
policies = self._construct_policies()
self.policies = policies

Expand Down Expand Up @@ -251,8 +270,10 @@ def __init__(

# Construct I for backwards induction (if H specified)
if H is not None:
self.I = control.backwards_induction(H, B, B_factor_list, threshold=1/16, depth=5)
self.H = H
self.I = control.backwards_induction(H, B, B_factor_list, threshold=ii_threshold, depth=ii_depth)
else:
self.H = None
self.I = None

self.edge_handling_params = {}
Expand Down Expand Up @@ -616,6 +637,12 @@ def infer_policies_old(self):
gamma=self.gamma
)
elif self.inference_algo == "MMP":
if self.factorized:
raise NotImplementedError("Factorized inference not implemented for MMP")

if self.sophisticated:
raise NotImplementedError("Sophisticated inference not implemented for MMP")


future_qs_seq = self.get_future_qs()

Expand Down Expand Up @@ -664,23 +691,42 @@ def infer_policies(self):
"""

if self.inference_algo == "VANILLA":
q_pi, G = control.update_posterior_policies_factorized(
self.qs,
self.A,
self.B,
self.C,
self.A_factor_list,
self.B_factor_list,
self.policies,
self.use_utility,
self.use_states_info_gain,
self.use_param_info_gain,
self.pA,
self.pB,
E=self.E,
I=self.I,
gamma=self.gamma
)
if self.sophisticated:
q_pi, G = control.sophisticated_inference_search(
self.qs,
self.policies,
self.A,
self.B,
self.C,
self.A_factor_list,
self.B_factor_list,
self.I,
self.si_horizon,
self.si_policy_prune_threshold,
self.si_state_prune_threshold,
self.si_prune_penalty,
1.0,
self.inference_params,
n=0
)
else:
q_pi, G = control.update_posterior_policies_factorized(
self.qs,
self.A,
self.B,
self.C,
self.A_factor_list,
self.B_factor_list,
self.policies,
self.use_utility,
self.use_states_info_gain,
self.use_param_info_gain,
self.pA,
self.pB,
E = self.E,
I = self.I,
gamma = self.gamma
)
elif self.inference_algo == "MMP":

future_qs_seq = self.get_future_qs()
Expand Down
166 changes: 159 additions & 7 deletions pymdp/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import itertools
import numpy as np
from pymdp.maths import softmax, softmax_obj_arr, spm_dot, spm_wnorm, spm_MDP_G, spm_log_single, spm_log_obj_array
from pymdp.maths import softmax, softmax_obj_arr, spm_dot, spm_wnorm, spm_MDP_G, spm_log_single, kl_div, entropy
from pymdp.inference import update_posterior_states_factorized, average_states_over_policies
from pymdp import utils
import copy

Expand Down Expand Up @@ -106,7 +107,7 @@ def update_posterior_policies_full(

if I is not None:
init_qs_all_pi = [qs_seq_pi[p][0] for p in range(num_policies)]
qs_bma = inference.average_states_over_policies(init_qs_all_pi, softmax(E))
qs_bma = average_states_over_policies(init_qs_all_pi, softmax(E))

for p_idx, policy in enumerate(policies):

Expand Down Expand Up @@ -236,7 +237,7 @@ def update_posterior_policies_full_factorized(

if I is not None:
init_qs_all_pi = [qs_seq_pi[p][0] for p in range(num_policies)]
qs_bma = inference.average_states_over_policies(init_qs_all_pi, softmax(E))
qs_bma = average_states_over_policies(init_qs_all_pi, softmax(E))

for p_idx, policy in enumerate(policies):

Expand All @@ -250,9 +251,9 @@ def update_posterior_policies_full_factorized(

if use_param_info_gain:
if pA is not None:
G[idx] += calc_pA_info_gain_factorized(pA, qo_seq_pi[p_idx], qs_seq_pi[p_idx], A_factor_list)
G[p_idx] += calc_pA_info_gain_factorized(pA, qo_seq_pi[p_idx], qs_seq_pi[p_idx], A_factor_list)
if pB is not None:
G[idx] += calc_pB_info_gain_interactions(pB, qs_seq_pi[p_idx], qs, B_factor_list, policy)
G[p_idx] += calc_pB_info_gain_interactions(pB, qs_seq_pi[p_idx], qs_seq_pi[p_idx], B_factor_list, policy)

if I is not None:
G[p_idx] += calc_inductive_cost(qs_bma, qs_seq_pi[p_idx], I)
Expand Down Expand Up @@ -943,7 +944,7 @@ def calc_inductive_cost(qs, qs_pi, I, epsilon=1e-3):
m = np.where(I[factor][:, idx] == 1)[0]
# we might find no path to goal (i.e. when no goal specified)
if len(m) > 0:
m = np.max(m[0]-1, 0)
m = max(m[0]-1, 0)
I_m = (1-I[factor][m, :]) * np.log(epsilon)
inductive_cost += I_m.dot(qs_pi[t][factor])

Expand Down Expand Up @@ -1308,7 +1309,158 @@ def backwards_induction(H, B, B_factor_list, threshold, depth):
for i in range(1, depth):
I[factor][i, :] = np.dot(b, I[factor][i-1, :])
I[factor][i, :] = np.where(I[factor][i, :] > 0.1, 1.0, 0.0)
# TODO stop when all 1s?

return I


def calc_ambiguity_factorized(qs_pi, A, A_factor_list):
"""
Computes the Ambiguity term.
Parameters
----------
qs_pi: ``list`` of ``numpy.ndarray`` of dtype object
Predictive posterior beliefs over hidden states expected under the policy, where ``qs_pi[t]`` stores the beliefs about
hidden states expected under the policy at time ``t``
A: ``numpy.ndarray`` of dtype object
Sensory likelihood mapping or 'observation model', mapping from hidden states to observations. Each element ``A[m]`` of
stores an ``numpy.ndarray`` multidimensional array for observation modality ``m``, whose entries ``A[m][i, j, k, ...]`` store
the probability of observation level ``i`` given hidden state levels ``j, k, ...``
A_factor_list: ``list`` of ``list`` of ``int``
List of lists, where ``A_factor_list[m]`` is a list of the hidden state factor indices that observation modality with the index ``m`` depends on
Returns
-------
ambiguity: float
"""

n_steps = len(qs_pi)

ambiguity = 0
# TODO check if we do this correctly!
H = entropy(A)
for t in range(n_steps):
for m, H_m in enumerate(H):
factor_idx = A_factor_list[m]
# TODO why does spm_dot return an array here?
# joint_x = maths.spm_cross(qs_pi[t][factor_idx])
# ambiguity += (H_m * joint_x).sum()
ambiguity += np.sum(spm_dot(H_m, qs_pi[t][factor_idx]))

return ambiguity


def sophisticated_inference_search(qs, policies, A, B, C, A_factor_list, B_factor_list, I=None, horizon=1,
policy_prune_threshold=1/16, state_prune_threshold=1/16, prune_penalty=512, gamma=16,
inference_params = {"num_iter": 10, "dF": 1.0, "dF_tol": 0.001, "compute_vfe": False}, n=0):
"""
Performs sophisticated inference to find the optimal policy for a given generative model and prior preferences.
Parameters
----------
qs: ``numpy.ndarray`` of dtype object
Marginal posterior beliefs over hidden states at a given timepoint.
policies: ``list`` of 1D ``numpy.ndarray`` inference_params = {"num_iter": 10, "dF": 1.0, "dF_tol": 0.001, "compute_vfe": False}
``list`` that stores each policy as a 1D array in ``policies[p_idx]``. Shape of ``policies[p_idx]``
is ``(num_factors)`` where ``num_factors`` is the number of control factors.
A: ``numpy.ndarray`` of dtype object
Sensory likelihood mapping or 'observation model', mapping from hidden states to observations. Each element ``A[m]`` of
stores an ``numpy.ndarray`` multidimensional array for observation modality ``m``, whose entries ``A[m][i, j, k, ...]`` store
the probability of observation level ``i`` given hidden state levels ``j, k, ...``
B: ``numpy.ndarray`` of dtype object
Dynamics likelihood mapping or 'transition model', mapping from hidden states at ``t`` to hidden states at ``t+1``, given some control state ``u``.
Each element ``B[f]`` of this object array stores a 3-D tensor for hidden state factor ``f``, whose entries ``B[f][s, v, u]`` store the probability
of hidden state level ``s`` at the current time, given hidden state level ``v`` and action ``u`` at the previous time.
C: ``numpy.ndarray`` of dtype object
Prior over observations or 'prior preferences', storing the "value" of each outcome in terms of relative log probabilities.
This is softmaxed to form a proper probability distribution before being used to compute the expected utility term of the expected free energy.
A_factor_list: ``list`` of ``list`` of ``int``
List of lists, where ``A_factor_list[m]`` is a list of the hidden state factor indices that observation modality with the index ``m`` depends on
B_factor_list: ``list`` of ``list`` of ``int``
List of lists of hidden state factors each hidden state factor depends on. Each element ``B_factor_list[i]`` is a list of the factor indices that factor i's dynamics depend on.
I: ``numpy.ndarray`` of dtype object
For each state factor, contains a 2D ``numpy.ndarray`` whose element i,j yields the probability
of reaching the goal state backwards from state j after i steps.
horizon: ``int``
The temporal depth of the policy
policy_prune_threshold: ``float``
The threshold for pruning policies that are below a certain probability
state_prune_threshold: ``float``
The threshold for pruning states in the expectation that are below a certain probability
prune_penalty: ``float``
Penalty to add to the EFE when a policy is pruned
gamma: ``float``, default 16.0
Prior precision over policies, scales the contribution of the expected free energy to the posterior over policies
n: ``int``
timestep in the future we are calculating
Returns
----------
q_pi: 1D ``numpy.ndarray``
Posterior beliefs over policies, i.e. a vector containing one posterior probability per policy.
G: 1D ``numpy.ndarray``
Negative expected free energies of each policy, i.e. a vector containing one negative expected free energy per policy.
"""

n_policies = len(policies)
G = np.zeros(n_policies)
q_pi = np.zeros((n_policies, 1))
qs_pi = utils.obj_array(n_policies)
qo_pi = utils.obj_array(n_policies)

for idx, policy in enumerate(policies):
qs_pi[idx] = get_expected_states_interactions(qs, B, B_factor_list, policy)
qo_pi[idx] = get_expected_obs_factorized(qs_pi[idx], A, A_factor_list)

G[idx] += calc_expected_utility(qo_pi[idx], C)
G[idx] += calc_states_info_gain_factorized(A, qs_pi[idx], A_factor_list)

if I is not None:
G[idx] += calc_inductive_cost(qs, qs_pi[idx], I)

q_pi = softmax(G * gamma)

if n < horizon - 1:
# ignore low probability actions in the search tree
# TODO shouldnt we have to add extra penalty for branches no longer considered?
# or assume these are already low EFE (high NEFE) anyway?
policies_to_consider = list(np.where(q_pi >= policy_prune_threshold)[0])
for idx in range(n_policies):
if idx not in policies_to_consider:
G[idx] -= prune_penalty
else :
# average over outcomes
qo_next = qo_pi[idx][0]
for k in itertools.product(*[range(s.shape[0]) for s in qo_next]):
prob = 1.0
for i in range(len(k)):
prob *= qo_pi[idx][0][i][k[i]]

# ignore low probability states in the search tree
if prob < state_prune_threshold:
continue

qo_one_hot = utils.obj_array(len(qo_next))
for i in range(len(qo_one_hot)):
qo_one_hot[i] = utils.onehot(k[i], qo_next[i].shape[0])

num_obs = [A[m].shape[0] for m in range(len(A))]
num_states = [B[f].shape[0] for f in range(len(B))]
A_modality_list = []
for f in range(len(B)):
A_modality_list.append( [m for m in range(len(A)) if f in A_factor_list[m]] )
mb_dict = {
'A_factor_list': A_factor_list,
'A_modality_list': A_modality_list
}
qs_next = update_posterior_states_factorized(A, qo_one_hot, num_obs, num_states, mb_dict, qs_pi[idx][0], **inference_params)
q_pi_next, G_next = sophisticated_inference_search(qs_next, policies, A, B, C, A_factor_list, B_factor_list, I,
horizon, policy_prune_threshold, state_prune_threshold,
prune_penalty, gamma, inference_params, n+1)
G_weighted = np.dot(q_pi_next, G_next) * prob
G[idx] += G_weighted

q_pi = softmax(G * gamma)
return q_pi, G
Loading

0 comments on commit a7d5a29

Please sign in to comment.