From 7553d2a5a7d06a3108b0a4b460567559a8ff81dd Mon Sep 17 00:00:00 2001 From: dimarkov <5038100+dimarkov@users.noreply.github.com> Date: Mon, 17 Jun 2024 18:07:17 +0200 Subject: [PATCH] updated example --- .../inference_methods_comparison.ipynb | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/inference_and_learning/inference_methods_comparison.ipynb b/examples/inference_and_learning/inference_methods_comparison.ipynb index 0d78f77e..a2bc4b0e 100644 --- a/examples/inference_and_learning/inference_methods_comparison.ipynb +++ b/examples/inference_and_learning/inference_methods_comparison.ipynb @@ -32,7 +32,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-06-17 17:54:09.793970: W external/xla/xla/service/gpu/nvptx_compiler.cc:763] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n" + "2024-06-17 18:05:07.093645: W external/xla/xla/service/gpu/nvptx_compiler.cc:763] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n" ] } ], @@ -120,7 +120,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -135,20 +135,20 @@ " if t < len(obs[0]) - 1:\n", " action_hist.append(actions)\n", "\n", - "v_jso = jit(vmap(smoothing_ovf))\n", + "v_jso = jit(vmap(smoothing_ovf), backend='gpu')\n", "actions_seq = jnp.stack(action_hist, 1)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "63.1 µs ± 1.89 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + "65.1 µs ± 838 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], @@ -166,14 +166,14 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "87.7 µs ± 8.25 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + "86.8 µs ± 1.26 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ],