diff --git a/examples/inference_and_learning/inference_methods_comparison.ipynb b/examples/inference_and_learning/inference_methods_comparison.ipynb new file mode 100644 index 00000000..dae9fa93 --- /dev/null +++ b/examples/inference_and_learning/inference_methods_comparison.ipynb @@ -0,0 +1,397 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "from jax import tree_util as jtu, vmap, jit\n", + "from jax.experimental import sparse\n", + "from pymdp.jax.agent import Agent\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "from pymdp.jax.inference import smoothing_ovf" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set up generative model and a sequence of observations. The A tensors, B tensors and observations are specified in such a way that only later observations ($o_{t > 1}$) help disambiguate hidden states at earlier time points. This will demonstrate the importance of \"smoothing\" or retrospective inference" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-17 15:10:30.638413: W external/xla/xla/service/gpu/nvptx_compiler.cc:763] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n" + ] + } + ], + "source": [ + "num_states = [3, 2]\n", + "num_obs = [2]\n", + "n_batch = 2\n", + "\n", + "A_1 = jnp.array([[1.0, 1.0, 1.0], [0.0, 0.0, 1.]])\n", + "A_2 = jnp.array([[1.0, 1.0], [1., 0.]])\n", + "\n", + "A_tensor = A_1[..., None] * A_2[:, None]\n", + "\n", + "A_tensor /= A_tensor.sum(0)\n", + "\n", + "A = [jnp.broadcast_to(A_tensor, (n_batch, num_obs[0], 3, 2)) ]\n", + "\n", + "# create two transition matrices, one for each state factor\n", + "B_1 = jnp.broadcast_to(\n", + " jnp.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]), (n_batch, 3, 3)\n", + ")\n", + "\n", + "B_2 = jnp.broadcast_to(\n", + " jnp.array([[0.0, 1.0], [1.0, 0.0]]), (n_batch, 2, 2)\n", + " )\n", + "\n", + "B = [B_1[..., None], B_2[..., None]]\n", + "\n", + "# for the single modality, a sequence over time of observations (one hot vectors)\n", + "obs = [jnp.broadcast_to(jnp.array([[1., 0.], # observation 0 is ambiguous with respect state factors\n", + " [1., 0], # observation 0 is ambiguous with respect state factors\n", + " [1., 0], # observation 0 is ambiguous with respect state factors\n", + " [0., 1.]])[:, None], (4, n_batch, num_obs[0]) )] # observation 1 provides information about exact state of both factors \n", + "C = [jnp.zeros((n_batch, num_obs[0]))] # flat preferences\n", + "D = [jnp.ones((n_batch, 3)) / 3., jnp.ones((n_batch, 2)) / 2.] # flat prior\n", + "E = jnp.ones((n_batch, 1))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Construct the `Agent`" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "pA = None\n", + "pB = None\n", + "\n", + "agents = Agent(\n", + " A=A,\n", + " B=B,\n", + " C=C,\n", + " D=D,\n", + " E=E,\n", + " pA=pA,\n", + " pB=pB,\n", + " policy_len=3,\n", + " onehot_obs=True,\n", + " action_selection=\"deterministic\",\n", + " sampling_mode=\"full\",\n", + " inference_algo=\"ovf\",\n", + " num_iter=16\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using `obs` and `policies`, pass in the arguments `outcomes`, `past_actions`, `empirical_prior` and `qs_hist` to `agent.infer_states(...)`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run first timestep of inference using `obs[0]`, no past actions, empirical prior set to actual prior, no qs_hist" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "347 µs ± 9.25 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" + ] + } + ], + "source": [ + "prior = agents.D\n", + "action_hist = []\n", + "qs_hist=None\n", + "for t in range(len(obs[0])):\n", + " first_obs = jtu.tree_map(lambda x: jnp.moveaxis(x[:t+1], 0, 1), obs)\n", + " beliefs = agents.infer_states(first_obs, prior, qs_hist=qs_hist)\n", + " actions = jnp.broadcast_to(agents.policies[0, 0], (2, 2))\n", + " prior, qs_hist = agents.update_empirical_prior(actions, beliefs)\n", + " if t < len(obs[0]) - 1:\n", + " action_hist.append(actions)\n", + "\n", + "v_jso = jit(vmap(smoothing_ovf))\n", + "smoothed_beliefs = v_jso(beliefs, agents.B, jnp.stack(action_hist, 1))\n", + "\n", + "%timeit v_jso(beliefs, agents.B, jnp.stack(action_hist, 1))[0][0].block_until_ready()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Try the non-vmapped version of `smoothing_ovf`\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "166 µs ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + ] + } + ], + "source": [ + "take_first = lambda pytree: jtu.tree_map(lambda leaf: leaf[0], pytree)\n", + "\n", + "beliefs_single = take_first(beliefs)\n", + "sparse_B_single = jtu.tree_map(lambda b: sparse.BCOO.fromdense(b[0]), agents.B)\n", + "actions_single = jnp.stack(action_hist, 1)[0]\n", + "\n", + "jso = jit(smoothing_ovf)\n", + "smoothed_beliefs_sparse = jso(beliefs_single, sparse_B_single, actions_single)\n", + "%timeit jso(beliefs_single, sparse_B_single, actions_single)[0][0].block_until_ready()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Now we can plot that pair of filtering / smoothing distributions for the single batch / single agent, that we ran" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'Filtered beliefs')" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABQMAAAKqCAYAAACO80jyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAqcklEQVR4nO3de5CV9X348c9ZLisibJOAC3glGAPeMN0oXkawohI7NSFjxdhpC16CNkhr0DRhfmnQZEZ0NJISvCAdxdhfmphGaNpJyiiCxvxoUSjGXBTbaMyogBQDdcWj7D6/P/Jzf9nIZZ+9cHb5vF4z5499zjnPfh7m2eXLm3OeUymKoggAAAAA4IBXV+sBAAAAAID9QwwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwE9ujFF1+MSqUSS5cubdt2ww03RKVSqd1QJXV03rPPPjtOOOGEbv3eRx99dMyYMaPt69WrV0elUonVq1d3an9PPvlknHHGGTF48OCoVCqxYcOGbpkTAACAPMRASGzp0qVRqVR2e/vCF77Q4f3cdNNNsXz58p4blHjnnXfi4osvjm3btsWCBQvigQceiKOOOqrWYwEAANDH9K/1AEDtffnLX47Ro0e323bCCSfEUUcdFTt37owBAwbs9fk33XRT/PEf/3FMnTq1B6fs+yZOnBg7d+6MgQMHln7uf/3Xf8Uvf/nLWLJkSVx55ZU9MB0AAAAZiIFAXHDBBfHRj350t/cddNBB+3ma33jrrbdi4MCBUVd34LyAua6urtN/nlu2bImIiN/7vd/rxokAAADI5sD5VzbQ7XZ3zcDfValUorm5Oe6///62txj/9nXyXn755bj88sujsbEx6uvr4/jjj49777233T7evZbet771rfjiF78Yhx12WBx88MGxY8eOiIj493//9/jYxz4WDQ0NcfDBB8ekSZPiRz/60XtmeeKJJ+KUU06Jgw46KMaMGROLFy8ufczr1q2LM844IwYNGhSjR4+Ou++++z2PqVarMW/evDjmmGOivr4+jjjiiPjrv/7rqFare933nq4ZuK/jmzFjRkyaNCkiIi6++OKoVCpx9tlnR0TEpk2b4rLLLovDDz886uvrY+TIkfGJT3wiXnzxxdLHDgAAwIHPKwOB2L59e2zdurXdtmHDhnXouQ888EBceeWVceqpp8bMmTMjImLMmDEREbF58+Y47bTTolKpxDXXXBPDhw+PH/zgB3HFFVfEjh074tprr223r6985SsxcODAuP7666NarcbAgQPj0UcfjQsuuCCamppi3rx5UVdXF/fdd1+cc8458cMf/jBOPfXUiIh45pln4vzzz4/hw4fHDTfcELt27Yp58+ZFY2Njh/8cXn/99fjDP/zDmDZtWlx66aXx4IMPxl/8xV/EwIED4/LLL4+IiNbW1vj4xz8eTzzxRMycOTPGjRsXzzzzTCxYsCA2btxY+tqJHTm+q666Kg477LC46aab4i//8i/jlFNOaTuuiy66KH7605/G7Nmz4+ijj44tW7bEww8/HC+99FIcffTRpWYBAAAggQJI67777isiYre3oiiKF154oYiI4r777mt7zrx584rf/dUxePDgYvr06e/Z/xVXXFGMHDmy2Lp1a7vtn/rUp4qGhobizTffLIqiKFatWlVERPHBD36wbVtRFEVra2vxoQ99qJgyZUrR2tratv3NN98sRo8eXZx33nlt26ZOnVocdNBBxS9/+cu2bT/72c+Kfv36vWfe3Zk0aVIREcVXv/rVtm3VarU4+eSTi0MPPbR4++23i6IoigceeKCoq6srfvjDH7Z7/t13311ERPGjH/2obdtRRx3V7s/l3eNctWpV6eN797nf+c532ra9/vrrRUQUt9566z6PDwAAAIqiKLxNGIg77rgjHn744Xa3riqKIr773e/GhRdeGEVRxNatW9tuU6ZMie3bt8f69evbPWf69OkxaNCgtq83bNgQzz//fPzJn/xJ/Pd//3fb85ubm2Py5Mnx+OOPR2tra7S0tMSKFSti6tSpceSRR7Y9f9y4cTFlypQOz9y/f/+46qqr2r4eOHBgXHXVVbFly5ZYt25dRER85zvfiXHjxsXYsWPbHdM555wTERGrVq3q8Pfr6PHtyaBBg2LgwIGxevXqeP311zv8fQEAAMjL24SBOPXUU/f4ASKd9dprr8Wvf/3ruOeee+Kee+7Z7WPe/VCMd/3uJxo///zzEfGbSLgn27dvj2q1Gjt37owPfehD77n/wx/+cHz/+9/v0MyjRo2KwYMHt9t27LHHRsRvrp942mmnxfPPPx8///nPY/jw4bvdx+8e09509Pje97737fa++vr6uOWWW+K6666LxsbGOO200+KP/uiP4s///M9jxIgRHZ4DAACAPMRAoEe8+4q2P/3TP91j7DrppJPaff3brwr87X3ceuutcfLJJ+92H4cccsg+P7ijO7W2tsaJJ54Yt99++27vP+KII0rtK2Lfx7c31157bVx44YWxfPnyWLFiRfzN3/xNzJ8/Px599NH4yEc+0uFZAAAAyEEMBLqsUqm8Z9vw4cNjyJAh0dLSEueee26n9vvuB5EMHTp0r/sYPnx4DBo0qO2Vdr/tueee6/D3e+WVV6K5ubndqwM3btwYEdH2YRxjxoyJp59+OiZPnrzb4y6jo8fXkf1cd911cd1118Xzzz8fJ598cnz1q1+Nv//7v+/SfAAAABx4XDMQ6LLBgwfHr3/963bb+vXrFxdddFF897vfjZ/85Cfvec5rr722z/02NTXFmDFj4rbbbos33nhjj/vo169fTJkyJZYvXx4vvfRS2/0///nPY8WKFR0+jl27dsXixYvbvn777bdj8eLFMXz48GhqaoqIiGnTpsXLL78cS5Ysec/zd+7cGc3NzR3+fh09vj15880346233mq3bcyYMTFkyJD9+mpJAAAA+g6vDAS6rKmpKR555JG4/fbbY9SoUTF69OiYMGFC3HzzzbFq1aqYMGFCfPrTn47jjjsutm3bFuvXr49HHnkktm3bttf91tXVxd/93d/FBRdcEMcff3xcdtllcdhhh8XLL78cq1atiqFDh8Y///M/R0TEjTfeGP/6r/8aZ511VnzmM5+JXbt2xde//vU4/vjj48c//nGHjmPUqFFxyy23xIsvvhjHHntsfPvb344NGzbEPffcEwMGDIiIiD/7sz+LBx98MK6++upYtWpVnHnmmdHS0hLPPvtsPPjgg7FixYoOX3+xzPHtzsaNG2Py5Mkxbdq0OO6446J///6xbNmy2Lx5c3zqU5/q0AwAAADkIgYCXXb77bfHzJkz44tf/GLs3Lkzpk+fHhMmTIjGxsZYu3ZtfPnLX46HHnoo7rzzzvjABz4Qxx9/fNxyyy0d2vfZZ58da9asia985SuxaNGieOONN2LEiBExYcKEdp/8e9JJJ8WKFStizpw58aUvfSkOP/zwuPHGG+PVV1/tcAx83/veF/fff3/Mnj07lixZEo2NjbFo0aL49Kc/3faYurq6WL58eSxYsCC+8Y1vxLJly+Lggw+OD37wg/FXf/VXbR840lEdPb7dOeKII+LSSy+NlStXxgMPPBD9+/ePsWPHxoMPPhgXXXRRqTkAAADIoVIURVHrIQAAAACAnueagQAAAACQhBgIAAAAAEmIgQAAAACQhBgIAAAAAEmIgQAAAACQhBgIAAAAAEmIgQAAAACQRP9aD/Cukz67oNYj0MfsOrjWE9DXHLSt1hPQl3xg8f+p9Qj0MQ+3fqfWI9BF59VdXOsRgAPcileervUI9DFTRo2v9Qj0MR1Zk3plIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAk0b/sE7Zu3Rr33ntvrFmzJjZt2hQRESNGjIgzzjgjZsyYEcOHD+/2IQEA4F3WowAAnVfqlYFPPvlkHHvssbFw4cJoaGiIiRMnxsSJE6OhoSEWLlwYY8eOjaeeeqqnZgUAIDnrUQCArin1ysDZs2fHxRdfHHfffXdUKpV29xVFEVdffXXMnj071qxZs9f9VKvVqFar7ba17toVdf1Lv1ARAIBEenQ9WrREXaVft88MANCblHpl4NNPPx2f/exn37PwioioVCrx2c9+NjZs2LDP/cyfPz8aGhra3V578pEyowAAkFBPrkdfiGd7YGIAgN6lVAwcMWJErF27do/3r127NhobG/e5n7lz58b27dvb3Yafcm6ZUQAASKgn16OjY2x3jgoA0CuVel/u9ddfHzNnzox169bF5MmT2xZamzdvjpUrV8aSJUvitttu2+d+6uvro76+vt02bxEGAGBfenQ96i3CAEACpQrcrFmzYtiwYbFgwYK48847o6WlJSIi+vXrF01NTbF06dKYNm1ajwwKAADWowAAXVP65XiXXHJJXHLJJfHOO+/E1q1bIyJi2LBhMWDAgG4fDgAAfpf1KABA53X6vbkDBgyIkSNHducsAADQYdajAADllfoAEQAAAACg7xIDAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkuhf6wGgsw55uaj1CPQxu+ortR4BADiArHjl6VqPQB8zZdT4Wo8A4JWBAAAAAJCFGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASXR7DPzVr34Vl19++V4fU61WY8eOHe1urbt2dfcoAAAk1On1aNGynyYEAKidbo+B27Zti/vvv3+vj5k/f340NDS0u7325CPdPQoAAAl1dj36Qjy7nyYEAKid/mWf8L3vfW+v9//iF7/Y5z7mzp0bc+bMabftjP+1uOwoAAAk1FPr0U82zOjKWAAAfULpGDh16tSoVCpRFMUeH1OpVPa6j/r6+qivr2+3ra5/6VEAAEiox9ajlX7dMh8AQG9W+m3CI0eOjIceeihaW1t3e1u/fn1PzAkAABFhPQoA0BWlY2BTU1OsW7duj/fv639pAQCgK6xHAQA6r/R7cz/3uc9Fc3PzHu8/5phjYtWqVV0aCgAA9sR6FACg80rHwLPOOmuv9w8ePDgmTZrU6YEAAGBvrEcBADqv9NuEAQAAAIC+SQwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCT613qAdx3ySmutR6CPaamv1HoE+piiX60noC9Z8crTtR4B2M/83FPWlFHjaz0CcIDzdxM9wSsDAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkigdA3fu3BlPPPFE/OxnP3vPfW+99VZ84xvf6JbBAABgd6xHAQA6r1QM3LhxY4wbNy4mTpwYJ554YkyaNCleffXVtvu3b98el1122T73U61WY8eOHe1urS27yk8PAEAqPbkerVZbe3J0AIBeoVQM/PznPx8nnHBCbNmyJZ577rkYMmRInHnmmfHSSy+V+qbz58+PhoaGdreXf76y1D4AAMinJ9ejN3/99R6aGgCg96gURVF09MGNjY3xyCOPxIknnhgREUVRxGc+85n4/ve/H6tWrYrBgwfHqFGjoqWlZa/7qVarUa1W2207//K7oq5f/04cAlm11FdqPQJ9zNuHOGfouHVfuqvWI9DH1I3YWOsRUujJ9eiA138/6utdUpuOmzJqfK1HAA5wK155utYj0Md0ZE1aarWzc+fO6N///we7SqUSd911V1x44YUxadKk2LixY4vg+vr6GDp0aLubEAgAwL705HpUCAQAMihV4MaOHRtPPfVUjBs3rt32RYsWRUTExz/+8e6bDAAAfof1KABA15T6789PfvKT8Q//8A+7vW/RokVx6aWXRol3HQMAQCnWowAAXVPqmoE96YxLvlrrEehjXDOQslwzkDJcM5CyXDOw72vddGytR6CPcc1AoKe5ZiBldfs1AwEAAACAvksMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkKkVRFLUegt2rVqsxf/78mDt3btTX19d6HPoA5wxlOWcoyzkD+fi5pwznC2U5ZyjLOdN1YmAvtmPHjmhoaIjt27fH0KFDaz0OfYBzhrKcM5TlnIF8/NxThvOFspwzlOWc6TpvEwYAAACAJMRAAAAAAEhCDAQAAACAJMTAXqy+vj7mzZvngph0mHOGspwzlOWcgXz83FOG84WynDOU5ZzpOh8gAgAAAABJeGUgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACTRv9YDvOvYmxbUegT6mF0HF7UegT6m/xuVWo9AH/LslXfVegT6mLoRG2s9Al10Xt3FtR6BPmbFK0/XegT6mCmjxtd6BPoYv2coqyNrUq8MBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAk+pd9wtatW+Pee++NNWvWxKZNmyIiYsSIEXHGGWfEjBkzYvjw4d0+JAAAvMt6FACg80q9MvDJJ5+MY489NhYuXBgNDQ0xceLEmDhxYjQ0NMTChQtj7Nix8dRTT+1zP9VqNXbs2NHu1rprV6cPAgCAHHp0PVq07IcjAACorUpRFEVHH3zaaafF+PHj4+67745KpdLuvqIo4uqrr44f//jHsWbNmr3u54Ybbogbb7yx3bb3n3N+fGDyx0qMTna7Du7wqQsREdH/jcq+HwT/z7NX3lXrEehj6kZsrPUIKfTkenR0jIsxleO7fWYOXCteebrWI9DHTBk1vtYj0Mf4PUNZHVmTloqBgwYNiv/4j/+IsWPH7vb+Z599Nj7ykY/Ezp0797qfarUa1Wq13bbf/9riqOtf+l3LJCYGUpYYSBliIGWJgftHT65HP9kwI+oq/bptVg58/pFOWWIgZfk9Q1kdWZOWqm8jRoyItWvX7nHxtXbt2mhsbNznfurr66O+vr7dNiEQAIB96dH1qBAIACRQqsBdf/31MXPmzFi3bl1Mnjy5baG1efPmWLlyZSxZsiRuu+22HhkUAACsRwEAuqZUDJw1a1YMGzYsFixYEHfeeWe0tPzmIsv9+vWLpqamWLp0aUybNq1HBgUAAOtRAICuKf3e3EsuuSQuueSSeOedd2Lr1q0RETFs2LAYMGBAtw8HAAC/y3oUAKDzOn2hvgEDBsTIkSO7cxYAAOgw61EAgPLqaj0AAAAAALB/iIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJ9K/1AO+qHrqr1iPQx9Tt1LIppzrK7xk6bsqo8bUegT7m4dZaT0BXrXjl6VqPQB/j7wrK8nuGsvyeoayOrEnVFAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCS6PQb+6le/issvv3yvj6lWq7Fjx452t+KdXd09CgAACXV2PVqttu6nCQEAaqfbY+C2bdvi/vvv3+tj5s+fHw0NDe1u23/waHePAgBAQp1dj9789df304QAALXTv+wTvve97+31/l/84hf73MfcuXNjzpw57bad8L/vKDsKAAAJ9dR6dMDrv9+luQAA+oLSMXDq1KlRqVSiKIo9PqZSqex1H/X19VFfX9/+OQNKjwIAQEI9tR5tfdPltAGAA1/pFc/IkSPjoYceitbW1t3e1q9f3xNzAgBARFiPAgB0RekY2NTUFOvWrdvj/fv6X1oAAOgK61EAgM4r/d7cz33uc9Hc3LzH+4855phYtWpVl4YCAIA9sR4FAOi80jHwrLPO2uv9gwcPjkmTJnV6IAAA2BvrUQCAznOVZAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIolIURVHrIdi9arUa8+fPj7lz50Z9fX2tx6EPcM5QlnOGspwzkI+fe8pwvlCWc4aynDNdJwb2Yjt27IiGhobYvn17DB06tNbj0Ac4ZyjLOUNZzhnIx889ZThfKMs5Q1nOma7zNmEAAAAASEIMBAAAAIAkxEAAAAAASEIM7MXq6+tj3rx5LohJhzlnKMs5Q1nOGcjHzz1lOF8oyzlDWc6ZrvMBIgAAAACQhFcGAgAAAEASYiAAAAAAJCEGAgAAAEASYiAAAAAAJCEG9mJ33HFHHH300XHQQQfFhAkTYu3atbUeiV7q8ccfjwsvvDBGjRoVlUolli9fXuuR6OXmz58fp5xySgwZMiQOPfTQmDp1ajz33HO1Hote7K677oqTTjophg4dGkOHDo3TTz89fvCDH9R6LKCHWY9ShjUpZViPUpb1aPcRA3upb3/72zFnzpyYN29erF+/PsaPHx9TpkyJLVu21Ho0eqHm5uYYP3583HHHHbUehT7isccei1mzZsW//du/xcMPPxzvvPNOnH/++dHc3Fzr0eilDj/88Lj55ptj3bp18dRTT8U555wTn/jEJ+KnP/1prUcDeoj1KGVZk1KG9ShlWY92n0pRFEWth+C9JkyYEKecckosWrQoIiJaW1vjiCOOiNmzZ8cXvvCFGk9Hb1apVGLZsmUxderUWo9CH/Laa6/FoYceGo899lhMnDix1uPQR7z//e+PW2+9Na644opajwL0AOtRusKalLKsR+kM69HO8crAXujtt9+OdevWxbnnntu2ra6uLs4999xYs2ZNDScDDlTbt2+PiN/8ZQr70tLSEt/61reiubk5Tj/99FqPA/QA61Fgf7MepQzr0a7pX+sBeK+tW7dGS0tLNDY2ttve2NgYzz77bI2mAg5Ura2tce2118aZZ54ZJ5xwQq3HoRd75pln4vTTT4+33norDjnkkFi2bFkcd9xxtR4L6AHWo8D+ZD1KR1mPdg8xECC5WbNmxU9+8pN44oknaj0KvdyHP/zh2LBhQ2zfvj3+8R//MaZPnx6PPfaYBRgA0CXWo3SU9Wj3EAN7oWHDhkW/fv1i8+bN7bZv3rw5RowYUaOpgAPRNddcE//yL/8Sjz/+eBx++OG1HodebuDAgXHMMcdERERTU1M8+eST8bd/+7exePHiGk8GdDfrUWB/sR6lDOvR7uGagb3QwIEDo6mpKVauXNm2rbW1NVauXOm98EC3KIoirrnmmli2bFk8+uijMXr06FqPRB/U2toa1Wq11mMAPcB6FOhp1qN0B+vRzvHKwF5qzpw5MX369PjoRz8ap556anzta1+L5ubmuOyyy2o9Gr3QG2+8Ef/5n//Z9vULL7wQGzZsiPe///1x5JFH1nAyeqtZs2bFN7/5zfinf/qnGDJkSGzatCkiIhoaGmLQoEE1no7eaO7cuXHBBRfEkUceGf/zP/8T3/zmN2P16tWxYsWKWo8G9BDrUcqyJqUM61HKsh7tPpWiKIpaD8HuLVq0KG699dbYtGlTnHzyybFw4cKYMGFCrceiF1q9enX8wR/8wXu2T58+PZYuXbr/B6LXq1Qqu91+3333xYwZM/bvMPQJV1xxRaxcuTJeffXVaGhoiJNOOik+//nPx3nnnVfr0YAeZD1KGdaklGE9SlnWo91HDAQAAACAJFwzEAAAAACSEAMBAAAAIAkxEAAAAACSEAMBAAAAIAkxEAAAAACSEAMBAAAAIAkxEAAAAACSEAMBAAAAIAkxEAAAAACSEAMBAAAAIAkxEAAAAACSEAMBAAAAIIn/C+frtYfPwrLqAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# with dense matrices\n", + "fig, axes = plt.subplots(2, 2, figsize=(16, 8), sharex=True)\n", + "\n", + "sns.heatmap(beliefs[0][0].mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", + "sns.heatmap(beliefs[1][0].mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", + "\n", + "sns.heatmap(smoothed_beliefs[0][0][0].mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", + "sns.heatmap(smoothed_beliefs[1][0][0].mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", + "\n", + "axes[0, 0].set_title('Filtered beliefs')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'Filtered beliefs')" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABQMAAAKqCAYAAACO80jyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAqcklEQVR4nO3de5CV9X348c9ZLisibJOAC3glGAPeMN0oXkawohI7NSFjxdhpC16CNkhr0DRhfmnQZEZ0NJISvCAdxdhfmphGaNpJyiiCxvxoUSjGXBTbaMyogBQDdcWj7D6/P/Jzf9nIZZ+9cHb5vF4z5499zjnPfh7m2eXLm3OeUymKoggAAAAA4IBXV+sBAAAAAID9QwwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwE9ujFF1+MSqUSS5cubdt2ww03RKVSqd1QJXV03rPPPjtOOOGEbv3eRx99dMyYMaPt69WrV0elUonVq1d3an9PPvlknHHGGTF48OCoVCqxYcOGbpkTAACAPMRASGzp0qVRqVR2e/vCF77Q4f3cdNNNsXz58p4blHjnnXfi4osvjm3btsWCBQvigQceiKOOOqrWYwEAANDH9K/1AEDtffnLX47Ro0e323bCCSfEUUcdFTt37owBAwbs9fk33XRT/PEf/3FMnTq1B6fs+yZOnBg7d+6MgQMHln7uf/3Xf8Uvf/nLWLJkSVx55ZU9MB0AAAAZiIFAXHDBBfHRj350t/cddNBB+3ma33jrrbdi4MCBUVd34LyAua6urtN/nlu2bImIiN/7vd/rxokAAADI5sD5VzbQ7XZ3zcDfValUorm5Oe6///62txj/9nXyXn755bj88sujsbEx6uvr4/jjj49777233T7evZbet771rfjiF78Yhx12WBx88MGxY8eOiIj493//9/jYxz4WDQ0NcfDBB8ekSZPiRz/60XtmeeKJJ+KUU06Jgw46KMaMGROLFy8ufczr1q2LM844IwYNGhSjR4+Ou++++z2PqVarMW/evDjmmGOivr4+jjjiiPjrv/7rqFare933nq4ZuK/jmzFjRkyaNCkiIi6++OKoVCpx9tlnR0TEpk2b4rLLLovDDz886uvrY+TIkfGJT3wiXnzxxdLHDgAAwIHPKwOB2L59e2zdurXdtmHDhnXouQ888EBceeWVceqpp8bMmTMjImLMmDEREbF58+Y47bTTolKpxDXXXBPDhw+PH/zgB3HFFVfEjh074tprr223r6985SsxcODAuP7666NarcbAgQPj0UcfjQsuuCCamppi3rx5UVdXF/fdd1+cc8458cMf/jBOPfXUiIh45pln4vzzz4/hw4fHDTfcELt27Yp58+ZFY2Njh/8cXn/99fjDP/zDmDZtWlx66aXx4IMPxl/8xV/EwIED4/LLL4+IiNbW1vj4xz8eTzzxRMycOTPGjRsXzzzzTCxYsCA2btxY+tqJHTm+q666Kg477LC46aab4i//8i/jlFNOaTuuiy66KH7605/G7Nmz4+ijj44tW7bEww8/HC+99FIcffTRpWYBAAAggQJI67777isiYre3oiiKF154oYiI4r777mt7zrx584rf/dUxePDgYvr06e/Z/xVXXFGMHDmy2Lp1a7vtn/rUp4qGhobizTffLIqiKFatWlVERPHBD36wbVtRFEVra2vxoQ99qJgyZUrR2tratv3NN98sRo8eXZx33nlt26ZOnVocdNBBxS9/+cu2bT/72c+Kfv36vWfe3Zk0aVIREcVXv/rVtm3VarU4+eSTi0MPPbR4++23i6IoigceeKCoq6srfvjDH7Z7/t13311ERPGjH/2obdtRRx3V7s/l3eNctWpV6eN797nf+c532ra9/vrrRUQUt9566z6PDwAAAIqiKLxNGIg77rgjHn744Xa3riqKIr773e/GhRdeGEVRxNatW9tuU6ZMie3bt8f69evbPWf69OkxaNCgtq83bNgQzz//fPzJn/xJ/Pd//3fb85ubm2Py5Mnx+OOPR2tra7S0tMSKFSti6tSpceSRR7Y9f9y4cTFlypQOz9y/f/+46qqr2r4eOHBgXHXVVbFly5ZYt25dRER85zvfiXHjxsXYsWPbHdM555wTERGrVq3q8Pfr6PHtyaBBg2LgwIGxevXqeP311zv8fQEAAMjL24SBOPXUU/f4ASKd9dprr8Wvf/3ruOeee+Kee+7Z7WPe/VCMd/3uJxo///zzEfGbSLgn27dvj2q1Gjt37owPfehD77n/wx/+cHz/+9/v0MyjRo2KwYMHt9t27LHHRsRvrp942mmnxfPPPx8///nPY/jw4bvdx+8e09509Pje97737fa++vr6uOWWW+K6666LxsbGOO200+KP/uiP4s///M9jxIgRHZ4DAACAPMRAoEe8+4q2P/3TP91j7DrppJPaff3brwr87X3ceuutcfLJJ+92H4cccsg+P7ijO7W2tsaJJ54Yt99++27vP+KII0rtK2Lfx7c31157bVx44YWxfPnyWLFiRfzN3/xNzJ8/Px599NH4yEc+0uFZAAAAyEEMBLqsUqm8Z9vw4cNjyJAh0dLSEueee26n9vvuB5EMHTp0r/sYPnx4DBo0qO2Vdr/tueee6/D3e+WVV6K5ubndqwM3btwYEdH2YRxjxoyJp59+OiZPnrzb4y6jo8fXkf1cd911cd1118Xzzz8fJ598cnz1q1+Nv//7v+/SfAAAABx4XDMQ6LLBgwfHr3/963bb+vXrFxdddFF897vfjZ/85Cfvec5rr722z/02NTXFmDFj4rbbbos33nhjj/vo169fTJkyJZYvXx4vvfRS2/0///nPY8WKFR0+jl27dsXixYvbvn777bdj8eLFMXz48GhqaoqIiGnTpsXLL78cS5Ysec/zd+7cGc3NzR3+fh09vj15880346233mq3bcyYMTFkyJD9+mpJAAAA+g6vDAS6rKmpKR555JG4/fbbY9SoUTF69OiYMGFC3HzzzbFq1aqYMGFCfPrTn47jjjsutm3bFuvXr49HHnkktm3bttf91tXVxd/93d/FBRdcEMcff3xcdtllcdhhh8XLL78cq1atiqFDh8Y///M/R0TEjTfeGP/6r/8aZ511VnzmM5+JXbt2xde//vU4/vjj48c//nGHjmPUqFFxyy23xIsvvhjHHntsfPvb344NGzbEPffcEwMGDIiIiD/7sz+LBx98MK6++upYtWpVnHnmmdHS0hLPPvtsPPjgg7FixYoOX3+xzPHtzsaNG2Py5Mkxbdq0OO6446J///6xbNmy2Lx5c3zqU5/q0AwAAADkIgYCXXb77bfHzJkz44tf/GLs3Lkzpk+fHhMmTIjGxsZYu3ZtfPnLX46HHnoo7rzzzvjABz4Qxx9/fNxyyy0d2vfZZ58da9asia985SuxaNGieOONN2LEiBExYcKEdp/8e9JJJ8WKFStizpw58aUvfSkOP/zwuPHGG+PVV1/tcAx83/veF/fff3/Mnj07lixZEo2NjbFo0aL49Kc/3faYurq6WL58eSxYsCC+8Y1vxLJly+Lggw+OD37wg/FXf/VXbR840lEdPb7dOeKII+LSSy+NlStXxgMPPBD9+/ePsWPHxoMPPhgXXXRRqTkAAADIoVIURVHrIQAAAACAnueagQAAAACQhBgIAAAAAEmIgQAAAACQhBgIAAAAAEmIgQAAAACQhBgIAAAAAEmIgQAAAACQRP9aD/Cukz67oNYj0MfsOrjWE9DXHLSt1hPQl3xg8f+p9Qj0MQ+3fqfWI9BF59VdXOsRgAPcileervUI9DFTRo2v9Qj0MR1Zk3plIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAkIQYCAAAAQBJiIAAAAAAk0b/sE7Zu3Rr33ntvrFmzJjZt2hQRESNGjIgzzjgjZsyYEcOHD+/2IQEA4F3WowAAnVfqlYFPPvlkHHvssbFw4cJoaGiIiRMnxsSJE6OhoSEWLlwYY8eOjaeeeqqnZgUAIDnrUQCArin1ysDZs2fHxRdfHHfffXdUKpV29xVFEVdffXXMnj071qxZs9f9VKvVqFar7ba17toVdf1Lv1ARAIBEenQ9WrREXaVft88MANCblHpl4NNPPx2f/exn37PwioioVCrx2c9+NjZs2LDP/cyfPz8aGhra3V578pEyowAAkFBPrkdfiGd7YGIAgN6lVAwcMWJErF27do/3r127NhobG/e5n7lz58b27dvb3Yafcm6ZUQAASKgn16OjY2x3jgoA0CuVel/u9ddfHzNnzox169bF5MmT2xZamzdvjpUrV8aSJUvitttu2+d+6uvro76+vt02bxEGAGBfenQ96i3CAEACpQrcrFmzYtiwYbFgwYK48847o6WlJSIi+vXrF01NTbF06dKYNm1ajwwKAADWowAAXVP65XiXXHJJXHLJJfHOO+/E1q1bIyJi2LBhMWDAgG4fDgAAfpf1KABA53X6vbkDBgyIkSNHducsAADQYdajAADllfoAEQAAAACg7xIDAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkuhf6wGgsw55uaj1CPQxu+ortR4BADiArHjl6VqPQB8zZdT4Wo8A4JWBAAAAAJCFGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASYiBAAAAAJCEGAgAAAAASXR7DPzVr34Vl19++V4fU61WY8eOHe1urbt2dfcoAAAk1On1aNGynyYEAKidbo+B27Zti/vvv3+vj5k/f340NDS0u7325CPdPQoAAAl1dj36Qjy7nyYEAKid/mWf8L3vfW+v9//iF7/Y5z7mzp0bc+bMabftjP+1uOwoAAAk1FPr0U82zOjKWAAAfULpGDh16tSoVCpRFMUeH1OpVPa6j/r6+qivr2+3ra5/6VEAAEiox9ajlX7dMh8AQG9W+m3CI0eOjIceeihaW1t3e1u/fn1PzAkAABFhPQoA0BWlY2BTU1OsW7duj/fv639pAQCgK6xHAQA6r/R7cz/3uc9Fc3PzHu8/5phjYtWqVV0aCgAA9sR6FACg80rHwLPOOmuv9w8ePDgmTZrU6YEAAGBvrEcBADqv9NuEAQAAAIC+SQwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCT613qAdx3ySmutR6CPaamv1HoE+piiX60noC9Z8crTtR4B2M/83FPWlFHjaz0CcIDzdxM9wSsDAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkhADAQAAACAJMRAAAAAAkigdA3fu3BlPPPFE/OxnP3vPfW+99VZ84xvf6JbBAABgd6xHAQA6r1QM3LhxY4wbNy4mTpwYJ554YkyaNCleffXVtvu3b98el1122T73U61WY8eOHe1urS27yk8PAEAqPbkerVZbe3J0AIBeoVQM/PznPx8nnHBCbNmyJZ577rkYMmRInHnmmfHSSy+V+qbz58+PhoaGdreXf76y1D4AAMinJ9ejN3/99R6aGgCg96gURVF09MGNjY3xyCOPxIknnhgREUVRxGc+85n4/ve/H6tWrYrBgwfHqFGjoqWlZa/7qVarUa1W2207//K7oq5f/04cAlm11FdqPQJ9zNuHOGfouHVfuqvWI9DH1I3YWOsRUujJ9eiA138/6utdUpuOmzJqfK1HAA5wK155utYj0Md0ZE1aarWzc+fO6N///we7SqUSd911V1x44YUxadKk2LixY4vg+vr6GDp0aLubEAgAwL705HpUCAQAMihV4MaOHRtPPfVUjBs3rt32RYsWRUTExz/+8e6bDAAAfof1KABA15T6789PfvKT8Q//8A+7vW/RokVx6aWXRol3HQMAQCnWowAAXVPqmoE96YxLvlrrEehjXDOQslwzkDJcM5CyXDOw72vddGytR6CPcc1AoKe5ZiBldfs1AwEAAACAvksMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkKkVRFLUegt2rVqsxf/78mDt3btTX19d6HPoA5wxlOWcoyzkD+fi5pwznC2U5ZyjLOdN1YmAvtmPHjmhoaIjt27fH0KFDaz0OfYBzhrKcM5TlnIF8/NxThvOFspwzlOWc6TpvEwYAAACAJMRAAAAAAEhCDAQAAACAJMTAXqy+vj7mzZvngph0mHOGspwzlOWcgXz83FOG84WynDOU5ZzpOh8gAgAAAABJeGUgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACQhBgIAAABAEmIgAAAAACTRv9YDvOvYmxbUegT6mF0HF7UegT6m/xuVWo9AH/LslXfVegT6mLoRG2s9Al10Xt3FtR6BPmbFK0/XegT6mCmjxtd6BPoYv2coqyNrUq8MBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAkxEAAAAAASEIMBAAAAIAk+pd9wtatW+Pee++NNWvWxKZNmyIiYsSIEXHGGWfEjBkzYvjw4d0+JAAAvMt6FACg80q9MvDJJ5+MY489NhYuXBgNDQ0xceLEmDhxYjQ0NMTChQtj7Nix8dRTT+1zP9VqNXbs2NHu1rprV6cPAgCAHHp0PVq07IcjAACorUpRFEVHH3zaaafF+PHj4+67745KpdLuvqIo4uqrr44f//jHsWbNmr3u54Ybbogbb7yx3bb3n3N+fGDyx0qMTna7Du7wqQsREdH/jcq+HwT/z7NX3lXrEehj6kZsrPUIKfTkenR0jIsxleO7fWYOXCteebrWI9DHTBk1vtYj0Mf4PUNZHVmTloqBgwYNiv/4j/+IsWPH7vb+Z599Nj7ykY/Ezp0797qfarUa1Wq13bbf/9riqOtf+l3LJCYGUpYYSBliIGWJgftHT65HP9kwI+oq/bptVg58/pFOWWIgZfk9Q1kdWZOWqm8jRoyItWvX7nHxtXbt2mhsbNznfurr66O+vr7dNiEQAIB96dH1qBAIACRQqsBdf/31MXPmzFi3bl1Mnjy5baG1efPmWLlyZSxZsiRuu+22HhkUAACsRwEAuqZUDJw1a1YMGzYsFixYEHfeeWe0tPzmIsv9+vWLpqamWLp0aUybNq1HBgUAAOtRAICuKf3e3EsuuSQuueSSeOedd2Lr1q0RETFs2LAYMGBAtw8HAAC/y3oUAKDzOn2hvgEDBsTIkSO7cxYAAOgw61EAgPLqaj0AAAAAALB/iIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJiIEAAAAAkIQYCAAAAABJ9K/1AO+qHrqr1iPQx9Tt1LIppzrK7xk6bsqo8bUegT7m4dZaT0BXrXjl6VqPQB/j7wrK8nuGsvyeoayOrEnVFAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCS6PQb+6le/issvv3yvj6lWq7Fjx452t+KdXd09CgAACXV2PVqttu6nCQEAaqfbY+C2bdvi/vvv3+tj5s+fHw0NDe1u23/waHePAgBAQp1dj9789df304QAALXTv+wTvve97+31/l/84hf73MfcuXNjzpw57bad8L/vKDsKAAAJ9dR6dMDrv9+luQAA+oLSMXDq1KlRqVSiKIo9PqZSqex1H/X19VFfX9/+OQNKjwIAQEI9tR5tfdPltAGAA1/pFc/IkSPjoYceitbW1t3e1q9f3xNzAgBARFiPAgB0RekY2NTUFOvWrdvj/fv6X1oAAOgK61EAgM4r/d7cz33uc9Hc3LzH+4855phYtWpVl4YCAIA9sR4FAOi80jHwrLPO2uv9gwcPjkmTJnV6IAAA2BvrUQCAznOVZAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIQgwEAAAAgCTEQAAAAABIolIURVHrIdi9arUa8+fPj7lz50Z9fX2tx6EPcM5QlnOGspwzkI+fe8pwvlCWc4aynDNdJwb2Yjt27IiGhobYvn17DB06tNbj0Ac4ZyjLOUNZzhnIx889ZThfKMs5Q1nOma7zNmEAAAAASEIMBAAAAIAkxEAAAAAASEIM7MXq6+tj3rx5LohJhzlnKMs5Q1nOGcjHzz1lOF8oyzlDWc6ZrvMBIgAAAACQhFcGAgAAAEASYiAAAAAAJCEGAgAAAEASYiAAAAAAJCEG9mJ33HFHHH300XHQQQfFhAkTYu3atbUeiV7q8ccfjwsvvDBGjRoVlUolli9fXuuR6OXmz58fp5xySgwZMiQOPfTQmDp1ajz33HO1Hote7K677oqTTjophg4dGkOHDo3TTz89fvCDH9R6LKCHWY9ShjUpZViPUpb1aPcRA3upb3/72zFnzpyYN29erF+/PsaPHx9TpkyJLVu21Ho0eqHm5uYYP3583HHHHbUehT7isccei1mzZsW//du/xcMPPxzvvPNOnH/++dHc3Fzr0eilDj/88Lj55ptj3bp18dRTT8U555wTn/jEJ+KnP/1prUcDeoj1KGVZk1KG9ShlWY92n0pRFEWth+C9JkyYEKecckosWrQoIiJaW1vjiCOOiNmzZ8cXvvCFGk9Hb1apVGLZsmUxderUWo9CH/Laa6/FoYceGo899lhMnDix1uPQR7z//e+PW2+9Na644opajwL0AOtRusKalLKsR+kM69HO8crAXujtt9+OdevWxbnnntu2ra6uLs4999xYs2ZNDScDDlTbt2+PiN/8ZQr70tLSEt/61reiubk5Tj/99FqPA/QA61Fgf7MepQzr0a7pX+sBeK+tW7dGS0tLNDY2ttve2NgYzz77bI2mAg5Ura2tce2118aZZ54ZJ5xwQq3HoRd75pln4vTTT4+33norDjnkkFi2bFkcd9xxtR4L6AHWo8D+ZD1KR1mPdg8xECC5WbNmxU9+8pN44oknaj0KvdyHP/zh2LBhQ2zfvj3+8R//MaZPnx6PPfaYBRgA0CXWo3SU9Wj3EAN7oWHDhkW/fv1i8+bN7bZv3rw5RowYUaOpgAPRNddcE//yL/8Sjz/+eBx++OG1HodebuDAgXHMMcdERERTU1M8+eST8bd/+7exePHiGk8GdDfrUWB/sR6lDOvR7uGagb3QwIEDo6mpKVauXNm2rbW1NVauXOm98EC3KIoirrnmmli2bFk8+uijMXr06FqPRB/U2toa1Wq11mMAPcB6FOhp1qN0B+vRzvHKwF5qzpw5MX369PjoRz8ap556anzta1+L5ubmuOyyy2o9Gr3QG2+8Ef/5n//Z9vULL7wQGzZsiPe///1x5JFH1nAyeqtZs2bFN7/5zfinf/qnGDJkSGzatCkiIhoaGmLQoEE1no7eaO7cuXHBBRfEkUceGf/zP/8T3/zmN2P16tWxYsWKWo8G9BDrUcqyJqUM61HKsh7tPpWiKIpaD8HuLVq0KG699dbYtGlTnHzyybFw4cKYMGFCrceiF1q9enX8wR/8wXu2T58+PZYuXbr/B6LXq1Qqu91+3333xYwZM/bvMPQJV1xxRaxcuTJeffXVaGhoiJNOOik+//nPx3nnnVfr0YAeZD1KGdaklGE9SlnWo91HDAQAAACAJFwzEAAAAACSEAMBAAAAIAkxEAAAAACSEAMBAAAAIAkxEAAAAACSEAMBAAAAIAkxEAAAAACSEAMBAAAAIAkxEAAAAACSEAMBAAAAIAkxEAAAAACSEAMBAAAAIIn/C+frtYfPwrLqAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "#with sparse matrices\n", + "fig, axes = plt.subplots(2, 2, figsize=(16, 8), sharex=True)\n", + "\n", + "sns.heatmap(beliefs_single[0].mT, ax=axes[0, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", + "sns.heatmap(beliefs_single[1].mT, ax=axes[1, 0], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", + "\n", + "sns.heatmap(smoothed_beliefs_sparse[0][0].mT, ax=axes[0, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", + "sns.heatmap(smoothed_beliefs_sparse[1][0].mT, ax=axes[1, 1], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", + "\n", + "axes[0, 0].set_title('Filtered beliefs')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Compare to marginal message passing" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABQMAAAGHCAYAAAAEKUSHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAmlUlEQVR4nO3deZCU9Z348U9zOBzCGG4sVFiRgKJYS4xBZfCAIBoRY8AzCl6bFXDFwujUbhSyuoPRjXKIaDyDYhQ8oqyIEdDoiouAGHUTj/UmCiI6RMRBmef3h8X87HA2DDbj9/Wq6irn2888/emmwafe9fTTuSzLsgAAAAAAvvXqFXsAAAAAAOCbIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgDfKocffngcfvjhO/QxxowZE7lcboc+xs4ol8vFiBEjdvjjPPHEE5HL5eKJJ57Y7Hbr/xxWrFhRa489dOjQ6NixY95aLpeLMWPGbNP+li1bFj/5yU+iZcuWkcvl4rrrrtvuGQEAtocYCABsldtvvz1yuVzkcrl4+umnN7g/y7LYY489IpfLxY9+9KMiTEhteOaZZ2LMmDHxySefFHuUb4VRo0bF7Nmzo7y8PKZOnRpHH310sUcCABLXoNgDAAB1S6NGjWLatGlx2GGH5a0/+eST8d5770VJSUmRJvvKY489VtTHr+ueeeaZGDt2bAwdOjR22223Yo+zU1izZk00aLBth81z586N448/PkaPHl3LUwEAbBtnBgIABTnmmGNi+vTp8eWXX+atT5s2LXr27Bnt2rWrtceqrq6Ozz//vKDf2WWXXWKXXXaptRmgUaNG2xwDly9fLqoCADsVMRAAKMgpp5wSH330UfzhD3+oWVu7dm3MmDEjTj311I3+zjXXXBOHHHJItGzZMho3bhw9e/aMGTNmbLDd+mvS3XXXXbHffvtFSUlJPProoxER8ac//Sn69OkTjRs3jg4dOsQVV1wRt912W+RyuXjrrbdq9vH31wxcf/25e++9N6688sro0KFDNGrUKI466qh4/fXX8x7/qaeeisGDB8eee+4ZJSUlsccee8SoUaNizZo12/Ravfbaa3HiiSdGu3btolGjRtGhQ4c4+eSTo7KycoPnPH369Nh3332jcePG0atXr3jxxRcjIuLGG2+Mzp07R6NGjeLwww/Pe67rTZ8+PXr27BmNGzeOVq1axemnnx5Lly7dYLu5c+dG7969o2nTprHbbrvF8ccfH3/+859r7h8zZkxcfPHFERHRqVOnmo+F//1jPvjgg9G9e/coKSmJ/fbbr+bP6OuWLl0aZ511VrRt27Zmu1tvvXWD7d57770YNGhQNG3aNNq0aROjRo2KqqqqrXp911uxYkUMGTIkmjdvHi1btox/+Zd/2WhEvvPOO2tepxYtWsTJJ58c77777hb3v7FrBm7p+a3/WH2WZXH99dfXvJYREV988UWMHTs29tlnn2jUqFG0bNkyDjvssLy/UwAAO4qPCQMABenYsWP06tUr7r777hgwYEBERMyaNSsqKyvj5JNPjgkTJmzwO+PHj4+BAwfGaaedFmvXro3f/e53MXjw4Jg5c2Yce+yxedvOnTs37r333hgxYkS0atUqOnbsGEuXLo0jjjgicrlclJeXR9OmTePmm28u6CPJ48aNi3r16sXo0aOjsrIyfvWrX8Vpp50W//M//1OzzfTp0+Ozzz6Lf/7nf46WLVvGggULYuLEifHee+/F9OnTC3qd1q5dG/3794+qqqoYOXJktGvXLpYuXRozZ86MTz75JEpLS2u2feqpp+Khhx6K4cOHR0RERUVF/OhHP4qf//znMXny5Dj//PPj448/jl/96ldx1llnxdy5c2t+9/bbb49hw4bFQQcdFBUVFbFs2bIYP358/Pd//3c8//zzNWelPf744zFgwID4h3/4hxgzZkysWbMmJk6cGIceemgsXrw4OnbsGD/+8Y/j1VdfjbvvvjuuvfbaaNWqVUREtG7duubxnn766bj//vvj/PPPj2bNmsWECRPixBNPjHfeeSdatmwZEV99acYPfvCDmtDZunXrmDVrVpx99tmxatWquPDCCyPiq4/fHnXUUfHOO+/EBRdcELvvvntMnTo17/ltjSFDhkTHjh2joqIinn322ZgwYUJ8/PHH8dvf/rZmmyuvvDJ+8YtfxJAhQ+Kcc86JDz/8MCZOnBhlZWV5r9PW2JrnV1ZWFlOnTo2f/vSn0a9fvzjjjDNqfn/MmDFRUVER55xzTnz/+9+PVatWxcKFC2Px4sXRr1+/gp47AEDBMgCArXDbbbdlEZE999xz2aRJk7JmzZpln332WZZlWTZ48ODsiCOOyLIsy/baa6/s2GOPzfvd9dutt3bt2qx79+7ZkUcembceEVm9evWyl19+OW995MiRWS6Xy55//vmatY8++ihr0aJFFhHZm2++WbPep0+frE+fPjU/z5s3L4uIrFu3bllVVVXN+vjx47OIyF588cVNzpllWVZRUZHlcrns7bffrlm7/PLLsy0dRj3//PNZRGTTp0/f7HYRkZWUlOQ9hxtvvDGLiKxdu3bZqlWratbLy8vznu/atWuzNm3aZN27d8/WrFlTs93MmTOziMguu+yymrUDDzwwa9OmTfbRRx/VrL3wwgtZvXr1sjPOOKNm7eqrr97gNf36rLvsskv2+uuv5+0jIrKJEyfWrJ199tlZ+/btsxUrVuT9/sknn5yVlpbWvM7XXXddFhHZvffeW7PN6tWrs86dO2cRkc2bN2+zr936P4eBAwfmrZ9//vlZRGQvvPBClmVZ9tZbb2X169fPrrzyyrztXnzxxaxBgwZ562eeeWa21157bfC8L7/88oKf3/rfHT58eN52PXr02ODvCADAN8XHhAGAgg0ZMiTWrFkTM2fOjL/97W8xc+bMTX5EOCKicePGNf/98ccfR2VlZfTu3TsWL168wbZ9+vSJfffdN2/t0UcfjV69esWBBx5Ys9aiRYs47bTTtnrmYcOG5V1LsHfv3hER8cYbb2x0ztWrV8eKFSvikEMOiSzL4vnnn9/qx4qImjP/Zs+eHZ999tlmtz3qqKOiY8eONT8ffPDBERFx4oknRrNmzTZYXz/zwoULY/ny5XH++edHo0aNarY79thjo2vXrvFf//VfERHx/vvvx5IlS2Lo0KHRokWLmu0OOOCA6NevXzzyyCNb/bz69u0be++9d94+mjdvXjNTlmVx3333xXHHHRdZlsWKFStqbv3794/KysqaP/dHHnkk2rdvHz/5yU9q9tekSZM477zztnqeiKg5o3K9kSNH1uw/IuL++++P6urqGDJkSN487dq1i3322SfmzZu31Y9VyPPblN122y1efvnleO211wp6ngAAtcHHhAGAgrVu3Tr69u0b06ZNi88++yzWrVuXF3T+3syZM+OKK66IJUuW5F0Pbv011L6uU6dOG6y9/fbb0atXrw3WO3fuvNUz77nnnnk/f+c734mIr+Lkeu+8805cdtll8dBDD+WtR0Tedf62RqdOneKiiy6KX//613HXXXdF7969Y+DAgXH66afnfUR4Y7Otv3+PPfbY6Pr62d5+++2IiPjud7+7weN37do1nn766S1u161bt5g9e3asXr06mjZtusXn9fezRnz1Wq6f6cMPP4xPPvkkbrrpprjppps2uo/ly5fXzNW5c+cN3gcbm3Nz9tlnn7yf995776hXr17NtQ5fe+21yLJsg+3Wa9iw4VY/ViHPb1N++ctfxvHHHx9dunSJ7t27x9FHHx0//elP44ADDtjqOQAAtpUYCABsk1NPPTXOPffc+OCDD2LAgAGbvObaU089FQMHDoyysrKYPHlytG/fPho2bBi33XZbTJs2bYPtv352Xm2qX7/+RtezLIuIiHXr1kW/fv1i5cqVcckll0TXrl2jadOmsXTp0hg6dGhUV1cX/Jj/+Z//GUOHDo3f//738dhjj8UFF1xQc127Dh06bHG2Lc1cDFuaaf3rdPrpp8eZZ5650W13dPT6+7hYXV0duVwuZs2atdH5d911163ed208v7Kysvi///u/mvfFzTffHNdee21MmTIlzjnnnK2eBQBgW4iBAMA2OeGEE+Kf/umf4tlnn4177rlnk9vdd9990ahRo5g9e3beF37cdtttW/1Ye+211wbf/BsRG13bVi+++GK8+uqrcccdd+R92cP2fsPr/vvvH/vvv3/827/9WzzzzDNx6KGHxpQpU+KKK67Y3pFjr732ioiIV155JY488si8+1555ZWa+7++3d/7y1/+Eq1atao5K3BjZ2sWonXr1tGsWbNYt25d9O3bd4vzv/TSS5FlWd7jbmzOzXnttdfyzih9/fXXo7q6uuaj13vvvXdkWRadOnWKLl26FLTvv1fI89ucFi1axLBhw2LYsGHx6aefRllZWYwZM0YMBAB2ONcMBAC2ya677ho33HBDjBkzJo477rhNble/fv3I5XKxbt26mrW33norHnzwwa1+rP79+8f8+fNjyZIlNWsrV66Mu+66a1tG3+ScEfln3WVZFuPHj9+m/a1atSq+/PLLvLX9998/6tWrl/dR6e3xve99L9q0aRNTpkzJ2+esWbPiz3/+c803Nbdv3z4OPPDAuOOOO+KTTz6p2e6ll16Kxx57LI455piatfVR8OvbFaJ+/fpx4oknxn333RcvvfTSBvd/+OGHNf99zDHHxF//+teYMWNGzdpnn322yY/fbsr111+f9/PEiRMjImq+7frHP/5x1K9fP8aOHbvBWZVZlsVHH3201Y9VyPPblL9/vF133TU6d+5ca+8LAIDNcWYgALDNNvUxya879thj49e//nUcffTRceqpp8by5cvj+uuvj86dO8ef/vSnrXqcn//853HnnXdGv379YuTIkdG0adO4+eabY88994yVK1du99lsEV9dY2/vvfeO0aNHx9KlS6N58+Zx3333bXDtwK01d+7cGDFiRAwePDi6dOkSX375ZUydOrUmJtWGhg0bxlVXXRXDhg2LPn36xCmnnBLLli2L8ePHR8eOHWPUqFE121599dUxYMCA6NWrV5x99tmxZs2amDhxYpSWlsaYMWNqtuvZs2dERPzrv/5rnHzyydGwYcM47rjjtup6guuNGzcu5s2bFwcffHCce+65se+++8bKlStj8eLF8fjjj8fKlSsjIuLcc8+NSZMmxRlnnBGLFi2K9u3bx9SpU6NJkyYFvQ5vvvlmDBw4MI4++uiYP39+3HnnnXHqqadGjx49IuKrMwOvuOKKKC8vj7feeisGDRoUzZo1izfffDMeeOCBOO+882L06NG1/vw2Zd99943DDz88evbsGS1atIiFCxfGjBkzYsSIEQU9bwCAbSEGAgA71JFHHhm33HJLjBs3Li688MLo1KlTXHXVVfHWW29tdQzcY489Yt68eXHBBRfEf/zHf0Tr1q1j+PDh0bRp07jgggvyvkl3WzVs2DAefvjhmuv6NWrUKE444YQYMWJETVQqRI8ePaJ///7x8MMPx9KlS6NJkybRo0ePmDVrVvzgBz/Y7nnXGzp0aDRp0iTGjRsXl1xySTRt2jROOOGEuOqqq/Ku49i3b9949NFH4/LLL4/LLrssGjZsGH369Imrrroq7yO2Bx10UPz7v/97TJkyJR599NGorq6ON998s6AY2LZt21iwYEH88pe/jPvvvz8mT54cLVu2jP322y+uuuqqmu2aNGkSc+bMiZEjR8bEiROjSZMmcdppp8WAAQPi6KOP3urHu+eee+Kyyy6LSy+9NBo0aBAjRoyIq6++Om+bSy+9NLp06RLXXnttjB07NiK+el/98Ic/jIEDB271YxXy/DblggsuiIceeigee+yxqKqqir322iuuuOKKuPjiiwuaAwBgW+SyYl6BGgBgO1x44YVx4403xqeffrrJL7YAAAD+P9cMBADqhDVr1uT9/NFHH8XUqVPjsMMOEwIBAGAr+ZgwAFAn9OrVKw4//PDo1q1bLFu2LG655ZZYtWpV/OIXvyj2aAAAUGeIgQBAnXDMMcfEjBkz4qabbopcLhf/+I//GLfcckuUlZUVezQAAKgzXDMQAAAAABLhmoEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEhEg2IPsF6/eoOLPQLwLTf7ry8UewTqkP679yj2CNQxf6ieXuwR2E6ORymUYwsK5fiCQvl3hkLVa/fqlrf5BuYAAAAAAHYCYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARDQr9hRUrVsStt94a8+fPjw8++CAiItq1axeHHHJIDB06NFq3bl3rQwIAAAAA26+gMwOfe+656NKlS0yYMCFKS0ujrKwsysrKorS0NCZMmBBdu3aNhQsXbnE/VVVVsWrVqrxbdbZum58EAAAAALBlBZ0ZOHLkyBg8eHBMmTIlcrlc3n1ZlsXPfvazGDlyZMyfP3+z+6moqIixY8fmrXWKbrF37FfIOAAAAABAAQo6M/CFF16IUaNGbRACIyJyuVyMGjUqlixZssX9lJeXR2VlZd6tU3QtZBQAAAAAoEAFnRnYrl27WLBgQXTtuvFwt2DBgmjbtu0W91NSUhIlJSV5a/Vy9QsZBQAAAAAoUEExcPTo0XHeeefFokWL4qijjqoJf8uWLYs5c+bEb37zm7jmmmt2yKAAAAAAwPYpKAYOHz48WrVqFddee21Mnjw51q376ks/6tevHz179ozbb789hgwZskMGBQAAAAC2T0ExMCLipJNOipNOOim++OKLWLFiRUREtGrVKho2bFjrwwEAAAAAtafgGLhew4YNo3379rU5CwAAAACwAxX0bcIAAAAAQN0lBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJCIBsUeAAAACrFixYq49dZbY/78+fHBBx9ERES7du3ikEMOiaFDh0br1q2LPCEAwM7LmYEAANQZzz33XHTp0iUmTJgQpaWlUVZWFmVlZVFaWhoTJkyIrl27xsKFC7e4n6qqqli1alXerTpb9w08AwCA4nJmIAAAdcbIkSNj8ODBMWXKlMjlcnn3ZVkWP/vZz2LkyJExf/78ze6noqIixo4dm7fWKbrF3rFfrc8MALAzcWYgAAB1xgsvvBCjRo3aIARGRORyuRg1alQsWbJki/spLy+PysrKvFun6LoDJgYA2Lk4MxAAgDqjXbt2sWDBgujadePhbsGCBdG2bdst7qekpCRKSkry1url6tfKjAAAOzMxEACAOmP06NFx3nnnxaJFi+Koo46qCX/Lli2LOXPmxG9+85u45pprijwlAMDOSwwEAKDOGD58eLRq1SquvfbamDx5cqxb99WXftSvXz969uwZt99+ewwZMqTIUwIA7LzEQAAA6pSTTjopTjrppPjiiy9ixYoVERHRqlWraNiwYZEnAwDY+YmBAADUSQ0bNoz27dsXewwAgDrFtwkDAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiWhQ7AFgW83+6wvFHoE6pv/uPYo9AgAAABSVMwMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBG1HgPffffdOOussza7TVVVVaxatSrvVp2tq+1RAAAAAICvaVDbO1y5cmXccccdceutt25ym4qKihg7dmzeWqfoFnvHfrU9DgAAbJXZf32h2CNQx/TfvUexR6CO8e8MhfLvDIX6Q/WWtyk4Bj700EObvf+NN97Y4j7Ky8vjoosuyls7oXRooaMAAAAAAAUoOAYOGjQocrlcZFm2yW1yudxm91FSUhIlJSV5a/Vy9QsdBQAAAAAoQMHXDGzfvn3cf//9UV1dvdHb4sWLd8ScAAAAAMB2KjgG9uzZMxYtWrTJ+7d01iAAAAAAUBwFf0z44osvjtWrV2/y/s6dO8e8efO2aygAAAAAoPYVHAN79+692fubNm0affr02eaBAAAAAIAdo+CPCQMAAAAAdZMYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARDQo9gDrzf7rC8UegTqm/+49ij0C8C3m/0sAAMC3kTMDAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAwLfKu+++G2edddZmt6mqqopVq1bl3aqqqr+hCQEAikcMBADgW2XlypVxxx13bHabioqKKC0tzbuNm/jxNzQhAEDxNCj2AAAAUIiHHnpos/e/8cYbW9xHeXl5XHTRRXlrDT/+x+2aCwCgLhADAQCoUwYNGhS5XC6yLNvkNrlcbrP7KCkpiZKSkry16s98aAYA+PZzxAMAQJ3Svn37uP/++6O6unqjt8WLFxd7RACAnZYYCABAndKzZ89YtGjRJu/f0lmDAAAp8zFhAADqlIsvvjhWr169yfs7d+4c8+bN+wYnAgCoO8RAAADqlN69e2/2/qZNm0afPn2+oWkAAOoWHxMGAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJCIgmPgmjVr4umnn47//d//3eC+zz//PH77299ucR9VVVWxatWqvFtVVXWhowAAAAAABSgoBr766qvRrVu3KCsri/333z/69OkT77//fs39lZWVMWzYsC3up6KiIkpLS/Nu4yZ+XPj0AAAAAMBWKygGXnLJJdG9e/dYvnx5vPLKK9GsWbM49NBD45133inoQcvLy6OysjLvdunI7xS0DwAAAACgMA0K2fiZZ56Jxx9/PFq1ahWtWrWKhx9+OM4///zo3bt3zJs3L5o2bbpV+ykpKYmSkpK8terPXL4QAAAAAHakggrcmjVrokGD/98Pc7lc3HDDDXHcccdFnz594tVXX631AQEAAACA2lHQmYFdu3aNhQsXRrdu3fLWJ02aFBERAwcOrL3JAAAAAIBaVdCZgSeccELcfffdG71v0qRJccopp0SWZbUyGAAAAABQuwqKgeXl5fHII49s8v7JkydHdXX1dg8FAAAAANQ+39oBAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBE5LIsy4o9BBtXVVUVFRUVUV5eHiUlJcUehzrAe4ZCec9QKO8ZSI+/9xTC+4VCec9QKO+Z7ScG7sRWrVoVpaWlUVlZGc2bNy/2ONQB3jMUynuGQnnPQHr8vacQ3i8UynuGQnnPbD8fEwYAAACARIiBAAAAAJAIMRAAAAAAEiEG7sRKSkri8ssvd0FMtpr3DIXynqFQ3jOQHn/vKYT3C4XynqFQ3jPbzxeIAAAAAEAinBkIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMXAndv3110fHjh2jUaNGcfDBB8eCBQuKPRI7qT/+8Y9x3HHHxe677x65XC4efPDBYo/ETq6ioiIOOuigaNasWbRp0yYGDRoUr7zySrHHYid2ww03xAEHHBDNmzeP5s2bR69evWLWrFnFHgvYwRyPUgjHpBTC8SiFcjxae8TAndQ999wTF110UVx++eWxePHi6NGjR/Tv3z+WL19e7NHYCa1evTp69OgR119/fbFHoY548sknY/jw4fHss8/GH/7wh/jiiy/ihz/8YaxevbrYo7GT6tChQ4wbNy4WLVoUCxcujCOPPDKOP/74ePnll4s9GrCDOB6lUI5JKYTjUQrleLT25LIsy4o9BBs6+OCD46CDDopJkyZFRER1dXXsscceMXLkyLj00kuLPB07s1wuFw888EAMGjSo2KNQh3z44YfRpk2bePLJJ6OsrKzY41BHtGjRIq6++uo4++yziz0KsAM4HmV7OCalUI5H2RaOR7eNMwN3QmvXro1FixZF3759a9bq1asXffv2jfnz5xdxMuDbqrKyMiK++p8pbMm6devid7/7XaxevTp69epV7HGAHcDxKPBNczxKIRyPbp8GxR6ADa1YsSLWrVsXbdu2zVtv27Zt/OUvfynSVMC3VXV1dVx44YVx6KGHRvfu3Ys9DjuxF198MXr16hWff/557LrrrvHAAw/EvvvuW+yxgB3A8SjwTXI8ytZyPFo7xECAxA0fPjxeeumlePrpp4s9Cju57373u7FkyZKorKyMGTNmxJlnnhlPPvmkAzAAYLs4HmVrOR6tHWLgTqhVq1ZRv379WLZsWd76smXLol27dkWaCvg2GjFiRMycOTP++Mc/RocOHYo9Dju5XXbZJTp37hwRET179oznnnsuxo8fHzfeeGORJwNqm+NR4JvieJRCOB6tHa4ZuBPaZZddomfPnjFnzpyaterq6pgzZ47PwgO1IsuyGDFiRDzwwAMxd+7c6NSpU7FHog6qrq6OqqqqYo8B7ACOR4EdzfEotcHx6LZxZuBO6qKLLoozzzwzvve978X3v//9uO6662L16tUxbNiwYo/GTujTTz+N119/vebnN998M5YsWRItWrSIPffcs4iTsbMaPnx4TJs2LX7/+99Hs2bN4oMPPoiIiNLS0mjcuHGRp2NnVF5eHgMGDIg999wz/va3v8W0adPiiSeeiNmzZxd7NGAHcTxKoRyTUgjHoxTK8WjtyWVZlhV7CDZu0qRJcfXVV8cHH3wQBx54YEyYMCEOPvjgYo/FTuiJJ56II444YoP1M888M26//fZvfiB2erlcbqPrt912WwwdOvSbHYY64eyzz445c+bE+++/H6WlpXHAAQfEJZdcEv369Sv2aMAO5HiUQjgmpRCORymU49HaIwYCAAAAQCJcMxAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEjE/wNyswhV/yxGqAAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "mmp_agents = agents = Agent(\n", + " A=A,\n", + " B=B,\n", + " C=C,\n", + " D=D,\n", + " E=E,\n", + " pA=pA,\n", + " pB=pB,\n", + " policy_len=3,\n", + " control_fac_idx=None,\n", + " onehot_obs=True,\n", + " action_selection=\"deterministic\",\n", + " sampling_mode=\"full\",\n", + " inference_algo=\"mmp\",\n", + " num_iter=16\n", + ")\n", + "\n", + "mmp_obs = [jnp.moveaxis(obs[0], 0, 1)]\n", + "post_marg_beliefs = mmp_agents.infer_states(mmp_obs, mmp_agents.D, past_actions=jnp.stack(action_hist, 1))\n", + "\n", + "#with sparse matrices\n", + "fig, axes = plt.subplots(1, 2, figsize=(16, 4), sharex=True)\n", + "\n", + "sns.heatmap(post_marg_beliefs[0][0].mT, ax=axes[0], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", + "sns.heatmap(post_marg_beliefs[1][0].mT, ax=axes[1], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", + "\n", + "fig.suptitle('Marginal smoothed beliefs');" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Compare to variational message passing" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABQMAAAGHCAYAAAAEKUSHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkh0lEQVR4nO3deZCU9Z348U9zDQg4iJzRCBouDYjJKIRbxYhEw2IqHhFLxphVN4AlFkTRUsA1kl11RUTF6CrGCBo1XqsGEPGIskFw0UjWK8KqiVweoKOMyDy/P1LMLx3OloGe8ft6VXVZ832efvrTYwNPvevp7lyWZVkAAAAAAF959Yo9AAAAAACwZ4iBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAX3Hl5eXRrFmzPfJYHTt2jPLy8u3us2LFisjlcnH11VfX2OM+9dRTkcvl4qmnnqpeKy8vj44dO37pY1511VVx0EEHRf369eOwww7b5RkBAGoDMRAA2GOGDRsWe+21V3z88cfb3GfEiBHRqFGjeP/99yMiIpfLRS6Xi5/85Cdb3f+SSy6p3mft2rXV6+Xl5dXruVwu9t577+jZs2dcc801UVlZWbNPrBb49NNPY9KkSXkxjC9v7ty58bOf/Sz69esXt99+e1x55ZXFHgkAoEY0KPYAAEA6RowYEY888kg88MADccYZZ2yx/dNPP42HHnoojjvuuNh3332r1xs3bhz3339/3HjjjdGoUaO8+8yePTsaN24cGzZs2OJ4JSUlceutt0ZExEcffRT3339/jBs3Ll544YW4++67a/jZFdenn34akydPjoiII488srjD1BK33HJLVFVVfan7Pvnkk1GvXr34z//8zy1ecwAAdZkrAwGAPWbYsGHRvHnzmDVr1la3P/TQQ1FRUREjRozIWz/uuONi/fr18fjjj+etP//887F8+fI4/vjjt3q8Bg0axOmnnx6nn356jB49OubPnx+HH3543HPPPfHXv/61Zp4UtVbDhg2jpKTkS9139erV0aRJEyEQAPjKEQMBgD2mSZMm8YMf/CDmz58fq1ev3mL7rFmzonnz5jFs2LC89f322y8GDhy4RUS86667okePHtG9e/edevx69epVXzW3YsWKbe63cePGmDx5cnTu3DkaN24c++67b/Tv3z/mzZtXvc/mz+F7++2344QTTohmzZrFfvvtFzfccENERPzxj3+Mo48+Opo2bRodOnTYagB966234qSTToqWLVvGXnvtFd/5znfi0Ucf3WK/1atXx1lnnRVt27aNxo0bR8+ePeOOO+6o3r5ixYpo3bp1RERMnjy5+q3RkyZNyjvOX/7ylxg+fHg0a9YsWrduHePGjYtNmzbl7VNVVRVTp06Nb37zm9G4ceNo27ZtnHPOOfHhhx/m7ZdlWVxxxRWx//77x1577RVHHXVULFu2bJu/02259tpro0OHDtGkSZMYNGhQvPLKK1vs8+qrr8YPf/jDaNmyZTRu3DgOP/zwePjhh3d47K19ZuDOPL9cLhe33357VFRUVP8uZ86cGRER8+bNi/79+0eLFi2iWbNm0bVr17j44osLft4AAMUiBgIAe9SIESPiiy++iN/85jd56x988EHMmTMnTjzxxGjSpMkW9zvttNPikUceiU8++SQiIr744ou4995747TTTivo8f/85z9HROS9DfkfTZo0KSZPnhxHHXVUTJ8+PS655JI44IAD4sUXX8zbb9OmTTF06ND4+te/Hv/+7/8eHTt2jNGjR8fMmTPjuOOOi8MPPzz+7d/+LZo3bx5nnHFGLF++vPq+q1atir59+8acOXPipz/9afz85z+PDRs2xLBhw+KBBx6o3u+zzz6LI488Mu68884YMWJEXHXVVVFaWhrl5eVx3XXXRURE69at46abboqIiBNPPDHuvPPOuPPOO+MHP/hB3qxDhgyJfffdN66++uoYNGhQXHPNNfHLX/4y7zmdc845MX78+OjXr19cd911ceaZZ8Zdd90VQ4YMiY0bN1bvd9lll8Wll14aPXv2rP6ijWOPPTYqKip2+v/Fr371q5g2bVqMGjUqJkyYEK+88kocffTRsWrVqup9li1bFt/5znfif//3f+Oiiy6Ka665Jpo2bRrDhw/P+z3trJ15fnfeeWcMGDAgSkpKqn+XAwcOjGXLlsUJJ5wQlZWVcfnll8c111wTw4YNi+eee67gOQAAiiYDANiDvvjii6x9+/ZZnz598tZnzJiRRUQ2Z86cvPWIyEaNGpV98MEHWaNGjbI777wzy7Ise/TRR7NcLpetWLEimzhxYhYR2Zo1a6rvN3LkyKxp06bZmjVrsjVr1mRvvvlmduWVV2a5XC479NBDtztjz549s+OPP367+4wcOTKLiOzKK6+sXvvwww+zJk2aZLlcLrv77rur11999dUsIrKJEydWr51//vlZRGTPPvts9drHH3+cHXjggVnHjh2zTZs2ZVmWZVOnTs0iIvv1r39dvd/nn3+e9enTJ2vWrFm2fv36LMuybM2aNVs8xj/Oevnll+etf+tb38rKysqqf3722WeziMjuuuuuvP1+97vf5a2vXr06a9SoUXb88cdnVVVV1ftdfPHFWURkI0eO3O7vbvny5VlEZE2aNMnefffd6vU//OEPWURkY8eOrV4bPHhw1qNHj2zDhg3Va1VVVVnfvn2zzp07V68tWLAgi4hswYIFec+7Q4cOBT+/zfdt2rRp3n7XXnvtFq8zAIC6xpWBAMAeVb9+/Tj11FNj4cKFeW/VnTVrVrRt2zYGDx681fvts88+cdxxx8Xs2bOr9+/bt2906NBhm49VUVERrVu3jtatW0enTp3i4osvjj59+uzwirIWLVrEsmXL4o033tjh8/n7bzlu0aJFdO3aNZo2bRonn3xy9XrXrl2jRYsW8dZbb1WvPfbYY9GrV6/o379/9VqzZs3i7LPPjhUrVsSf/vSn6v3atWsXP/rRj6r3a9iwYZx33nnxySefxNNPP73DGTc799xz834eMGBA3kz33ntvlJaWxne/+91Yu3Zt9a2srCyaNWsWCxYsiIiIJ554Ij7//PMYM2ZM5HK56vuff/75Oz1LRMTw4cNjv/32q/65V69e0bt373jsscci4m9Xiz755JNx8sknx8cff1w9z/vvvx9DhgyJN954I/7yl7/s9OPt7PPblhYtWkTE3z7b8st+MQkAQLGJgQDAHrf5C0I2f47eu+++G88++2yceuqpUb9+/W3e77TTTot58+bF22+/HQ8++OAO3yLcuHHjmDdvXsybNy+eeeaZeOedd+K5556Lgw46aLv3u/zyy+Ojjz6KLl26RI8ePWL8+PHx8ssvb/X4mz+rb7PS0tLYf//98yLZ5vW//1y6//u//4uuXbtuccyDDz64evvm/3bu3Dnq1au33f12ZGuz7rPPPnkzvfHGG7Fu3bpo06ZNdUTdfPvkk0+qP+dx82N27tw573itW7eOffbZZ6fm2dr9IyK6dOlSHYnffPPNyLIsLr300i3mmThxYkTEVj97clt29vltyymnnBL9+vWLn/zkJ9G2bds49dRT4ze/+Y0wCADUKQ2KPQAAkJ6ysrLo1q1bzJ49Oy6++OKYPXt2ZFm2xbcI/6Nhw4ZFSUlJjBw5MiorK/Ouvtua+vXrxzHHHFPwfAMHDow///nP8dBDD8XcuXPj1ltvjWuvvTZmzJiRdyXgtsLlttazLCt4lpqyvci6WVVVVbRp0ybuuuuurW7/x5i4u22ObOPGjYshQ4ZsdZ9OnToVdLxdeX5NmjSJZ555JhYsWBCPPvpo/O53v4t77rknjj766Jg7d+5O/Y4BAIpNDAQAimLEiBFx6aWXxssvvxyzZs2Kzp07xxFHHLHd+zRp0iSGDx8ev/71r2Po0KHRqlWr3TZfy5Yt48wzz4wzzzwzPvnkkxg4cGBMmjQpLwbuig4dOsRrr722xfqrr75avX3zf19++eWoqqrKuzrwH/f7xysRv4xvfOMb8cQTT0S/fv22+iUufz97xN+utPv7qyzXrFmzxbcOb8/W3ob9+uuvV38D8OZjN2zY8EtF3X+0s89ve+rVqxeDBw+OwYMHx3/8x3/ElVdeGZdcckksWLCgRmYEANjdvE0YACiKzVcBXnbZZbF06dIdXhW42bhx42LixIlx6aWX7rbZ3n///byfmzVrFp06dYrKysoae4zvfe97sWjRoli4cGH1WkVFRfzyl7+Mjh07xiGHHFK938qVK+Oee+6p3u+LL76I66+/Ppo1axaDBg2KiIi99torIiI++uijLz3TySefHJs2bYp//dd/3WLbF198UX3sY445Jho2bBjXX3993tWOU6dOLejxHnzwwbzP/Fu0aFH84Q9/iKFDh0ZERJs2beLII4+Mm2++Od57770t7r9mzZqCHm9nn9+2fPDBB1usHXbYYRERNfraAADYnVwZCAAUxYEHHhh9+/aNhx56KCJip2Ngz549o2fPnrtztDjkkEPiyCOPjLKysmjZsmUsXrw47rvvvhg9enSNPcZFF10Us2fPjqFDh8Z5550XLVu2jDvuuCOWL18e999/f/VVgGeffXbcfPPNUV5eHkuWLImOHTvGfffdF88991xMnTo1mjdvHhF/u2rykEMOiXvuuSe6dOkSLVu2jO7du0f37t13eqZBgwbFOeecE1OmTImlS5fGscceGw0bNow33ngj7r333rjuuuvihz/8YbRu3TrGjRsXU6ZMiRNOOCG+973vxf/8z//E448/XtDVmp06dYr+/fvHv/zLv0RlZWVMnTo19t133/jZz35Wvc8NN9wQ/fv3jx49esQ///M/x0EHHRSrVq2KhQsXxrvvvhsvvfRSjT+/bbn88svjmWeeieOPPz46dOgQq1evjhtvvDH233//vC+CAQCozcRAAKBoRowYEc8//3z06tWroM9+293OO++8ePjhh2Pu3LlRWVkZHTp0iCuuuCLGjx9fY4/Rtm3beP755+PCCy+M66+/PjZs2BCHHnpoPPLII3H88cdX79ekSZN46qmn4qKLLoo77rgj1q9fH127do3bb789ysvL84556623xpgxY2Ls2LHx+eefx8SJEwuKgRERM2bMiLKysrj55pvj4osvjgYNGkTHjh3j9NNPj379+lXvd8UVV0Tjxo1jxowZsWDBgujdu3fMnTs3b/YdOeOMM6JevXoxderUWL16dfTq1SumT58e7du3r97nkEMOicWLF8fkyZNj5syZ8f7770ebNm3iW9/6Vlx22WUFPbdCnt/WDBs2LFasWBG33XZbrF27Nlq1ahWDBg2KyZMnR2lpacGzAAAUQy4r5idZAwAAAAB7jM8MBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJKJBsQfY7Lv1Tir2CMBX3Jy/vlTsEahDhnytZ7FHoI6ZV3VvsUdgFzkfpVDOLSiU8wsK5e8ZClWv3es73mcPzAEAAAAA1AJiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBENCr3D2rVr47bbbouFCxfGypUrIyKiXbt20bdv3ygvL4/WrVvX+JAAAAAAwK4r6MrAF154Ibp06RLTpk2L0tLSGDhwYAwcODBKS0tj2rRp0a1bt1i8ePEOj1NZWRnr16/Pu1Vlm770kwAAAAAAdqygKwPHjBkTJ510UsyYMSNyuVzetizL4txzz40xY8bEwoULt3ucKVOmxOTJk/PWDoyD4xvxzULGAQAAAAAKUNCVgS+99FKMHTt2ixAYEZHL5WLs2LGxdOnSHR5nwoQJsW7durzbgdGtkFEAAAAAgAIVdGVgu3btYtGiRdGt29bD3aJFi6Jt27Y7PE5JSUmUlJTkrdXL1S9kFAAAAACgQAXFwHHjxsXZZ58dS5YsicGDB1eHv1WrVsX8+fPjlltuiauvvnq3DAoAAAAA7JqCYuCoUaOiVatWce2118aNN94Ymzb97Us/6tevH2VlZTFz5sw4+eSTd8ugAAAAAMCuKSgGRkSccsopccopp8TGjRtj7dq1ERHRqlWraNiwYY0PBwAAAADUnIJj4GYNGzaM9u3b1+QsAAAAAMBuVNC3CQMAAAAAdZcYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCIaFHsAAAAoxNq1a+O2226LhQsXxsqVKyMiol27dtG3b98oLy+P1q1bF3lCAIDay5WBAADUGS+88EJ06dIlpk2bFqWlpTFw4MAYOHBglJaWxrRp06Jbt26xePHiHR6nsrIy1q9fn3eryjbtgWcAAFBcrgwEAKDOGDNmTJx00kkxY8aMyOVyeduyLItzzz03xowZEwsXLtzucaZMmRKTJ0/OWzswDo5vxDdrfGYAgNrElYEAANQZL730UowdO3aLEBgRkcvlYuzYsbF06dIdHmfChAmxbt26vNuB0W03TAwAULu4MhAAgDqjXbt2sWjRoujWbevhbtGiRdG2bdsdHqekpCRKSkry1url6tfIjAAAtZkYCABAnTFu3Lg4++yzY8mSJTF48ODq8Ldq1aqYP39+3HLLLXH11VcXeUoAgNpLDAQAoM4YNWpUtGrVKq699tq48cYbY9Omv33pR/369aOsrCxmzpwZJ598cpGnBACovcRAAADqlFNOOSVOOeWU2LhxY6xduzYiIlq1ahUNGzYs8mQAALWfGAgAQJ3UsGHDaN++fbHHAACoU3ybMAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJCIBsUeAGBPGfK1nsUeAQAAAIrKlYEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkIgaj4HvvPNO/PjHP97uPpWVlbF+/fq8W1W2qaZHAQAAAAD+Ti7LsqwmD/jSSy/Ft7/97di0adtxb9KkSTF58uS8tQPj4PhG7ps1OQoAwB4zr+reYo/ALqpa2aXYI1DHDPlaz2KPQB0z568vFXsE6hh/z1ConTknbVDoQR9++OHtbn/rrbd2eIwJEybEBRdckLd2Yml5oaMAAAAAAAUoOAYOHz48crlcbO+Cwlwut91jlJSURElJSd5avVz9QkcBAAAAAApQ8GcGtm/fPn77299GVVXVVm8vvvji7pgTAAAAANhFBcfAsrKyWLJkyTa37+iqQQAAAACgOAp+m/D48eOjoqJim9s7deoUCxYs2KWhAAAAAICaV3AMHDBgwHa3N23aNAYNGvSlBwIAAAAAdo+C3yYMAAAAANRNYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABLRoNgDbDbnry8VewTgK27I13oWewTqEP8uAQAAX0WuDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAABfKe+88078+Mc/3u4+lZWVsX79+rxbZWXVHpoQAKB4xEAAAL5SPvjgg7jjjju2u8+UKVOitLQ07/aL6z/cQxMCABRPg2IPAAAAhXj44Ye3u/2tt97a4TEmTJgQF1xwQd5aww+/vUtzAQDUBWIgAAB1yvDhwyOXy0WWZdvcJ5fLbfcYJSUlUVJSkrdW9ak3zQAAX33OeAAAqFPat28fv/3tb6OqqmqrtxdffLHYIwIA1FpiIAAAdUpZWVksWbJkm9t3dNUgAEDKvE0YAIA6Zfz48VFRUbHN7Z06dYoFCxbswYkAAOoOMRAAgDplwIAB293etGnTGDRo0B6aBgCgbvE2YQAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiSg4Bn722Wfx+9//Pv70pz9tsW3Dhg3xq1/9aofHqKysjPXr1+fdKiurCh0FAAAAAChAQTHw9ddfj4MPPjgGDhwYPXr0iEGDBsV7771XvX3dunVx5pln7vA4U6ZMidLS0rzbL67/sPDpAQAAAICdVlAMvPDCC6N79+6xevXqeO2116J58+bRr1+/ePvttwt60AkTJsS6devybheN2aegYwAAAAAAhWlQyM7PP/98PPHEE9GqVato1apVPPLII/HTn/40BgwYEAsWLIimTZvu1HFKSkqipKQkb63qUx9fCAAAAAC7U0EF7rPPPosGDf5/P8zlcnHTTTfF97///Rg0aFC8/vrrNT4gAAAAAFAzCroysFu3brF48eI4+OCD89anT58eERHDhg2ruckAAAAAgBpV0JWBJ554YsyePXur26ZPnx4/+tGPIsuyGhkMAAAAAKhZBcXACRMmxGOPPbbN7TfeeGNUVVXt8lAAAAAAQM3zrR0AAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEiEGAgAAAAAiRADAQAAACARYiAAAAAAJEIMBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEAMBAAAAIBFiIAAAAAAkQgwEAAAAgESIgQAAAACQCDEQAAAAABIhBgIAAABAIsRAAAAAAEhELsuyrNhDsHWVlZUxZcqUmDBhQpSUlBR7HOoArxkK5TVDobxmID3+3FMIrxcK5TVDobxmdp0YWIutX78+SktLY926dbH33nsXexzqAK8ZCuU1Q6G8ZiA9/txTCK8XCuU1Q6G8ZnadtwkDAAAAQCLEQAAAAABIhBgIAAAAAIkQA2uxkpKSmDhxog/EZKd5zVAorxkK5TUD6fHnnkJ4vVAorxkK5TWz63yBCAAAAAAkwpWBAAAAAJAIMRAAAAAAEiEGAgAAAEAixEAAAAAASIQYCAAAAACJEANrsRtuuCE6duwYjRs3jt69e8eiRYuKPRK11DPPPBPf//7342tf+1rkcrl48MEHiz0StdyUKVPiiCOOiObNm0ebNm1i+PDh8dprrxV7LGqxm266KQ499NDYe++9Y++9944+ffrE448/XuyxgN3M+SiFcE5KIZyPUijnozVHDKyl7rnnnrjgggti4sSJ8eKLL0bPnj1jyJAhsXr16mKPRi1UUVERPXv2jBtuuKHYo1BHPP300zFq1Kj47//+75g3b15s3Lgxjj322KioqCj2aNRS+++/f/ziF7+IJUuWxOLFi+Poo4+Of/qnf4ply5YVezRgN3E+SqGck1II56MUyvlozcllWZYVewi21Lt37zjiiCNi+vTpERFRVVUVX//612PMmDFx0UUXFXk6arNcLhcPPPBADB8+vNijUIesWbMm2rRpE08//XQMHDiw2ONQR7Rs2TKuuuqqOOuss4o9CrAbOB9lVzgnpVDOR/kynI9+Oa4MrIU+//zzWLJkSRxzzDHVa/Xq1YtjjjkmFi5cWMTJgK+qdevWRcTf/jGFHdm0aVPcfffdUVFREX369Cn2OMBu4HwU2NOcj1II56O7pkGxB2BLa9eujU2bNkXbtm3z1tu2bRuvvvpqkaYCvqqqqqri/PPPj379+kX37t2LPQ612B//+Mfo06dPbNiwIZo1axYPPPBAHHLIIcUeC9gNnI8Ce5LzUXaW89GaIQYCJG7UqFHxyiuvxO9///tij0It17Vr11i6dGmsW7cu7rvvvhg5cmQ8/fTTTsAAgF3ifJSd5Xy0ZoiBtVCrVq2ifv36sWrVqrz1VatWRbt27Yo0FfBVNHr06Piv//qveOaZZ2L//fcv9jjUco0aNYpOnTpFRERZWVm88MILcd1118XNN99c5MmAmuZ8FNhTnI9SCOejNcNnBtZCjRo1irKyspg/f371WlVVVcyfP9974YEakWVZjB49Oh544IF48skn48ADDyz2SNRBVVVVUVlZWewxgN3A+SiwuzkfpSY4H/1yXBlYS11wwQUxcuTIOPzww6NXr14xderUqKioiDPPPLPYo1ELffLJJ/Hmm29W/7x8+fJYunRptGzZMg444IAiTkZtNWrUqJg1a1Y89NBD0bx581i5cmVERJSWlkaTJk2KPB210YQJE2Lo0KFxwAEHxMcffxyzZs2Kp556KubMmVPs0YDdxPkohXJOSiGcj1Io56M1J5dlWVbsIdi66dOnx1VXXRUrV66Mww47LKZNmxa9e/cu9ljUQk899VQcddRRW6yPHDkyZs6cuecHotbL5XJbXb/99tujvLx8zw5DnXDWWWfF/Pnz47333ovS0tI49NBD48ILL4zvfve7xR4N2I2cj1II56QUwvkohXI+WnPEQAAAAABIhM8MBAAAAIBEiIEAAAAAkAgxEAAAAAASIQYCAAAAQCLEQAAAAABIhBgIAAAAAIkQAwEAAAAgEWIgAAAAACRCDAQAAACARIiBAAAAAJAIMRAAAAAAEvH/AIQkWF9ePsqIAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "vmp_agents = agents = Agent(\n", + " A=A,\n", + " B=B,\n", + " C=C,\n", + " D=D,\n", + " E=E,\n", + " pA=pA,\n", + " pB=pB,\n", + " policy_len=3,\n", + " control_fac_idx=None,\n", + " onehot_obs=True,\n", + " action_selection=\"deterministic\",\n", + " sampling_mode=\"full\",\n", + " inference_algo=\"vmp\",\n", + " num_iter=16\n", + ")\n", + "\n", + "vmp_obs = [jnp.moveaxis(obs[0], 0, 1)]\n", + "post_vmp_beliefs = vmp_agents.infer_states(vmp_obs, vmp_agents.D, past_actions=jnp.stack(action_hist, 1))\n", + "\n", + "#with sparse matrices\n", + "fig, axes = plt.subplots(1, 2, figsize=(16, 4), sharex=True)\n", + "\n", + "sns.heatmap(post_vmp_beliefs[0][0].mT, ax=axes[0], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", + "sns.heatmap(post_vmp_beliefs[1][0].mT, ax=axes[1], cbar=False, vmax=1., vmin=0., cmap='viridis')\n", + "\n", + "fig.suptitle('VMP smoothed beliefs');" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymdp", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pymdp/jax/agent.py b/pymdp/jax/agent.py index 776a65dd..e3d22e8a 100644 --- a/pymdp/jax/agent.py +++ b/pymdp/jax/agent.py @@ -298,8 +298,55 @@ def learning(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1., lr_pB return agent + + @vmap + def infer_parameters(self, beliefs_A, outcomes, actions, beliefs_B=None, lr_pA=1., lr_pB=1., **kwargs): + agent = self + beliefs_B = beliefs_A if beliefs_B is None else beliefs_B + if self.inference_algo == 'ovf': + smoothed_marginals_and_joints = inference.smoothing_ovf(beliefs_A, self.B, actions) + marginal_beliefs = smoothed_marginals_and_joints[0] + joint_beliefs = smoothed_marginals_and_joints[1] + else: + marginal_beliefs = beliefs_A + if self.learn_B: + nf = len(beliefs_B) + joint_fn = lambda f: [beliefs_B[f][1:]] + [beliefs_B[f_idx][:-1] for f_idx in self.B_dependencies[f]] + joint_beliefs = jtu.tree_map(joint_fn, list(range(nf))) + + if self.learn_A: + qA, E_qA = learning.update_obs_likelihood_dirichlet( + self.pA, + outcomes, + marginal_beliefs, + A_dependencies=self.A_dependencies, + num_obs=self.num_obs, + onehot_obs=self.onehot_obs, + lr=lr_pA, + ) + + agent = tree_at(lambda x: (x.A, x.pA), agent, (E_qA, qA)) + + if self.learn_B: + assert beliefs_B[0].shape[0] == actions.shape[0] + 1 + qB, E_qB = learning.update_state_transition_dirichlet( + self.pB, + joint_beliefs, + actions, + num_controls=self.num_controls, + lr=lr_pB + ) + + # if you have updated your beliefs about transitions, you need to re-compute the I matrix used for inductive inferenece + if self.use_inductive and self.H is not None: + I_updated = control.generate_I_matrix(self.H, E_qB, self.inductive_threshold, self.inductive_depth) + else: + I_updated = self.I + + agent = tree_at(lambda x: (x.B, x.pB, x.I), agent, (E_qB, qB, I_updated)) + @vmap - def infer_states(self, observations, past_actions, empirical_prior, qs_hist, mask=None): + def infer_states(self, observations, empirical_prior, *, past_actions=None, qs_hist=None, mask=None): """ Update approximate posterior over hidden states by solving variational inference problem, given an observation. diff --git a/pymdp/jax/inference.py b/pymdp/jax/inference.py index 790ae354..dc2b254c 100644 --- a/pymdp/jax/inference.py +++ b/pymdp/jax/inference.py @@ -4,7 +4,9 @@ import jax.numpy as jnp from .algos import run_factorized_fpi, run_mmp, run_vmp -from jax import tree_util as jtu +from jax import tree_util as jtu, lax +from jax.experimental.sparse._base import JAXSparse +from jaxtyping import Array def update_posterior_states( A, @@ -55,4 +57,90 @@ def update_posterior_states( qs_hist = qs return qs_hist - + +def joint_dist_factor_dense(b: Array, filtered_qs: list[Array], actions: Array): + qs_last = filtered_qs[-1] + qs_filter = filtered_qs[:-1] + + # conditional dist - timestep x s_{t+1} | s_{t} + time_b = jnp.moveaxis(b[..., actions], -1, 0) + # time_b = b[...,actions].transpose([b.ndim-1] + list(range(b.ndim-1))) + + # joint dist - timestep x s_{t+1} x s_{t} + qs_joint = time_b * jnp.expand_dims(qs_filter, -1) + + # cond dist - timestep x s_{t} | s_{t+1} + qs_backward_cond = jnp.moveaxis( + qs_joint / qs_joint.sum(-2, keepdims=True), -2, -1 + ) + # tranpose_idx = list(range(len(qs_joint.shape[:-2]))) + [qs_joint.ndim-1, qs_joint.ndim-2] + # qs_backward_cond = (qs_joint / qs_joint.sum(-2, keepdims=True).todense()).transpose(tranpose_idx) + + def step_fn(qs_smooth_past, backward_b): + qs_joint = backward_b * qs_smooth_past + qs_smooth = qs_joint.sum(-1) + + return qs_smooth, (qs_smooth, qs_joint) + + # seq_qs will contain a sequence of smoothed marginals and joints + _, seq_qs = lax.scan( + step_fn, + qs_last, + qs_backward_cond, + reverse=True, + unroll=2 + ) + + # we add the last filtered belief to smoothed beliefs + qs_smooth_all = jnp.concatenate([seq_qs[0], jnp.expand_dims(qs_last, 0)], 0) + return qs_smooth_all, seq_qs[1] + +def joint_dist_factor_sparse(b: JAXSparse, filtered_qs: list[Array], actions: Array): + qs_last = filtered_qs[-1] + qs_filter = filtered_qs[:-1] + + # conditional dist - timestep x s_{t+1} | s_{t} + time_b = b[...,actions].transpose([b.ndim-1] + list(range(b.ndim-1))) + + # joint dist - timestep x s_{t+1} x s_{t} + qs_joint = time_b * jnp.expand_dims(qs_filter, -1) + + # cond dist - timestep x s_{t} | s_{t+1} + tranpose_idx = list(range(len(qs_joint.shape[:-2]))) + [qs_joint.ndim-1, qs_joint.ndim-2] + qs_backward_cond = (qs_joint / qs_joint.sum(-2, keepdims=True).todense()).transpose(tranpose_idx) + + def step_fn(qs_smooth_past, t): + qs_joint = qs_backward_cond[t] * qs_smooth_past + qs_smooth = qs_joint.sum(-1) + + return qs_smooth.todense(), (qs_smooth.todense(), qs_joint) + + # seq_qs will contain a sequence of smoothed marginals and joints + _, seq_qs = lax.scan( + step_fn, + qs_last, + jnp.arange(qs_backward_cond.shape[0]), + reverse=True, + unroll=2 + ) + + # we add the last filtered belief to smoothed beliefs + + qs_smooth_all = jnp.concatenate([seq_qs[0], jnp.expand_dims(qs_last, 0)], 0) + return qs_smooth_all, seq_qs[1] + + +def smoothing_ovf(filtered_post, B, past_actions): + assert len(filtered_post) == len(B) + nf = len(B) # number of factors + + joint = lambda b, qs, f: joint_dist_factor_sparse(b, qs, past_actions[..., f]) if isinstance(b, JAXSparse) else joint_dist_factor_dense(b, qs, past_actions[..., f]) + + marginals_and_joints = [] + for b, qs, f in zip(B, filtered_post, list(range(nf))): + marginals_and_joints.append( joint(b, qs, f) ) + + return marginals_and_joints + + + \ No newline at end of file diff --git a/pymdp/jax/learning.py b/pymdp/jax/learning.py index c075aab6..6c681751 100644 --- a/pymdp/jax/learning.py +++ b/pymdp/jax/learning.py @@ -2,11 +2,9 @@ # -*- coding: utf-8 -*- # pylint: disable=no-member -import numpy as np -from .maths import multidimensional_outer +from .maths import multidimensional_outer, dirichlet_expected_value from jax.tree_util import tree_map -from jax import vmap -import jax.numpy as jnp +from jax import vmap, nn def update_obs_likelihood_dirichlet_m(pA_m, obs_m, qs, dependencies_m, lr=1.0): """ JAX version of ``pymdp.learning.update_obs_likelihood_dirichlet_m`` """ @@ -26,17 +24,27 @@ def update_obs_likelihood_dirichlet_m(pA_m, obs_m, qs, dependencies_m, lr=1.0): dfda = vmap(multidimensional_outer)([obs_m] + relevant_factors).sum(axis=0) - return pA_m + lr * dfda + new_pA_m = pA_m + lr * dfda + + return new_pA_m, dirichlet_expected_value(new_pA_m) -def update_obs_likelihood_dirichlet(pA, obs, qs, A_dependencies, lr=1.0): +def update_obs_likelihood_dirichlet(pA, obs, qs, *, A_dependencies, onehot_obs, num_obs, lr): """ JAX version of ``pymdp.learning.update_obs_likelihood_dirichlet`` """ - update_A_fn = lambda pA_m, obs_m, dependencies_m: update_obs_likelihood_dirichlet_m(pA_m, obs_m, qs, dependencies_m, lr=lr) - qA = tree_map(update_A_fn, pA, obs, A_dependencies) + obs_m = lambda o, dim: nn.one_hot(o, dim) if not onehot_obs else o + update_A_fn = lambda pA_m, o_m, dim, dependencies_m: update_obs_likelihood_dirichlet_m( + pA_m, obs_m(o_m, dim), qs, dependencies_m, lr=lr + ) + result = tree_map(update_A_fn, pA, obs, num_obs, A_dependencies) + qA = [] + E_qA = [] + for r in result: + qA.append(r[0]) + E_qA.append(r[1]) - return qA + return qA, E_qA -def update_state_likelihood_dirichlet_f(pB_f, actions_f, current_qs, qs_seq, dependencies_f, lr=1.0): +def update_state_transition_dirichlet_f(pB_f, actions_f, joint_qs_f, lr=1.0): """ JAX version of ``pymdp.learning.update_state_likelihood_dirichlet_f`` """ # pB_f - parameters of the dirichlet from the prior # pB_f.shape = (num_states[f] x num_states[f] x num_actions[f]) where f is the index of the hidden state factor @@ -50,265 +58,275 @@ def update_state_likelihood_dirichlet_f(pB_f, actions_f, current_qs, qs_seq, dep # \otimes is a multidimensional outer product, not just a outer product of two vectors # \kappa is an optional learning rate - past_qs = tree_map(lambda f_idx: qs_seq[f_idx][:-1], dependencies_f) - dfdb = vmap(multidimensional_outer)([current_qs[1:]] + past_qs + [actions_f]).sum(axis=0) + dfdb = vmap(multidimensional_outer)(joint_qs_f + [actions_f]).sum(axis=0) qB_f = pB_f + lr * dfdb - return qB_f + return qB_f, dirichlet_expected_value(qB_f) -def update_state_likelihood_dirichlet(pB, beliefs, actions_onehot, B_dependencies, lr=1.0): +def update_state_transition_dirichlet(pB, joint_beliefs, actions, *, num_controls, lr): - update_B_f_fn = lambda pB_f, action_f, qs_f, dependencies_f: update_state_likelihood_dirichlet_f(pB_f, action_f, qs_f, beliefs, dependencies_f, lr=lr) - qB = tree_map(update_B_f_fn, pB, actions_onehot, beliefs, B_dependencies) + nf = len(pB) + actions_onehot_fn = lambda f, dim: nn.one_hot(actions[..., f], dim, axis=-1) + update_B_f_fn = lambda pB_f, joint_qs_f, f, na: update_state_transition_dirichlet_f( + pB_f, actions_onehot_fn(f, na), joint_qs_f, lr=lr + ) + result = tree_map( + update_B_f_fn, pB, joint_beliefs, list(range(nf)), num_controls + ) - return qB - + qB = [] + E_qB = [] + for r in result: + qB.append(r[0]) + E_qB.append(r[1]) -def update_state_prior_dirichlet( - pD, qs, lr=1.0, factors="all" -): - """ - Update Dirichlet parameters of the initial hidden state distribution - (prior beliefs about hidden states at the beginning of the inference window). - - Parameters - ----------- - pD: ``numpy.ndarray`` of dtype object - Prior Dirichlet parameters over initial hidden state prior (same shape as ``qs``) - qs: 1D ``numpy.ndarray`` or ``numpy.ndarray`` of dtype object - Marginal posterior beliefs over hidden states at current timepoint - lr: float, default ``1.0`` - Learning rate, scale of the Dirichlet pseudo-count update. - factors: ``list``, default "all" - Indices (ranging from 0 to ``n_factors - 1``) of the hidden state factors to include - in learning. Defaults to "all", meaning that factor-specific sub-vectors of ``pD`` - are all updated using the corresponding hidden state distributions. + return qB, E_qB + +# def update_state_prior_dirichlet( +# pD, qs, lr=1.0, factors="all" +# ): +# """ +# Update Dirichlet parameters of the initial hidden state distribution +# (prior beliefs about hidden states at the beginning of the inference window). + +# Parameters +# ----------- +# pD: ``numpy.ndarray`` of dtype object +# Prior Dirichlet parameters over initial hidden state prior (same shape as ``qs``) +# qs: 1D ``numpy.ndarray`` or ``numpy.ndarray`` of dtype object +# Marginal posterior beliefs over hidden states at current timepoint +# lr: float, default ``1.0`` +# Learning rate, scale of the Dirichlet pseudo-count update. +# factors: ``list``, default "all" +# Indices (ranging from 0 to ``n_factors - 1``) of the hidden state factors to include +# in learning. Defaults to "all", meaning that factor-specific sub-vectors of ``pD`` +# are all updated using the corresponding hidden state distributions. - Returns - ----------- - qD: ``numpy.ndarray`` of dtype object - Posterior Dirichlet parameters over initial hidden state prior (same shape as ``qs``), after having updated it with state beliefs. - """ +# Returns +# ----------- +# qD: ``numpy.ndarray`` of dtype object +# Posterior Dirichlet parameters over initial hidden state prior (same shape as ``qs``), after having updated it with state beliefs. +# """ - num_factors = len(pD) +# num_factors = len(pD) - qD = copy.deepcopy(pD) +# qD = copy.deepcopy(pD) - if factors == "all": - factors = list(range(num_factors)) +# if factors == "all": +# factors = list(range(num_factors)) - for factor in factors: - idx = pD[factor] > 0 # only update those state level indices that have some prior probability - qD[factor][idx] += (lr * qs[factor][idx]) +# for factor in factors: +# idx = pD[factor] > 0 # only update those state level indices that have some prior probability +# qD[factor][idx] += (lr * qs[factor][idx]) - return qD +# return qD -def _prune_prior(prior, levels_to_remove, dirichlet = False): - """ - Function for pruning a prior Categorical distribution (e.g. C, D) +# def _prune_prior(prior, levels_to_remove, dirichlet = False): +# """ +# Function for pruning a prior Categorical distribution (e.g. C, D) - Parameters - ----------- - prior: 1D ``numpy.ndarray`` or ``numpy.ndarray`` of dtype object - The vector(s) containing the priors over hidden states of a generative model, e.g. the prior over hidden states (``D`` vector). - levels_to_remove: ``list`` of ``int``, ``list`` of ``list`` - A ``list`` of the levels (indices of the support) to remove. If the prior in question has multiple hidden state factors / multiple observation modalities, - then this will be a ``list`` of ``list``, where each sub-list within ``levels_to_remove`` will contain the levels to prune for a particular hidden state factor or modality - dirichlet: ``Bool``, default ``False`` - A Boolean flag indicating whether the input vector(s) is/are a Dirichlet distribution, and therefore should not be normalized at the end. - @TODO: Instead, the dirichlet parameters from the pruned levels should somehow be re-distributed among the remaining levels +# Parameters +# ----------- +# prior: 1D ``numpy.ndarray`` or ``numpy.ndarray`` of dtype object +# The vector(s) containing the priors over hidden states of a generative model, e.g. the prior over hidden states (``D`` vector). +# levels_to_remove: ``list`` of ``int``, ``list`` of ``list`` +# A ``list`` of the levels (indices of the support) to remove. If the prior in question has multiple hidden state factors / multiple observation modalities, +# then this will be a ``list`` of ``list``, where each sub-list within ``levels_to_remove`` will contain the levels to prune for a particular hidden state factor or modality +# dirichlet: ``Bool``, default ``False`` +# A Boolean flag indicating whether the input vector(s) is/are a Dirichlet distribution, and therefore should not be normalized at the end. +# @TODO: Instead, the dirichlet parameters from the pruned levels should somehow be re-distributed among the remaining levels - Returns - ----------- - reduced_prior: 1D ``numpy.ndarray`` or ``numpy.ndarray`` of dtype object - The prior vector(s), after pruning, that lacks the hidden state or modality levels indexed by ``levels_to_remove`` - """ +# Returns +# ----------- +# reduced_prior: 1D ``numpy.ndarray`` or ``numpy.ndarray`` of dtype object +# The prior vector(s), after pruning, that lacks the hidden state or modality levels indexed by ``levels_to_remove`` +# """ - if utils.is_obj_array(prior): # in case of multiple hidden state factors +# if utils.is_obj_array(prior): # in case of multiple hidden state factors - assert all([type(levels) == list for levels in levels_to_remove]) +# assert all([type(levels) == list for levels in levels_to_remove]) - num_factors = len(prior) +# num_factors = len(prior) - reduced_prior = utils.obj_array(num_factors) +# reduced_prior = utils.obj_array(num_factors) - factors_to_remove = [] - for f, s_i in enumerate(prior): # loop over factors (or modalities) +# factors_to_remove = [] +# for f, s_i in enumerate(prior): # loop over factors (or modalities) - ns = len(s_i) - levels_to_keep = list(set(range(ns)) - set(levels_to_remove[f])) - if len(levels_to_keep) == 0: - print(f'Warning... removing ALL levels of factor {f} - i.e. the whole hidden state factor is being removed\n') - factors_to_remove.append(f) - else: - if not dirichlet: - reduced_prior[f] = utils.norm_dist(s_i[levels_to_keep]) - else: - raise(NotImplementedError("Need to figure out how to re-distribute concentration parameters from pruned levels, across remaining levels")) - - - if len(factors_to_remove) > 0: - factors_to_keep = list(set(range(num_factors)) - set(factors_to_remove)) - reduced_prior = reduced_prior[factors_to_keep] - - else: # in case of one hidden state factor - - assert all([type(level_i) == int for level_i in levels_to_remove]) - - ns = len(prior) - levels_to_keep = list(set(range(ns)) - set(levels_to_remove)) - - if not dirichlet: - reduced_prior = utils.norm_dist(prior[levels_to_keep]) - else: - raise(NotImplementedError("Need to figure out how to re-distribute concentration parameters from pruned levels, across remaining levels")) - - return reduced_prior - -def _prune_A(A, obs_levels_to_prune, state_levels_to_prune, dirichlet = False): - """ - Function for pruning a observation likelihood model (with potentially multiple hidden state factors) - :meta private: - Parameters - ----------- - A: ``numpy.ndarray`` with ``ndim >= 2``, or ``numpy.ndarray`` of dtype object - Sensory likelihood mapping or 'observation model', mapping from hidden states to observations. Each element ``A[m]`` of - stores an ``numpy.ndarray`` multidimensional array for observation modality ``m``, whose entries ``A[m][i, j, k, ...]`` store - the probability of observation level ``i`` given hidden state levels ``j, k, ...`` - obs_levels_to_prune: ``list`` of int or ``list`` of ``list``: - A ``list`` of the observation levels to remove. If the likelihood in question has multiple observation modalities, - then this will be a ``list`` of ``list``, where each sub-list within ``obs_levels_to_prune`` will contain the observation levels - to remove for a particular observation modality - state_levels_to_prune: ``list`` of ``int`` - A ``list`` of the hidden state levels to remove (this will be the same across modalities) - dirichlet: ``Bool``, default ``False`` - A Boolean flag indicating whether the input array(s) is/are a Dirichlet distribution, and therefore should not be normalized at the end. - @TODO: Instead, the dirichlet parameters from the pruned columns should somehow be re-distributed among the remaining columns - - Returns - ----------- - reduced_A: ``numpy.ndarray`` with ndim >= 2, or ``numpy.ndarray ``of dtype object - The observation model, after pruning, which lacks the observation or hidden state levels given by the arguments ``obs_levels_to_prune`` and ``state_levels_to_prune`` - """ - - columns_to_keep_list = [] - if utils.is_obj_array(A): - num_states = A[0].shape[1:] - for f, ns in enumerate(num_states): - indices_f = np.array( list(set(range(ns)) - set(state_levels_to_prune[f])), dtype = np.intp) - columns_to_keep_list.append(indices_f) - else: - num_states = A.shape[1] - indices = np.array( list(set(range(num_states)) - set(state_levels_to_prune)), dtype = np.intp ) - columns_to_keep_list.append(indices) - - if utils.is_obj_array(A): # in case of multiple observation modality - - assert all([type(o_m_levels) == list for o_m_levels in obs_levels_to_prune]) - - num_modalities = len(A) - - reduced_A = utils.obj_array(num_modalities) +# ns = len(s_i) +# levels_to_keep = list(set(range(ns)) - set(levels_to_remove[f])) +# if len(levels_to_keep) == 0: +# print(f'Warning... removing ALL levels of factor {f} - i.e. the whole hidden state factor is being removed\n') +# factors_to_remove.append(f) +# else: +# if not dirichlet: +# reduced_prior[f] = utils.norm_dist(s_i[levels_to_keep]) +# else: +# raise(NotImplementedError("Need to figure out how to re-distribute concentration parameters from pruned levels, across remaining levels")) + + +# if len(factors_to_remove) > 0: +# factors_to_keep = list(set(range(num_factors)) - set(factors_to_remove)) +# reduced_prior = reduced_prior[factors_to_keep] + +# else: # in case of one hidden state factor + +# assert all([type(level_i) == int for level_i in levels_to_remove]) + +# ns = len(prior) +# levels_to_keep = list(set(range(ns)) - set(levels_to_remove)) + +# if not dirichlet: +# reduced_prior = utils.norm_dist(prior[levels_to_keep]) +# else: +# raise(NotImplementedError("Need to figure out how to re-distribute concentration parameters from pruned levels, across remaining levels")) + +# return reduced_prior + +# def _prune_A(A, obs_levels_to_prune, state_levels_to_prune, dirichlet = False): +# """ +# Function for pruning a observation likelihood model (with potentially multiple hidden state factors) +# :meta private: +# Parameters +# ----------- +# A: ``numpy.ndarray`` with ``ndim >= 2``, or ``numpy.ndarray`` of dtype object +# Sensory likelihood mapping or 'observation model', mapping from hidden states to observations. Each element ``A[m]`` of +# stores an ``numpy.ndarray`` multidimensional array for observation modality ``m``, whose entries ``A[m][i, j, k, ...]`` store +# the probability of observation level ``i`` given hidden state levels ``j, k, ...`` +# obs_levels_to_prune: ``list`` of int or ``list`` of ``list``: +# A ``list`` of the observation levels to remove. If the likelihood in question has multiple observation modalities, +# then this will be a ``list`` of ``list``, where each sub-list within ``obs_levels_to_prune`` will contain the observation levels +# to remove for a particular observation modality +# state_levels_to_prune: ``list`` of ``int`` +# A ``list`` of the hidden state levels to remove (this will be the same across modalities) +# dirichlet: ``Bool``, default ``False`` +# A Boolean flag indicating whether the input array(s) is/are a Dirichlet distribution, and therefore should not be normalized at the end. +# @TODO: Instead, the dirichlet parameters from the pruned columns should somehow be re-distributed among the remaining columns + +# Returns +# ----------- +# reduced_A: ``numpy.ndarray`` with ndim >= 2, or ``numpy.ndarray ``of dtype object +# The observation model, after pruning, which lacks the observation or hidden state levels given by the arguments ``obs_levels_to_prune`` and ``state_levels_to_prune`` +# """ + +# columns_to_keep_list = [] +# if utils.is_obj_array(A): +# num_states = A[0].shape[1:] +# for f, ns in enumerate(num_states): +# indices_f = np.array( list(set(range(ns)) - set(state_levels_to_prune[f])), dtype = np.intp) +# columns_to_keep_list.append(indices_f) +# else: +# num_states = A.shape[1] +# indices = np.array( list(set(range(num_states)) - set(state_levels_to_prune)), dtype = np.intp ) +# columns_to_keep_list.append(indices) + +# if utils.is_obj_array(A): # in case of multiple observation modality + +# assert all([type(o_m_levels) == list for o_m_levels in obs_levels_to_prune]) + +# num_modalities = len(A) + +# reduced_A = utils.obj_array(num_modalities) - for m, A_i in enumerate(A): # loop over modalities +# for m, A_i in enumerate(A): # loop over modalities - no = A_i.shape[0] - rows_to_keep = np.array(list(set(range(no)) - set(obs_levels_to_prune[m])), dtype = np.intp) +# no = A_i.shape[0] +# rows_to_keep = np.array(list(set(range(no)) - set(obs_levels_to_prune[m])), dtype = np.intp) - reduced_A[m] = A_i[np.ix_(rows_to_keep, *columns_to_keep_list)] - if not dirichlet: - reduced_A = utils.norm_dist_obj_arr(reduced_A) - else: - raise(NotImplementedError("Need to figure out how to re-distribute concentration parameters from pruned rows/columns, across remaining rows/columns")) - else: # in case of one observation modality +# reduced_A[m] = A_i[np.ix_(rows_to_keep, *columns_to_keep_list)] +# if not dirichlet: +# reduced_A = utils.norm_dist_obj_arr(reduced_A) +# else: +# raise(NotImplementedError("Need to figure out how to re-distribute concentration parameters from pruned rows/columns, across remaining rows/columns")) +# else: # in case of one observation modality - assert all([type(o_levels_i) == int for o_levels_i in obs_levels_to_prune]) +# assert all([type(o_levels_i) == int for o_levels_i in obs_levels_to_prune]) - no = A.shape[0] - rows_to_keep = np.array(list(set(range(no)) - set(obs_levels_to_prune)), dtype = np.intp) +# no = A.shape[0] +# rows_to_keep = np.array(list(set(range(no)) - set(obs_levels_to_prune)), dtype = np.intp) - reduced_A = A[np.ix_(rows_to_keep, *columns_to_keep_list)] - - if not dirichlet: - reduced_A = utils.norm_dist(reduced_A) - else: - raise(NotImplementedError("Need to figure out how to re-distribute concentration parameters from pruned rows/columns, across remaining rows/columns")) - - return reduced_A - -def _prune_B(B, state_levels_to_prune, action_levels_to_prune, dirichlet = False): - """ - Function for pruning a transition likelihood model (with potentially multiple hidden state factors) - - Parameters - ----------- - B: ``numpy.ndarray`` of ``ndim == 3`` or ``numpy.ndarray`` of dtype object - Dynamics likelihood mapping or 'transition model', mapping from hidden states at `t` to hidden states at `t+1`, given some control state `u`. - Each element B[f] of this object array stores a 3-D tensor for hidden state factor `f`, whose entries `B[f][s, v, u] store the probability - of hidden state level `s` at the current time, given hidden state level `v` and action `u` at the previous time. - state_levels_to_prune: ``list`` of ``int`` or ``list`` of ``list`` - A ``list`` of the state levels to remove. If the likelihood in question has multiple hidden state factors, - then this will be a ``list`` of ``list``, where each sub-list within ``state_levels_to_prune`` will contain the state levels - to remove for a particular hidden state factor - action_levels_to_prune: ``list`` of ``int`` or ``list`` of ``list`` - A ``list`` of the control state or action levels to remove. If the likelihood in question has multiple control state factors, - then this will be a ``list`` of ``list``, where each sub-list within ``action_levels_to_prune`` will contain the control state levels - to remove for a particular control state factor - dirichlet: ``Bool``, default ``False`` - A Boolean flag indicating whether the input array(s) is/are a Dirichlet distribution, and therefore should not be normalized at the end. - @TODO: Instead, the dirichlet parameters from the pruned rows/columns should somehow be re-distributed among the remaining rows/columns - - Returns - ----------- - reduced_B: ``numpy.ndarray`` of `ndim == 3` or ``numpy.ndarray`` of dtype object - The transition model, after pruning, which lacks the hidden state levels/action levels given by the arguments ``state_levels_to_prune`` and ``action_levels_to_prune`` - """ - - slices_to_keep_list = [] - - if utils.is_obj_array(B): - - num_controls = [B_arr.shape[2] for _, B_arr in enumerate(B)] - - for c, nc in enumerate(num_controls): - indices_c = np.array( list(set(range(nc)) - set(action_levels_to_prune[c])), dtype = np.intp) - slices_to_keep_list.append(indices_c) - else: - num_controls = B.shape[2] - slices_to_keep = np.array( list(set(range(num_controls)) - set(action_levels_to_prune)), dtype = np.intp ) - - if utils.is_obj_array(B): # in case of multiple hidden state factors - - assert all([type(ns_f_levels) == list for ns_f_levels in state_levels_to_prune]) - - num_factors = len(B) - - reduced_B = utils.obj_array(num_factors) +# reduced_A = A[np.ix_(rows_to_keep, *columns_to_keep_list)] + +# if not dirichlet: +# reduced_A = utils.norm_dist(reduced_A) +# else: +# raise(NotImplementedError("Need to figure out how to re-distribute concentration parameters from pruned rows/columns, across remaining rows/columns")) + +# return reduced_A + +# def _prune_B(B, state_levels_to_prune, action_levels_to_prune, dirichlet = False): +# """ +# Function for pruning a transition likelihood model (with potentially multiple hidden state factors) + +# Parameters +# ----------- +# B: ``numpy.ndarray`` of ``ndim == 3`` or ``numpy.ndarray`` of dtype object +# Dynamics likelihood mapping or 'transition model', mapping from hidden states at `t` to hidden states at `t+1`, given some control state `u`. +# Each element B[f] of this object array stores a 3-D tensor for hidden state factor `f`, whose entries `B[f][s, v, u] store the probability +# of hidden state level `s` at the current time, given hidden state level `v` and action `u` at the previous time. +# state_levels_to_prune: ``list`` of ``int`` or ``list`` of ``list`` +# A ``list`` of the state levels to remove. If the likelihood in question has multiple hidden state factors, +# then this will be a ``list`` of ``list``, where each sub-list within ``state_levels_to_prune`` will contain the state levels +# to remove for a particular hidden state factor +# action_levels_to_prune: ``list`` of ``int`` or ``list`` of ``list`` +# A ``list`` of the control state or action levels to remove. If the likelihood in question has multiple control state factors, +# then this will be a ``list`` of ``list``, where each sub-list within ``action_levels_to_prune`` will contain the control state levels +# to remove for a particular control state factor +# dirichlet: ``Bool``, default ``False`` +# A Boolean flag indicating whether the input array(s) is/are a Dirichlet distribution, and therefore should not be normalized at the end. +# @TODO: Instead, the dirichlet parameters from the pruned rows/columns should somehow be re-distributed among the remaining rows/columns + +# Returns +# ----------- +# reduced_B: ``numpy.ndarray`` of `ndim == 3` or ``numpy.ndarray`` of dtype object +# The transition model, after pruning, which lacks the hidden state levels/action levels given by the arguments ``state_levels_to_prune`` and ``action_levels_to_prune`` +# """ + +# slices_to_keep_list = [] + +# if utils.is_obj_array(B): + +# num_controls = [B_arr.shape[2] for _, B_arr in enumerate(B)] + +# for c, nc in enumerate(num_controls): +# indices_c = np.array( list(set(range(nc)) - set(action_levels_to_prune[c])), dtype = np.intp) +# slices_to_keep_list.append(indices_c) +# else: +# num_controls = B.shape[2] +# slices_to_keep = np.array( list(set(range(num_controls)) - set(action_levels_to_prune)), dtype = np.intp ) + +# if utils.is_obj_array(B): # in case of multiple hidden state factors + +# assert all([type(ns_f_levels) == list for ns_f_levels in state_levels_to_prune]) + +# num_factors = len(B) + +# reduced_B = utils.obj_array(num_factors) - for f, B_f in enumerate(B): # loop over modalities +# for f, B_f in enumerate(B): # loop over modalities - ns = B_f.shape[0] - states_to_keep = np.array(list(set(range(ns)) - set(state_levels_to_prune[f])), dtype = np.intp) +# ns = B_f.shape[0] +# states_to_keep = np.array(list(set(range(ns)) - set(state_levels_to_prune[f])), dtype = np.intp) - reduced_B[f] = B_f[np.ix_(states_to_keep, states_to_keep, slices_to_keep_list[f])] +# reduced_B[f] = B_f[np.ix_(states_to_keep, states_to_keep, slices_to_keep_list[f])] - if not dirichlet: - reduced_B = utils.norm_dist_obj_arr(reduced_B) - else: - raise(NotImplementedError("Need to figure out how to re-distribute concentration parameters from pruned rows/columns, across remaining rows/columns")) +# if not dirichlet: +# reduced_B = utils.norm_dist_obj_arr(reduced_B) +# else: +# raise(NotImplementedError("Need to figure out how to re-distribute concentration parameters from pruned rows/columns, across remaining rows/columns")) - else: # in case of one hidden state factor +# else: # in case of one hidden state factor - assert all([type(state_level_i) == int for state_level_i in state_levels_to_prune]) +# assert all([type(state_level_i) == int for state_level_i in state_levels_to_prune]) - ns = B.shape[0] - states_to_keep = np.array(list(set(range(ns)) - set(state_levels_to_prune)), dtype = np.intp) +# ns = B.shape[0] +# states_to_keep = np.array(list(set(range(ns)) - set(state_levels_to_prune)), dtype = np.intp) - reduced_B = B[np.ix_(states_to_keep, states_to_keep, slices_to_keep)] +# reduced_B = B[np.ix_(states_to_keep, states_to_keep, slices_to_keep)] - if not dirichlet: - reduced_B = utils.norm_dist(reduced_B) - else: - raise(NotImplementedError("Need to figure out how to re-distribute concentration parameters from pruned rows/columns, across remaining rows/columns")) +# if not dirichlet: +# reduced_B = utils.norm_dist(reduced_B) +# else: +# raise(NotImplementedError("Need to figure out how to re-distribute concentration parameters from pruned rows/columns, across remaining rows/columns")) - return reduced_B +# return reduced_B \ No newline at end of file diff --git a/pymdp/jax/maths.py b/pymdp/jax/maths.py index 58b34aff..213b519f 100644 --- a/pymdp/jax/maths.py +++ b/pymdp/jax/maths.py @@ -120,7 +120,7 @@ def spm_wnorm(A): Returns Expectation of logarithm of Dirichlet parameters over a set of Categorical distributions, stored in the columns of A. """ - A = jnp.clip(A, a_min=MINVAL) + A = jnp.clip(A, min=MINVAL) norm = 1. / A.sum(axis=0) avg = 1. / A wA = norm - avg @@ -131,7 +131,7 @@ def dirichlet_expected_value(dir_arr): Returns Expectation of Dirichlet parameters over a set of Categorical distributions, stored in the columns of A. """ - dir_arr = jnp.clip(dir_arr, a_min=MINVAL) + dir_arr = jnp.clip(dir_arr, min=MINVAL) expected_val = jnp.divide(dir_arr, dir_arr.sum(axis=0, keepdims=True)) return expected_val diff --git a/requirements.txt b/requirements.txt index de815d0c..99a4adb8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,5 @@ jaxlib>=0.3.4 equinox>=0.9 numpyro>=0.1 arviz>=0.13 -optax>=0.1 \ No newline at end of file +optax>=0.1 +multimethod>=1.11 \ No newline at end of file diff --git a/test/test_learning_jax.py b/test/test_learning_jax.py index cdb3b86c..60d6bcb3 100644 --- a/test/test_learning_jax.py +++ b/test/test_learning_jax.py @@ -5,7 +5,6 @@ __author__: Dimitrije Markovic, Conor Heins """ -import os import unittest import numpy as np @@ -15,7 +14,7 @@ from pymdp.learning import update_obs_likelihood_dirichlet as update_pA_numpy from pymdp.learning import update_obs_likelihood_dirichlet_factorized as update_pA_numpy_factorized from pymdp.jax.learning import update_obs_likelihood_dirichlet as update_pA_jax -from pymdp import utils, maths +from pymdp import utils class TestLearningJax(unittest.TestCase): @@ -68,7 +67,15 @@ def test_update_observation_likelihood_fullyconnected(self): obs_jax = jtu.tree_map(lambda x: jnp.array(x)[None], list(obs_np)) qs_jax = jtu.tree_map(lambda x: jnp.array(x)[None], list(qs_np)) - qA_jax_test = update_pA_jax(pA_jax, obs_jax, qs_jax, A_dependencies, lr=l_rate) + qA_jax_test, E_qA_jax_test = update_pA_jax( + pA_jax, + obs_jax, + qs_jax, + A_dependencies=A_dependencies, + onehot_obs=True, + num_obs=num_obs, + lr=l_rate + ) for modality, obs_dim in enumerate(num_obs): self.assertTrue(np.allclose(qA_jax_test[modality], qA_np_test[modality])) @@ -122,18 +129,18 @@ def test_update_observation_likelihood_factorized(self): obs_jax = jtu.tree_map(lambda x: jnp.array(x)[None], list(obs_np)) qs_jax = jtu.tree_map(lambda x: jnp.array(x)[None], list(qs_np)) - qA_jax_test = update_pA_jax(pA_jax, obs_jax, qs_jax, A_dependencies, lr=l_rate) + qA_jax_test, E_qA_jax_test = update_pA_jax( + pA_jax, + obs_jax, + qs_jax, + A_dependencies=A_dependencies, + onehot_obs=True, + num_obs=num_obs, + lr=l_rate + ) for modality, obs_dim in enumerate(num_obs): self.assertTrue(np.allclose(qA_jax_test[modality],qA_np_test[modality])) if __name__ == "__main__": - unittest.main() - - - - - - - - + unittest.main() \ No newline at end of file