Skip to content

Commit

Permalink
Merge pull request #132 from infer-actively/agent_jax
Browse files Browse the repository at this point in the history
JAX backend, sparse likelihood dependencies, sophisticated inference, inductive inference
  • Loading branch information
conorheins authored Jun 6, 2024
2 parents 6c23ab9 + 540e855 commit bfc1346
Show file tree
Hide file tree
Showing 47 changed files with 8,639 additions and 486 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
python-version: ["3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v2
Expand Down
60 changes: 30 additions & 30 deletions docs/notebooks/active_inference_from_scratch.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions docs/notebooks/cue_chaining_demo.ipynb

Large diffs are not rendered by default.

40 changes: 20 additions & 20 deletions docs/notebooks/free_energy_calculation.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions docs/notebooks/pymdp_fundamentals.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[0.53712305 0.46287695]\n"
"[0.13370366 0.86629634]\n"
]
}
],
Expand Down Expand Up @@ -533,7 +533,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[0, 1, 6]\n"
"[2, 2, 0]\n"
]
}
],
Expand Down Expand Up @@ -630,7 +630,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
"version": "3.8.8"
},
"vscode": {
"interpreter": {
Expand Down
6 changes: 3 additions & 3 deletions docs/notebooks/tmaze_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -628,17 +628,17 @@
"output_type": "stream",
"text": [
" === Starting experiment === \n",
" Reward condition: Right, Observation: [CENTER, No reward, Cue Right]\n",
" Reward condition: Right, Observation: [CENTER, No reward, Cue Left]\n",
"[Step 0] Action: [Move to CUE LOCATION]\n",
"[Step 0] Observation: [CUE LOCATION, No reward, Cue Right]\n",
"[Step 1] Action: [Move to RIGHT ARM]\n",
"[Step 1] Observation: [RIGHT ARM, Reward!, Cue Right]\n",
"[Step 2] Action: [Move to RIGHT ARM]\n",
"[Step 2] Observation: [RIGHT ARM, Reward!, Cue Left]\n",
"[Step 3] Action: [Move to RIGHT ARM]\n",
"[Step 3] Observation: [RIGHT ARM, Reward!, Cue Left]\n",
"[Step 3] Observation: [RIGHT ARM, Reward!, Cue Right]\n",
"[Step 4] Action: [Move to RIGHT ARM]\n",
"[Step 4] Observation: [RIGHT ARM, Reward!, Cue Left]\n"
"[Step 4] Observation: [RIGHT ARM, Reward!, Cue Right]\n"
]
}
],
Expand Down
92 changes: 46 additions & 46 deletions docs/notebooks/using_the_agent_class.ipynb

Large diffs are not rendered by default.

182 changes: 182 additions & 0 deletions examples/building_up_agent_loop.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import jax.tree_util as jtu\n",
"from jax import random as jr\n",
"from pymdp.jax.agent import Agent as AIFAgent\n",
"from pymdp.utils import random_A_matrix, random_B_matrix"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(2, 10, 5, 4)\n",
"[1 1]\n",
"(10, 3, 3, 3)\n",
"(10, 3, 3, 2)\n"
]
}
],
"source": [
"def scan(f, init, xs, length=None, axis=0):\n",
" if xs is None:\n",
" xs = [None] * length\n",
" carry = init\n",
" ys = []\n",
" for x in xs:\n",
" carry, y = f(carry, x)\n",
" if y is not None:\n",
" ys.append(y)\n",
" \n",
" ys = None if len(ys) < 1 else jtu.tree_map(lambda *x: jnp.stack(x,axis=axis), *ys)\n",
"\n",
" return carry, ys\n",
"\n",
"def evolve_trials(agent, env, block_idx, num_timesteps, prng_key=jr.PRNGKey(0)):\n",
"\n",
" batch_keys = jr.split(prng_key, batch_size)\n",
" def step_fn(carry, xs):\n",
" actions = carry['actions']\n",
" outcomes = carry['outcomes']\n",
" beliefs = agent.infer_states(outcomes, actions, *carry['args'])\n",
" q_pi, _ = agent.infer_policies(beliefs)\n",
" actions_t = agent.sample_action(q_pi, rng_key=batch_keys)\n",
"\n",
" outcome_t = env.step(actions_t)\n",
" outcomes = jtu.tree_map(\n",
" lambda prev_o, new_o: jnp.concatenate([prev_o, jnp.expand_dims(new_o, -1)], -1), outcomes, outcome_t\n",
" )\n",
"\n",
" if actions is not None:\n",
" actions = jnp.concatenate([actions, jnp.expand_dims(actions_t, -2)], -2)\n",
" else:\n",
" actions = jnp.expand_dims(actions_t, -2)\n",
"\n",
" args = agent.update_empirical_prior(actions_t, beliefs)\n",
"\n",
" ### @ NOTE !!!!: Shape of policy_probs = (num_blocks, num_trials, batch_size, num_policies) if scan axis = 0, but size of `actions` will \n",
" ### be (num_blocks, batch_size, num_trials, num_controls) -- so we need to 1) swap axes to both to have the same first three dimensiosn aligned,\n",
" # 2) use the action indices (the integers stored in the last dimension of `actions`) to index into the policy_probs array\n",
" \n",
" # args = (pred_{t+1}, [post_1, post_{2}, ..., post_{t}])\n",
" # beliefs = [post_1, post_{2}, ..., post_{t}]\n",
" return {'args': args, 'outcomes': outcomes, 'beliefs': beliefs, 'actions': actions}, {'policy_probs': q_pi}\n",
"\n",
" \n",
" outcome_0 = jtu.tree_map(lambda x: jnp.expand_dims(x, -1), env.step())\n",
" # qs_hist = jtu.tree_map(lambda x: jnp.expand_dims(x, -2), agent.D) # add a time dimension to the initial state prior\n",
" init = {\n",
" 'args': (agent.D, None,),\n",
" 'outcomes': outcome_0, \n",
" 'beliefs': [],\n",
" 'actions': None\n",
" }\n",
" last, q_pis_ = scan(step_fn, init, range(num_timesteps), axis=1)\n",
"\n",
" return last, q_pis_, env\n",
"\n",
"def step_fn(carry, block_idx):\n",
" agent, env = carry\n",
" output, q_pis_, env = evolve_trials(agent, env, block_idx, num_timesteps)\n",
" args = output.pop('args')\n",
" output['beliefs'] = agent.infer_states(output['outcomes'], output['actions'], *args)\n",
" output.update(q_pis_)\n",
"\n",
" # How to deal with contiguous blocks of trials? Two options we can imagine: \n",
" # A) you use final posterior (over current and past timesteps) to compute the smoothing distribution over qs_{t=0} and update pD, and then pass pD as the initial state prior ($D = \\mathbb{E}_{pD}[qs_{t=0}]$);\n",
" # B) we don't assume that blocks 'reset time', and are really just adjacent chunks of one long sequence, so you set the initial state prior to be the final output (`output['beliefs']`) passed through\n",
" # the transition model entailed by the action taken at the last timestep of the previous block.\n",
" # print(output['beliefs'].shape)\n",
" agent = agent.learning(**output)\n",
" \n",
" return (agent, env), output\n",
"\n",
"# define an agent and environment here\n",
"batch_size = 10\n",
"num_obs = [3, 3]\n",
"num_states = [3, 3]\n",
"num_controls = [2, 2]\n",
"num_blocks = 2\n",
"num_timesteps = 5\n",
"\n",
"A_np = random_A_matrix(num_obs=num_obs, num_states=num_states)\n",
"B_np = random_B_matrix(num_states=num_states, num_controls=num_controls)\n",
"A = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(A_np))\n",
"B = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(B_np))\n",
"C = [jnp.zeros((batch_size, no)) for no in num_obs]\n",
"D = [jnp.ones((batch_size, ns)) / ns for ns in num_states]\n",
"E = jnp.ones((batch_size, 4 )) / 4 \n",
"\n",
"pA = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(A_np))\n",
"pB = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(B_np))\n",
"\n",
"class TestEnv:\n",
" def __init__(self, num_obs, prng_key=jr.PRNGKey(0)):\n",
" self.num_obs=num_obs\n",
" self.key = prng_key\n",
" def step(self, actions=None):\n",
" # return a list of random observations for each agent or parallel realization (each entry in batch_dim)\n",
" obs = [jr.randint(self.key, (batch_size,), 0, no) for no in self.num_obs]\n",
" self.key, _ = jr.split(self.key)\n",
" return obs\n",
"\n",
"agents = AIFAgent(A, B, C, D, E, pA, pB, use_param_info_gain=True, use_inductive=False, inference_algo='fpi', sampling_mode='marginal', action_selection='stochastic')\n",
"env = TestEnv(num_obs)\n",
"init = (agents, env)\n",
"(agents, env), sequences = scan(step_fn, init, range(num_blocks) )\n",
"print(sequences['policy_probs'].shape)\n",
"print(sequences['actions'][0][0][0])\n",
"print(agents.A[0].shape)\n",
"print(agents.B[0].shape)\n",
"# def loss_fn(agents):\n",
"# env = TestEnv(num_obs)\n",
"# init = (agents, env)\n",
"# (agents, env), sequences = scan(step_fn, init, range(num_blocks)) \n",
"\n",
"# return jnp.sum(jnp.log(sequences['policy_probs']))\n",
"\n",
"# dLoss_dAgents = jax.grad(loss_fn)(agents)\n",
"# print(dLoss_dAgents.A[0].shape)\n",
"\n",
"\n",
"# sequences = jtu.tree_map(lambda x: x.swapaxes(1, 2), sequences)\n",
"\n",
"# NOTE: all elements of sequences will have dimensionality blocks, trials, batch_size, ...\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "jax_pymdp_test",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
146 changes: 146 additions & 0 deletions examples/inductive_inference_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from pymdp.jax import control\n",
"import jax.numpy as jnp\n",
"import jax.tree_util as jtu\n",
"from jax import nn, vmap, random, lax\n",
"\n",
"from typing import List, Optional\n",
"from jaxtyping import Array\n",
"from jax import random as jr"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Set up generative model (random one with trivial observation model)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Set up a generative model\n",
"num_states = [5, 3]\n",
"num_controls = [2, 2]\n",
"\n",
"# make some arbitrary policies (policy depth 3, 2 control factors)\n",
"policy_1 = jnp.array([[0, 1],\n",
" [1, 1],\n",
" [0, 0]])\n",
"policy_2 = jnp.array([[1, 0],\n",
" [0, 0],\n",
" [1, 1]])\n",
"policy_matrix = jnp.stack([policy_1, policy_2]) \n",
"\n",
"# observation modalities (isomorphic/identical to hidden states, just need to include for the need to include likleihood model)\n",
"num_obs = [5, 3]\n",
"num_factors = len(num_states)\n",
"num_modalities = len(num_obs)\n",
"\n",
"# sample parameters of the model (A, B, C)\n",
"key = jr.PRNGKey(1)\n",
"factor_keys = jr.split(key, num_factors)\n",
"\n",
"d = [0.1* jr.uniform(factor_key, (ns,)) for factor_key, ns in zip(factor_keys, num_states)]\n",
"qs_init = [jr.dirichlet(factor_key, d_f) for factor_key, d_f in zip(factor_keys, d)]\n",
"A = [jnp.eye(no) for no in num_obs]\n",
"\n",
"factor_keys = jr.split(factor_keys[-1], num_factors)\n",
"b = [jr.uniform(factor_keys[f], shape=(num_controls[f], num_states[f], num_states[f])) for f in range(num_factors)]\n",
"b_sparse = [jnp.where(b_f < 0.75, 1e-5, b_f) for b_f in b]\n",
"B = [jnp.swapaxes(jr.dirichlet(factor_keys[f], b_sparse[f]), 2, 0) for f in range(num_factors)]\n",
"\n",
"modality_keys = jr.split(factor_keys[-1], num_modalities)\n",
"C = [nn.one_hot(jr.randint(modality_keys[m], shape=(1,), minval=0, maxval=num_obs[m]), num_obs[m]) for m in range(num_modalities)]\n",
"\n",
"# trivial dependencies -- factor 1 drives modality 1, etc.\n",
"A_dependencies = [[0], [1]]\n",
"B_dependencies = [[0], [1]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generate sparse constraints vectors `H` and inductive matrix `I`, using inductive parameters like depth and threshold "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# generate random constraints (H vector)\n",
"factor_keys = jr.split(key, num_factors)\n",
"H = [jr.uniform(factor_key, (ns,)) for factor_key, ns in zip(factor_keys, num_states)]\n",
"H = [jnp.where(h < 0.75, 0., 1.) for h in H]\n",
"\n",
"# depth and threshold for inductive planning algorithm. I made policy-depth equal to inductive planning depth, out of ignorance -- need to ask Tim or Tommaso about this\n",
"inductive_depth, inductive_threshold = 3, 0.5\n",
"I = control.generate_I_matrix(H, B, inductive_threshold, inductive_depth)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Evaluate posterior probability of policies and negative EFE using new version of `update_posterior_policies`\n",
"#### This function no longer computes info gain (for both states and parameters) since deterministic model is assumed, and includes new inductive matrix `I` and `inductive_epsilon` parameter"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# evaluate Q(pi) and negative EFE using the inductive planning algorithm\n",
"\n",
"E = jnp.ones(policy_matrix.shape[0])\n",
"pA = jtu.tree_map(lambda a: jnp.ones_like(a), A)\n",
"pB = jtu.tree_map(lambda b: jnp.ones_like(b), B)\n",
"\n",
"q_pi, neg_efe = control.update_posterior_policies_inductive(policy_matrix, qs_init, A, B, C, E, pA, pB, A_dependencies, B_dependencies, I, gamma=16.0, use_utility=True, use_inductive=True, inductive_epsilon=1e-3)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "atari_env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit bfc1346

Please sign in to comment.