Skip to content

Commit

Permalink
updated example
Browse files Browse the repository at this point in the history
  • Loading branch information
dimarkov committed Jun 17, 2024
1 parent 397bca4 commit 7553d2a
Showing 1 changed file with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-17 17:54:09.793970: W external/xla/xla/service/gpu/nvptx_compiler.cc:763] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n"
"2024-06-17 18:05:07.093645: W external/xla/xla/service/gpu/nvptx_compiler.cc:763] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n"
]
}
],
Expand Down Expand Up @@ -120,7 +120,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -135,20 +135,20 @@
" if t < len(obs[0]) - 1:\n",
" action_hist.append(actions)\n",
"\n",
"v_jso = jit(vmap(smoothing_ovf))\n",
"v_jso = jit(vmap(smoothing_ovf), backend='gpu')\n",
"actions_seq = jnp.stack(action_hist, 1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"63.1 µs ± 1.89 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
"65.1 µs ± 838 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
Expand All @@ -166,14 +166,14 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"87.7 µs ± 8.25 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
"86.8 µs ± 1.26 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
Expand Down

0 comments on commit 7553d2a

Please sign in to comment.