From 20dd5cf295f794d3f89621a238ba069ac4290b2f Mon Sep 17 00:00:00 2001 From: mmcky Date: Fri, 2 Feb 2024 06:55:55 +1100 Subject: [PATCH] install torch torchvision and pyro-ppl before jax --- cuda-12.3.1-anaconda-2023-09-py311/Dockerfile | 1 + 1 file changed, 1 insertion(+) diff --git a/cuda-12.3.1-anaconda-2023-09-py311/Dockerfile b/cuda-12.3.1-anaconda-2023-09-py311/Dockerfile index 5d52a24..9a2139b 100644 --- a/cuda-12.3.1-anaconda-2023-09-py311/Dockerfile +++ b/cuda-12.3.1-anaconda-2023-09-py311/Dockerfile @@ -42,6 +42,7 @@ ENV PATH /opt/conda/envs/quantecon/bin:$PATH # Install JAX RUN nvcc --version +RUN pip install torch torchvision pyro-ppl RUN pip install "numpyro[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # Note: always install jax[cuda] last to ensure proper CUDA+CUDANN version linking RUN pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html