-
Notifications
You must be signed in to change notification settings - Fork 100
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #132 from infer-actively/agent_jax
JAX backend, sparse likelihood dependencies, sophisticated inference, inductive inference
- Loading branch information
Showing
47 changed files
with
8,639 additions
and
486 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import jax.numpy as jnp\n", | ||
"import jax.tree_util as jtu\n", | ||
"from jax import random as jr\n", | ||
"from pymdp.jax.agent import Agent as AIFAgent\n", | ||
"from pymdp.utils import random_A_matrix, random_B_matrix" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"(2, 10, 5, 4)\n", | ||
"[1 1]\n", | ||
"(10, 3, 3, 3)\n", | ||
"(10, 3, 3, 2)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"def scan(f, init, xs, length=None, axis=0):\n", | ||
" if xs is None:\n", | ||
" xs = [None] * length\n", | ||
" carry = init\n", | ||
" ys = []\n", | ||
" for x in xs:\n", | ||
" carry, y = f(carry, x)\n", | ||
" if y is not None:\n", | ||
" ys.append(y)\n", | ||
" \n", | ||
" ys = None if len(ys) < 1 else jtu.tree_map(lambda *x: jnp.stack(x,axis=axis), *ys)\n", | ||
"\n", | ||
" return carry, ys\n", | ||
"\n", | ||
"def evolve_trials(agent, env, block_idx, num_timesteps, prng_key=jr.PRNGKey(0)):\n", | ||
"\n", | ||
" batch_keys = jr.split(prng_key, batch_size)\n", | ||
" def step_fn(carry, xs):\n", | ||
" actions = carry['actions']\n", | ||
" outcomes = carry['outcomes']\n", | ||
" beliefs = agent.infer_states(outcomes, actions, *carry['args'])\n", | ||
" q_pi, _ = agent.infer_policies(beliefs)\n", | ||
" actions_t = agent.sample_action(q_pi, rng_key=batch_keys)\n", | ||
"\n", | ||
" outcome_t = env.step(actions_t)\n", | ||
" outcomes = jtu.tree_map(\n", | ||
" lambda prev_o, new_o: jnp.concatenate([prev_o, jnp.expand_dims(new_o, -1)], -1), outcomes, outcome_t\n", | ||
" )\n", | ||
"\n", | ||
" if actions is not None:\n", | ||
" actions = jnp.concatenate([actions, jnp.expand_dims(actions_t, -2)], -2)\n", | ||
" else:\n", | ||
" actions = jnp.expand_dims(actions_t, -2)\n", | ||
"\n", | ||
" args = agent.update_empirical_prior(actions_t, beliefs)\n", | ||
"\n", | ||
" ### @ NOTE !!!!: Shape of policy_probs = (num_blocks, num_trials, batch_size, num_policies) if scan axis = 0, but size of `actions` will \n", | ||
" ### be (num_blocks, batch_size, num_trials, num_controls) -- so we need to 1) swap axes to both to have the same first three dimensiosn aligned,\n", | ||
" # 2) use the action indices (the integers stored in the last dimension of `actions`) to index into the policy_probs array\n", | ||
" \n", | ||
" # args = (pred_{t+1}, [post_1, post_{2}, ..., post_{t}])\n", | ||
" # beliefs = [post_1, post_{2}, ..., post_{t}]\n", | ||
" return {'args': args, 'outcomes': outcomes, 'beliefs': beliefs, 'actions': actions}, {'policy_probs': q_pi}\n", | ||
"\n", | ||
" \n", | ||
" outcome_0 = jtu.tree_map(lambda x: jnp.expand_dims(x, -1), env.step())\n", | ||
" # qs_hist = jtu.tree_map(lambda x: jnp.expand_dims(x, -2), agent.D) # add a time dimension to the initial state prior\n", | ||
" init = {\n", | ||
" 'args': (agent.D, None,),\n", | ||
" 'outcomes': outcome_0, \n", | ||
" 'beliefs': [],\n", | ||
" 'actions': None\n", | ||
" }\n", | ||
" last, q_pis_ = scan(step_fn, init, range(num_timesteps), axis=1)\n", | ||
"\n", | ||
" return last, q_pis_, env\n", | ||
"\n", | ||
"def step_fn(carry, block_idx):\n", | ||
" agent, env = carry\n", | ||
" output, q_pis_, env = evolve_trials(agent, env, block_idx, num_timesteps)\n", | ||
" args = output.pop('args')\n", | ||
" output['beliefs'] = agent.infer_states(output['outcomes'], output['actions'], *args)\n", | ||
" output.update(q_pis_)\n", | ||
"\n", | ||
" # How to deal with contiguous blocks of trials? Two options we can imagine: \n", | ||
" # A) you use final posterior (over current and past timesteps) to compute the smoothing distribution over qs_{t=0} and update pD, and then pass pD as the initial state prior ($D = \\mathbb{E}_{pD}[qs_{t=0}]$);\n", | ||
" # B) we don't assume that blocks 'reset time', and are really just adjacent chunks of one long sequence, so you set the initial state prior to be the final output (`output['beliefs']`) passed through\n", | ||
" # the transition model entailed by the action taken at the last timestep of the previous block.\n", | ||
" # print(output['beliefs'].shape)\n", | ||
" agent = agent.learning(**output)\n", | ||
" \n", | ||
" return (agent, env), output\n", | ||
"\n", | ||
"# define an agent and environment here\n", | ||
"batch_size = 10\n", | ||
"num_obs = [3, 3]\n", | ||
"num_states = [3, 3]\n", | ||
"num_controls = [2, 2]\n", | ||
"num_blocks = 2\n", | ||
"num_timesteps = 5\n", | ||
"\n", | ||
"A_np = random_A_matrix(num_obs=num_obs, num_states=num_states)\n", | ||
"B_np = random_B_matrix(num_states=num_states, num_controls=num_controls)\n", | ||
"A = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(A_np))\n", | ||
"B = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(B_np))\n", | ||
"C = [jnp.zeros((batch_size, no)) for no in num_obs]\n", | ||
"D = [jnp.ones((batch_size, ns)) / ns for ns in num_states]\n", | ||
"E = jnp.ones((batch_size, 4 )) / 4 \n", | ||
"\n", | ||
"pA = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(A_np))\n", | ||
"pB = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(B_np))\n", | ||
"\n", | ||
"class TestEnv:\n", | ||
" def __init__(self, num_obs, prng_key=jr.PRNGKey(0)):\n", | ||
" self.num_obs=num_obs\n", | ||
" self.key = prng_key\n", | ||
" def step(self, actions=None):\n", | ||
" # return a list of random observations for each agent or parallel realization (each entry in batch_dim)\n", | ||
" obs = [jr.randint(self.key, (batch_size,), 0, no) for no in self.num_obs]\n", | ||
" self.key, _ = jr.split(self.key)\n", | ||
" return obs\n", | ||
"\n", | ||
"agents = AIFAgent(A, B, C, D, E, pA, pB, use_param_info_gain=True, use_inductive=False, inference_algo='fpi', sampling_mode='marginal', action_selection='stochastic')\n", | ||
"env = TestEnv(num_obs)\n", | ||
"init = (agents, env)\n", | ||
"(agents, env), sequences = scan(step_fn, init, range(num_blocks) )\n", | ||
"print(sequences['policy_probs'].shape)\n", | ||
"print(sequences['actions'][0][0][0])\n", | ||
"print(agents.A[0].shape)\n", | ||
"print(agents.B[0].shape)\n", | ||
"# def loss_fn(agents):\n", | ||
"# env = TestEnv(num_obs)\n", | ||
"# init = (agents, env)\n", | ||
"# (agents, env), sequences = scan(step_fn, init, range(num_blocks)) \n", | ||
"\n", | ||
"# return jnp.sum(jnp.log(sequences['policy_probs']))\n", | ||
"\n", | ||
"# dLoss_dAgents = jax.grad(loss_fn)(agents)\n", | ||
"# print(dLoss_dAgents.A[0].shape)\n", | ||
"\n", | ||
"\n", | ||
"# sequences = jtu.tree_map(lambda x: x.swapaxes(1, 2), sequences)\n", | ||
"\n", | ||
"# NOTE: all elements of sequences will have dimensionality blocks, trials, batch_size, ...\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "jax_pymdp_test", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.6" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Imports" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pymdp.jax import control\n", | ||
"import jax.numpy as jnp\n", | ||
"import jax.tree_util as jtu\n", | ||
"from jax import nn, vmap, random, lax\n", | ||
"\n", | ||
"from typing import List, Optional\n", | ||
"from jaxtyping import Array\n", | ||
"from jax import random as jr" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Set up generative model (random one with trivial observation model)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Set up a generative model\n", | ||
"num_states = [5, 3]\n", | ||
"num_controls = [2, 2]\n", | ||
"\n", | ||
"# make some arbitrary policies (policy depth 3, 2 control factors)\n", | ||
"policy_1 = jnp.array([[0, 1],\n", | ||
" [1, 1],\n", | ||
" [0, 0]])\n", | ||
"policy_2 = jnp.array([[1, 0],\n", | ||
" [0, 0],\n", | ||
" [1, 1]])\n", | ||
"policy_matrix = jnp.stack([policy_1, policy_2]) \n", | ||
"\n", | ||
"# observation modalities (isomorphic/identical to hidden states, just need to include for the need to include likleihood model)\n", | ||
"num_obs = [5, 3]\n", | ||
"num_factors = len(num_states)\n", | ||
"num_modalities = len(num_obs)\n", | ||
"\n", | ||
"# sample parameters of the model (A, B, C)\n", | ||
"key = jr.PRNGKey(1)\n", | ||
"factor_keys = jr.split(key, num_factors)\n", | ||
"\n", | ||
"d = [0.1* jr.uniform(factor_key, (ns,)) for factor_key, ns in zip(factor_keys, num_states)]\n", | ||
"qs_init = [jr.dirichlet(factor_key, d_f) for factor_key, d_f in zip(factor_keys, d)]\n", | ||
"A = [jnp.eye(no) for no in num_obs]\n", | ||
"\n", | ||
"factor_keys = jr.split(factor_keys[-1], num_factors)\n", | ||
"b = [jr.uniform(factor_keys[f], shape=(num_controls[f], num_states[f], num_states[f])) for f in range(num_factors)]\n", | ||
"b_sparse = [jnp.where(b_f < 0.75, 1e-5, b_f) for b_f in b]\n", | ||
"B = [jnp.swapaxes(jr.dirichlet(factor_keys[f], b_sparse[f]), 2, 0) for f in range(num_factors)]\n", | ||
"\n", | ||
"modality_keys = jr.split(factor_keys[-1], num_modalities)\n", | ||
"C = [nn.one_hot(jr.randint(modality_keys[m], shape=(1,), minval=0, maxval=num_obs[m]), num_obs[m]) for m in range(num_modalities)]\n", | ||
"\n", | ||
"# trivial dependencies -- factor 1 drives modality 1, etc.\n", | ||
"A_dependencies = [[0], [1]]\n", | ||
"B_dependencies = [[0], [1]]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Generate sparse constraints vectors `H` and inductive matrix `I`, using inductive parameters like depth and threshold " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# generate random constraints (H vector)\n", | ||
"factor_keys = jr.split(key, num_factors)\n", | ||
"H = [jr.uniform(factor_key, (ns,)) for factor_key, ns in zip(factor_keys, num_states)]\n", | ||
"H = [jnp.where(h < 0.75, 0., 1.) for h in H]\n", | ||
"\n", | ||
"# depth and threshold for inductive planning algorithm. I made policy-depth equal to inductive planning depth, out of ignorance -- need to ask Tim or Tommaso about this\n", | ||
"inductive_depth, inductive_threshold = 3, 0.5\n", | ||
"I = control.generate_I_matrix(H, B, inductive_threshold, inductive_depth)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Evaluate posterior probability of policies and negative EFE using new version of `update_posterior_policies`\n", | ||
"#### This function no longer computes info gain (for both states and parameters) since deterministic model is assumed, and includes new inductive matrix `I` and `inductive_epsilon` parameter" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# evaluate Q(pi) and negative EFE using the inductive planning algorithm\n", | ||
"\n", | ||
"E = jnp.ones(policy_matrix.shape[0])\n", | ||
"pA = jtu.tree_map(lambda a: jnp.ones_like(a), A)\n", | ||
"pB = jtu.tree_map(lambda b: jnp.ones_like(b), B)\n", | ||
"\n", | ||
"q_pi, neg_efe = control.update_posterior_policies_inductive(policy_matrix, qs_init, A, B, C, E, pA, pB, A_dependencies, B_dependencies, I, gamma=16.0, use_utility=True, use_inductive=True, inductive_epsilon=1e-3)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "atari_env", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.