Skip to content

Commit

Permalink
remove policies from static fields
Browse files Browse the repository at this point in the history
  • Loading branch information
dimarkov committed Jul 3, 2024
1 parent 865ce9e commit 7f2bbb4
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions pymdp/jax/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ class Agent(Module):

pA: List[Array]
pB: List[Array]

policies: Array # matrix of all possible policies (each row is a policy of shape (num_controls[0], num_controls[1], ..., num_controls[num_control_factors-1])

# static parameters not leaves of the PyTree
A_dependencies: Optional[List] = field(static=True)
B_dependencies: Optional[List] = field(static=True)
A_dependencies: Optional[List[int]] = field(static=True)
B_dependencies: Optional[List[int]] = field(static=True)
batch_size: int = field(static=True)
num_iter: int = field(static=True)
num_obs: List[int] = field(static=True)
Expand All @@ -69,7 +71,6 @@ class Agent(Module):
control_fac_idx: Optional[List[int]] = field(static=True)
policy_len: int = field(static=True) # depth of planning during roll-outs (i.e. number of timesteps to look ahead when computing expected free energy of policies)
inductive_depth: int = field(static=True) # depth of inductive inference (i.e. number of future timesteps to use when computing inductive `I` matrix)
policies: Array = field(static=True) # matrix of all possible policies (each row is a policy of shape (num_controls[0], num_controls[1], ..., num_controls[num_control_factors-1])
use_utility: bool = field(static=True) # flag for whether to use expected utility ("reward" or "preference satisfaction") when computing expected free energy
use_states_info_gain: bool = field(static=True) # flag for whether to use state information gain ("salience") when computing expected free energy
use_param_info_gain: bool = field(static=True) # flag for whether to use parameter information gain ("novelty") when computing expected free energy
Expand Down Expand Up @@ -365,7 +366,6 @@ def update_empirical_prior(self, action, qs):

return (pred, qs)

@vmap
def infer_policies(self, qs: List):
"""
Perform policy inference by optimizing a posterior (categorical) distribution over policies.
Expand All @@ -381,25 +381,29 @@ def infer_policies(self, qs: List):
Negative expected free energies of each policy, i.e. a vector containing one negative expected free energy per policy.
"""

latest_belief = jtu.tree_map(lambda x: x[-1], qs) # only get the posterior belief held at the current timepoint
q_pi, G = control.update_posterior_policies_inductive(
latest_belief = jtu.tree_map(lambda x: x[:, -1], qs) # only get the posterior belief held at the current timepoint
infer_policies = partial(
control.update_posterior_policies_inductive,
self.policies,
A_dependencies=self.A_dependencies,
B_dependencies=self.B_dependencies,
use_utility=self.use_utility,
use_states_info_gain=self.use_states_info_gain,
use_param_info_gain=self.use_param_info_gain,
use_inductive=self.use_inductive
)

q_pi, G = vmap(infer_policies)(
latest_belief,
self.A,
self.B,
self.C,
self.E,
self.pA,
self.pB,
A_dependencies=self.A_dependencies,
B_dependencies=self.B_dependencies,
I = self.I,
gamma=self.gamma,
inductive_epsilon=self.inductive_epsilon,
use_utility=self.use_utility,
use_states_info_gain=self.use_states_info_gain,
use_param_info_gain=self.use_param_info_gain,
use_inductive=self.use_inductive
inductive_epsilon=self.inductive_epsilon
)

return q_pi, G
Expand Down Expand Up @@ -434,7 +438,6 @@ def multiaction_probabilities(self, q_pi: Array):

return marginals

@vmap
def sample_action(self, q_pi: Array, rng_key=None):
"""
Sample or select a discrete action from the posterior over control states.
Expand All @@ -451,9 +454,11 @@ def sample_action(self, q_pi: Array, rng_key=None):
raise ValueError("Please provide a random number generator key to sample actions stochastically")

if self.sampling_mode == "marginal":
action = control.sample_action(q_pi, self.policies, self.num_controls, self.action_selection, self.alpha, rng_key=rng_key)
sample_action = partial(control.sample_action, self.policies, self.num_controls, action_selection=self.action_selection)
action = vmap(sample_action)(q_pi, alpha=self.alpha, rng_key=rng_key)
elif self.sampling_mode == "full":
action = control.sample_policy(q_pi, self.policies, self.num_controls, self.action_selection, self.alpha, rng_key=rng_key)
sample_policy = partial(control.sample_policy, self.policies, action_selection=self.action_selection)
action = vmap(sample_policy)(q_pi, alpha=self.alpha, rng_key=rng_key)

return action

Expand Down

0 comments on commit 7f2bbb4

Please sign in to comment.