Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docker] Upgrade to CUDA 12.6 #160

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions docker/gcp-base-image.dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# default base image: ghcr.io/actions/actions-runner:latest
# base image: Ubuntu 22.04 jammy
# Prune CUDA to only keep gencode >= A100
ARG BASE_IMAGE=ghcr.io/actions/actions-runner:latest
FROM ${BASE_IMAGE}

ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
ARG OVERRIDE_GENCODE="-gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90 -gencode arch=compute_90a,code=sm_90a"
ARG OVERRIDE_GENCODE_CUDNN="-gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90 -gencode arch=compute_90a,code=sm_90a"

RUN sudo apt-get -y update && sudo apt -y update
# fontconfig: required by model doctr_det_predictor
# libjpeg and libpng: optionally required by torchvision (vision#8342)
RUN sudo apt-get install -y git jq gcc g++ \
vim wget curl ninja-build cmake \
libgl1-mesa-glx libsndfile1-dev kmod libxml2-dev libxslt1-dev \
fontconfig libfontconfig1-dev \
libpango-1.0-0 libpangoft2-1.0-0 \
libsdl2-dev libsdl2-2.0-0 \
libjpeg-dev libpng-dev zlib1g-dev

# get switch-cuda utility
RUN sudo wget -q https://raw.githubusercontent.com/phohenecker/switch-cuda/master/switch-cuda.sh -O /usr/bin/switch-cuda.sh
RUN sudo chmod +x /usr/bin/switch-cuda.sh

RUN sudo mkdir -p /workspace; sudo chown runner:runner /workspace

# GKE version: 1.28.5-gke.1217000
# NVIDIA driver version: 535.104.05
# NVIDIA drivers list available at gs://ubuntu_nvidia_packages/
# We assume that the host NVIDIA driver binaries and libraries are mapped to the docker filesystem

# Use the CUDA installation scripts from pytorch/builder
# Install CUDA 12.4 and 12.6, and default to 12.4
RUN cd /workspace; mkdir -p pytorch-ci; cd pytorch-ci; wget https://raw.githubusercontent.com/pytorch/pytorch/main/.ci/docker/common/install_cuda.sh
RUN sudo bash -c "set -x;export OVERRIDE_GENCODE=\"${OVERRIDE_GENCODE}\" OVERRIDE_GENCODE_CUDNN=\"${OVERRIDE_GENCODE_CUDNN}\"; bash /workspace/pytorch-ci/install_cuda.sh 12.4"
RUN sudo bash -c "set -x;export OVERRIDE_GENCODE=\"${OVERRIDE_GENCODE}\" OVERRIDE_GENCODE_CUDNN=\"${OVERRIDE_GENCODE_CUDNN}\"; bash /workspace/pytorch-ci/install_cuda.sh 12.6"

# Install miniconda
RUN wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /workspace/Miniconda3-latest-Linux-x86_64.sh
RUN cd /workspace && \
chmod +x Miniconda3-latest-Linux-x86_64.sh && \
bash ./Miniconda3-latest-Linux-x86_64.sh -b -u

# Test activate miniconda
RUN . ${HOME}/miniconda3/etc/profile.d/conda.sh && \
conda activate base && \
conda init

RUN echo "\
. \${HOME}/miniconda3/etc/profile.d/conda.sh\n\
conda activate base\n\
export CONDA_HOME=\${HOME}/miniconda3\n\
export CUDA_HOME=/usr/local/cuda\n\
export PATH=\${CUDA_HOME}/bin\${PATH:+:\${PATH}}\n\
export LD_LIBRARY_PATH=\${CUDA_HOME}/lib64\${LD_LIBRARY_PATH:+:\${LD_LIBRARY_PATH}}\n\
export LIBRARY_PATH=\${CUDA_HOME}/lib64\${LIBRARY_PATHPATH:+:\${LIBRARY_PATHPATH}}\n" >> /workspace/setup_instance.sh

RUN echo ". /workspace/setup_instance.sh\n" >> ${HOME}/.bashrc
5 changes: 5 additions & 0 deletions tools/cuda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
"magma": "magma-cuda124",
"jax": "jax[cuda12]",
},
"12.6": {
"pytorch_url": "cu126",
"magma": "magma-cuda126",
"jax": "jax[cuda12]",
},
}


Expand Down
Loading