Skip to content

Commit

Permalink
changed backward smoothing function to accomodate batching over spars…
Browse files Browse the repository at this point in the history
…e arrays
  • Loading branch information
dimarkov committed Jun 17, 2024
1 parent 8ee1426 commit 397bca4
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 86 deletions.
67 changes: 35 additions & 32 deletions examples/inference_and_learning/inference_methods_comparison.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-17 15:10:30.638413: W external/xla/xla/service/gpu/nvptx_compiler.cc:763] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n"
"2024-06-17 17:54:09.793970: W external/xla/xla/service/gpu/nvptx_compiler.cc:763] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n"
]
}
],
Expand Down Expand Up @@ -122,15 +122,7 @@
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"347 µs ± 9.25 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
]
}
],
"outputs": [],
"source": [
"prior = agents.D\n",
"action_hist = []\n",
Expand All @@ -144,9 +136,25 @@
" action_hist.append(actions)\n",
"\n",
"v_jso = jit(vmap(smoothing_ovf))\n",
"smoothed_beliefs = v_jso(beliefs, agents.B, jnp.stack(action_hist, 1))\n",
"\n",
"%timeit v_jso(beliefs, agents.B, jnp.stack(action_hist, 1))[0][0].block_until_ready()"
"actions_seq = jnp.stack(action_hist, 1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"63.1 µs ± 1.89 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
"source": [
"smoothed_beliefs = v_jso(beliefs, agents.B, actions_seq)\n",
"%timeit v_jso(beliefs, agents.B, actions_seq)[0][0].block_until_ready()"
]
},
{
Expand All @@ -158,27 +166,22 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"166 µs ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
"87.7 µs ± 8.25 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
"source": [
"take_first = lambda pytree: jtu.tree_map(lambda leaf: leaf[0], pytree)\n",
"\n",
"beliefs_single = take_first(beliefs)\n",
"sparse_B_single = jtu.tree_map(lambda b: sparse.BCOO.fromdense(b[0]), agents.B)\n",
"actions_single = jnp.stack(action_hist, 1)[0]\n",
"sparse_B = jtu.tree_map(lambda b: sparse.BCOO.fromdense(b, n_batch=1), agents.B)\n",
"\n",
"jso = jit(smoothing_ovf)\n",
"smoothed_beliefs_sparse = jso(beliefs_single, sparse_B_single, actions_single)\n",
"%timeit jso(beliefs_single, sparse_B_single, actions_single)[0][0].block_until_ready()"
"smoothed_beliefs_sparse = v_jso(beliefs, sparse_B, actions_seq)\n",
"%timeit v_jso(beliefs, sparse_B, actions_seq)[0][0].block_until_ready()"
]
},
{
Expand All @@ -190,7 +193,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand All @@ -199,7 +202,7 @@
"Text(0.5, 1.0, 'Filtered beliefs')"
]
},
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
Expand Down Expand Up @@ -229,7 +232,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -238,7 +241,7 @@
"Text(0.5, 1.0, 'Filtered beliefs')"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
Expand All @@ -257,11 +260,11 @@
"#with sparse matrices\n",
"fig, axes = plt.subplots(2, 2, figsize=(16, 8), sharex=True)\n",
"\n",
"sns.heatmap(beliefs_single[0].mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')\n",
"sns.heatmap(beliefs_single[1].mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')\n",
"sns.heatmap(beliefs[0][0].mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')\n",
"sns.heatmap(beliefs[1][0].mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')\n",
"\n",
"sns.heatmap(smoothed_beliefs_sparse[0][0].mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')\n",
"sns.heatmap(smoothed_beliefs_sparse[1][0].mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')\n",
"sns.heatmap(smoothed_beliefs_sparse[0][0][0].mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')\n",
"sns.heatmap(smoothed_beliefs_sparse[1][0][0].mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')\n",
"\n",
"axes[0, 0].set_title('Filtered beliefs')"
]
Expand All @@ -275,7 +278,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand Down
78 changes: 24 additions & 54 deletions pymdp/jax/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from .algos import run_factorized_fpi, run_mmp, run_vmp
from jax import tree_util as jtu, lax
from jax.experimental.sparse._base import JAXSparse
from jaxtyping import Array
from jax.experimental import sparse
from jaxtyping import Array, ArrayLike

eps = jnp.finfo('float').eps

def update_posterior_states(
A,
Expand Down Expand Up @@ -58,83 +61,50 @@ def update_posterior_states(

return qs_hist

def joint_dist_factor_dense(b: Array, filtered_qs: list[Array], actions: Array):
def joint_dist_factor(b: ArrayLike, filtered_qs: list[Array], actions: Array):
qs_last = filtered_qs[-1]
qs_filter = filtered_qs[:-1]

# conditional dist - timestep x s_{t+1} | s_{t}
time_b = jnp.moveaxis(b[..., actions], -1, 0)
# time_b = b[...,actions].transpose([b.ndim-1] + list(range(b.ndim-1)))

# joint dist - timestep x s_{t+1} x s_{t}
qs_joint = time_b * jnp.expand_dims(qs_filter, -1)

# cond dist - timestep x s_{t} | s_{t+1}
qs_backward_cond = jnp.moveaxis(
qs_joint / qs_joint.sum(-2, keepdims=True), -2, -1
)
# tranpose_idx = list(range(len(qs_joint.shape[:-2]))) + [qs_joint.ndim-1, qs_joint.ndim-2]
# qs_backward_cond = (qs_joint / qs_joint.sum(-2, keepdims=True).todense()).transpose(tranpose_idx)

def step_fn(qs_smooth_past, backward_b):
qs_joint = backward_b * qs_smooth_past
def step_fn(qs_smooth, xs):
qs_f, action = xs
time_b = b[..., action]
qs_j = time_b * qs_f
norm = qs_j.sum(-1, keepdims=True)
if isinstance(norm, JAXSparse):
norm = sparse.todense(norm)
norm = jnp.where(norm == 0, eps, norm)
qs_backward_cond = (qs_j / norm).T
qs_joint = qs_backward_cond * qs_smooth
qs_smooth = qs_joint.sum(-1)
if isinstance(qs_smooth, JAXSparse):
qs_smooth = sparse.todense(qs_smooth)

# returns q(s_t), (q(s_t), q(s_t, s_t+1))
return qs_smooth, (qs_smooth, qs_joint)

# seq_qs will contain a sequence of smoothed marginals and joints
_, seq_qs = lax.scan(
step_fn,
qs_last,
qs_backward_cond,
reverse=True,
unroll=2
)

# we add the last filtered belief to smoothed beliefs
qs_smooth_all = jnp.concatenate([seq_qs[0], jnp.expand_dims(qs_last, 0)], 0)
return qs_smooth_all, seq_qs[1]

def joint_dist_factor_sparse(b: JAXSparse, filtered_qs: list[Array], actions: Array):
qs_last = filtered_qs[-1]
qs_filter = filtered_qs[:-1]

# conditional dist - timestep x s_{t+1} | s_{t}
time_b = b[...,actions].transpose([b.ndim-1] + list(range(b.ndim-1)))

# joint dist - timestep x s_{t+1} x s_{t}
qs_joint = time_b * jnp.expand_dims(qs_filter, -1)

# cond dist - timestep x s_{t} | s_{t+1}
tranpose_idx = list(range(len(qs_joint.shape[:-2]))) + [qs_joint.ndim-1, qs_joint.ndim-2]
qs_backward_cond = (qs_joint / qs_joint.sum(-2, keepdims=True).todense()).transpose(tranpose_idx)

def step_fn(qs_smooth_past, t):
qs_joint = qs_backward_cond[t] * qs_smooth_past
qs_smooth = qs_joint.sum(-1)

return qs_smooth.todense(), (qs_smooth.todense(), qs_joint)

# seq_qs will contain a sequence of smoothed marginals and joints
_, seq_qs = lax.scan(
step_fn,
qs_last,
jnp.arange(qs_backward_cond.shape[0]),
(qs_filter, actions),
reverse=True,
unroll=2
)

# we add the last filtered belief to smoothed beliefs

qs_smooth_all = jnp.concatenate([seq_qs[0], jnp.expand_dims(qs_last, 0)], 0)
return qs_smooth_all, seq_qs[1]
qs_joint_all = seq_qs[1]
if isinstance(qs_joint_all, JAXSparse):
qs_joint_all.shape = (len(actions),) + qs_joint_all.shape
return qs_smooth_all, qs_joint_all


def smoothing_ovf(filtered_post, B, past_actions):
assert len(filtered_post) == len(B)
nf = len(B) # number of factors

joint = lambda b, qs, f: joint_dist_factor_sparse(b, qs, past_actions[..., f]) if isinstance(b, JAXSparse) else joint_dist_factor_dense(b, qs, past_actions[..., f])
joint = lambda b, qs, f: joint_dist_factor(b, qs, past_actions[..., f])

marginals_and_joints = []
for b, qs, f in zip(B, filtered_post, list(range(nf))):
Expand Down

0 comments on commit 397bca4

Please sign in to comment.