From 397bca4bf6627ec2f055fbe01cecccdb4f077e3a Mon Sep 17 00:00:00 2001 From: dimarkov <5038100+dimarkov@users.noreply.github.com> Date: Mon, 17 Jun 2024 17:55:08 +0200 Subject: [PATCH] changed backward smoothing function to accomodate batching over sparse arrays --- .../inference_methods_comparison.ipynb | 67 ++++++++-------- pymdp/jax/inference.py | 78 ++++++------------- 2 files changed, 59 insertions(+), 86 deletions(-) diff --git a/examples/inference_and_learning/inference_methods_comparison.ipynb b/examples/inference_and_learning/inference_methods_comparison.ipynb index dae9fa93..0d78f77e 100644 --- a/examples/inference_and_learning/inference_methods_comparison.ipynb +++ b/examples/inference_and_learning/inference_methods_comparison.ipynb @@ -32,7 +32,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-06-17 15:10:30.638413: W external/xla/xla/service/gpu/nvptx_compiler.cc:763] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n" + "2024-06-17 17:54:09.793970: W external/xla/xla/service/gpu/nvptx_compiler.cc:763] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n" ] } ], @@ -122,15 +122,7 @@ "cell_type": "code", "execution_count": 4, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "347 µs ± 9.25 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" - ] - } - ], + "outputs": [], "source": [ "prior = agents.D\n", "action_hist = []\n", @@ -144,9 +136,25 @@ " action_hist.append(actions)\n", "\n", "v_jso = jit(vmap(smoothing_ovf))\n", - "smoothed_beliefs = v_jso(beliefs, agents.B, jnp.stack(action_hist, 1))\n", - "\n", - "%timeit v_jso(beliefs, agents.B, jnp.stack(action_hist, 1))[0][0].block_until_ready()" + "actions_seq = jnp.stack(action_hist, 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "63.1 µs ± 1.89 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + ] + } + ], + "source": [ + "smoothed_beliefs = v_jso(beliefs, agents.B, actions_seq)\n", + "%timeit v_jso(beliefs, agents.B, actions_seq)[0][0].block_until_ready()" ] }, { @@ -158,27 +166,22 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "166 µs ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + "87.7 µs ± 8.25 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ - "take_first = lambda pytree: jtu.tree_map(lambda leaf: leaf[0], pytree)\n", - "\n", - "beliefs_single = take_first(beliefs)\n", - "sparse_B_single = jtu.tree_map(lambda b: sparse.BCOO.fromdense(b[0]), agents.B)\n", - "actions_single = jnp.stack(action_hist, 1)[0]\n", + "sparse_B = jtu.tree_map(lambda b: sparse.BCOO.fromdense(b, n_batch=1), agents.B)\n", "\n", - "jso = jit(smoothing_ovf)\n", - "smoothed_beliefs_sparse = jso(beliefs_single, sparse_B_single, actions_single)\n", - "%timeit jso(beliefs_single, sparse_B_single, actions_single)[0][0].block_until_ready()" + "smoothed_beliefs_sparse = v_jso(beliefs, sparse_B, actions_seq)\n", + "%timeit v_jso(beliefs, sparse_B, actions_seq)[0][0].block_until_ready()" ] }, { @@ -190,7 +193,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -199,7 +202,7 @@ "Text(0.5, 1.0, 'Filtered beliefs')" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, @@ -229,7 +232,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -238,7 +241,7 @@ "Text(0.5, 1.0, 'Filtered beliefs')" ] }, - "execution_count": 7, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, @@ -257,11 +260,11 @@ "#with sparse matrices\n", "fig, axes = plt.subplots(2, 2, figsize=(16, 8), sharex=True)\n", "\n", - "sns.heatmap(beliefs_single[0].mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", - "sns.heatmap(beliefs_single[1].mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", + "sns.heatmap(beliefs[0][0].mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", + "sns.heatmap(beliefs[1][0].mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", "\n", - "sns.heatmap(smoothed_beliefs_sparse[0][0].mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", - "sns.heatmap(smoothed_beliefs_sparse[1][0].mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", + "sns.heatmap(smoothed_beliefs_sparse[0][0][0].mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", + "sns.heatmap(smoothed_beliefs_sparse[1][0][0].mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", "\n", "axes[0, 0].set_title('Filtered beliefs')" ] @@ -275,7 +278,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { diff --git a/pymdp/jax/inference.py b/pymdp/jax/inference.py index dc2b254c..350d3662 100644 --- a/pymdp/jax/inference.py +++ b/pymdp/jax/inference.py @@ -6,7 +6,10 @@ from .algos import run_factorized_fpi, run_mmp, run_vmp from jax import tree_util as jtu, lax from jax.experimental.sparse._base import JAXSparse -from jaxtyping import Array +from jax.experimental import sparse +from jaxtyping import Array, ArrayLike + +eps = jnp.finfo('float').eps def update_posterior_states( A, @@ -58,68 +61,32 @@ def update_posterior_states( return qs_hist -def joint_dist_factor_dense(b: Array, filtered_qs: list[Array], actions: Array): +def joint_dist_factor(b: ArrayLike, filtered_qs: list[Array], actions: Array): qs_last = filtered_qs[-1] qs_filter = filtered_qs[:-1] - # conditional dist - timestep x s_{t+1} | s_{t} - time_b = jnp.moveaxis(b[..., actions], -1, 0) - # time_b = b[...,actions].transpose([b.ndim-1] + list(range(b.ndim-1))) - - # joint dist - timestep x s_{t+1} x s_{t} - qs_joint = time_b * jnp.expand_dims(qs_filter, -1) - - # cond dist - timestep x s_{t} | s_{t+1} - qs_backward_cond = jnp.moveaxis( - qs_joint / qs_joint.sum(-2, keepdims=True), -2, -1 - ) - # tranpose_idx = list(range(len(qs_joint.shape[:-2]))) + [qs_joint.ndim-1, qs_joint.ndim-2] - # qs_backward_cond = (qs_joint / qs_joint.sum(-2, keepdims=True).todense()).transpose(tranpose_idx) - - def step_fn(qs_smooth_past, backward_b): - qs_joint = backward_b * qs_smooth_past + def step_fn(qs_smooth, xs): + qs_f, action = xs + time_b = b[..., action] + qs_j = time_b * qs_f + norm = qs_j.sum(-1, keepdims=True) + if isinstance(norm, JAXSparse): + norm = sparse.todense(norm) + norm = jnp.where(norm == 0, eps, norm) + qs_backward_cond = (qs_j / norm).T + qs_joint = qs_backward_cond * qs_smooth qs_smooth = qs_joint.sum(-1) + if isinstance(qs_smooth, JAXSparse): + qs_smooth = sparse.todense(qs_smooth) + # returns q(s_t), (q(s_t), q(s_t, s_t+1)) return qs_smooth, (qs_smooth, qs_joint) # seq_qs will contain a sequence of smoothed marginals and joints _, seq_qs = lax.scan( step_fn, qs_last, - qs_backward_cond, - reverse=True, - unroll=2 - ) - - # we add the last filtered belief to smoothed beliefs - qs_smooth_all = jnp.concatenate([seq_qs[0], jnp.expand_dims(qs_last, 0)], 0) - return qs_smooth_all, seq_qs[1] - -def joint_dist_factor_sparse(b: JAXSparse, filtered_qs: list[Array], actions: Array): - qs_last = filtered_qs[-1] - qs_filter = filtered_qs[:-1] - - # conditional dist - timestep x s_{t+1} | s_{t} - time_b = b[...,actions].transpose([b.ndim-1] + list(range(b.ndim-1))) - - # joint dist - timestep x s_{t+1} x s_{t} - qs_joint = time_b * jnp.expand_dims(qs_filter, -1) - - # cond dist - timestep x s_{t} | s_{t+1} - tranpose_idx = list(range(len(qs_joint.shape[:-2]))) + [qs_joint.ndim-1, qs_joint.ndim-2] - qs_backward_cond = (qs_joint / qs_joint.sum(-2, keepdims=True).todense()).transpose(tranpose_idx) - - def step_fn(qs_smooth_past, t): - qs_joint = qs_backward_cond[t] * qs_smooth_past - qs_smooth = qs_joint.sum(-1) - - return qs_smooth.todense(), (qs_smooth.todense(), qs_joint) - - # seq_qs will contain a sequence of smoothed marginals and joints - _, seq_qs = lax.scan( - step_fn, - qs_last, - jnp.arange(qs_backward_cond.shape[0]), + (qs_filter, actions), reverse=True, unroll=2 ) @@ -127,14 +94,17 @@ def step_fn(qs_smooth_past, t): # we add the last filtered belief to smoothed beliefs qs_smooth_all = jnp.concatenate([seq_qs[0], jnp.expand_dims(qs_last, 0)], 0) - return qs_smooth_all, seq_qs[1] + qs_joint_all = seq_qs[1] + if isinstance(qs_joint_all, JAXSparse): + qs_joint_all.shape = (len(actions),) + qs_joint_all.shape + return qs_smooth_all, qs_joint_all def smoothing_ovf(filtered_post, B, past_actions): assert len(filtered_post) == len(B) nf = len(B) # number of factors - joint = lambda b, qs, f: joint_dist_factor_sparse(b, qs, past_actions[..., f]) if isinstance(b, JAXSparse) else joint_dist_factor_dense(b, qs, past_actions[..., f]) + joint = lambda b, qs, f: joint_dist_factor(b, qs, past_actions[..., f]) marginals_and_joints = [] for b, qs, f in zip(B, filtered_post, list(range(nf))):