diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index ce0c8e71..fc0053c4 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -37,6 +37,8 @@ steps: - srun julia --project=test test/runtests.jl env: CLIMACOMMS_TEST_DEVICE: CPU + CLIMACOMMS_CONTEXT: MPI + CLIMACOMMS_DEVICE: CPU agents: slurm_ntasks: 2 @@ -46,6 +48,7 @@ steps: - julia --threads 4 --project=test test/runtests.jl env: CLIMACOMMS_TEST_DEVICE: CPU + CLIMACOMMS_DEVICE: CPU agents: slurm_cpus_per_task: 4 @@ -55,6 +58,8 @@ steps: - srun julia --threads 4 --project=test test/runtests.jl env: CLIMACOMMS_TEST_DEVICE: CPU + CLIMACOMMS_CONTEXT: MPI + CLIMACOMMS_DEVICE: CPU agents: slurm_ntasks: 2 slurm_cpus_per_task: 4 @@ -65,6 +70,7 @@ steps: - julia --project=test test/runtests.jl env: CLIMACOMMS_TEST_DEVICE: CUDA + CLIMACOMMS_DEVICE: CUDA agents: slurm_gpus: 1 @@ -74,6 +80,8 @@ steps: - srun julia --project=test test/runtests.jl env: CLIMACOMMS_TEST_DEVICE: CUDA + CLIMACOMMS_CONTEXT: MPI + CLIMACOMMS_DEVICE: CUDA agents: slurm_gpus_per_task: 1 slurm_ntasks: 2 diff --git a/NEWS.md b/NEWS.md new file mode 100644 index 00000000..3f268454 --- /dev/null +++ b/NEWS.md @@ -0,0 +1,24 @@ +ClimaComms.jl Release Notes +======================== + +v0.6.0 +------- + +- ![][badge-💥breaking] `ClimaComms` does no longer try to guess the correct + compute device: the default is now CPU. To control which device to use, + use the `CLIMACOMMS_DEVICE` environment variable. +- ![][badge-💥breaking] `CUDA` and `MPI` are now extensions in `ClimaComms`. To + use `CUDA`/`MPI`, `CUDA.jl`/`MPI.jl` have to be loaded. A convenience macro + `ClimaComms.@import_required_backends` checks what device/context could be + used and conditionally loads `CUDA.jl`/`MPI.jl`. It is recommended to change + ```julia + import ClimaComms + ``` + to + ```julia + import ClimaComms + ClimaComms.@import_required_backends + ``` + This has to be done before calling `ClimaComms.context()`. + +[badge-💥breaking]: https://img.shields.io/badge/💥BREAKING-red.svg diff --git a/docs/src/index.md b/docs/src/index.md index 26f69cdd..d158f3e7 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -8,6 +8,12 @@ CurrentModule = ClimaComms ClimaComms ``` +## Loading + +```@docs +ClimaComms.import_required_backends +``` + ## Devices ```@docs diff --git a/src/ClimaComms.jl b/src/ClimaComms.jl index 3d4084c3..a8605233 100644 --- a/src/ClimaComms.jl +++ b/src/ClimaComms.jl @@ -15,5 +15,6 @@ include("devices.jl") include("context.jl") include("singleton.jl") include("mpi.jl") +include("loading.jl") end # module diff --git a/src/context.jl b/src/context.jl index 6ff13322..7cf2b514 100644 --- a/src/context.jl +++ b/src/context.jl @@ -1,14 +1,24 @@ import ..ClimaComms -""" - ClimaComms.mpi_ext_available() - -Returns true when the `ClimaComms` `ClimaCommsMPIExt` extension was loaded. - -To load `ClimaCommsMPIExt`, just load `ClimaComms` and `MPI`. -""" -function mpi_ext_available() - return !isnothing(Base.get_extension(ClimaComms, :ClimaCommsMPIExt)) +function context_type() + name = get(ENV, "CLIMACOMMS_CONTEXT", nothing) + if !isnothing(name) + if name == "MPI" + return :MPICommsContext + elseif name == "SINGLETON" + return :SingletonCommsContext + else + error("Invalid context: $name") + end + end + # detect common environment variables used by MPI launchers + # PMI_RANK appears to be used by MPICH and srun + # OMPI_COMM_WORLD_RANK appears to be used by OpenMPI + if haskey(ENV, "PMI_RANK") || haskey(ENV, "OMPI_COMM_WORLD_RANK") + return :MPICommsContext + else + return :SingletonCommsContext + end end """ @@ -23,29 +33,9 @@ it will return a [`SingletonCommsContext`](@ref). Behavior can be overridden by setting the `CLIMACOMMS_CONTEXT` environment variable to either `MPI` or `SINGLETON`. """ -function context(device = device()) - if !(mpi_ext_available()) - return SingletonCommsContext(device) - else - name = get(ENV, "CLIMACOMMS_CONTEXT", nothing) - if !isnothing(name) - if name == "MPI" - return MPICommsContext() - elseif name == "SINGLETON" - return SingletonCommsContext() - else - error("Invalid context: $name") - end - end - # detect common environment variables used by MPI launchers - # PMI_RANK appears to be used by MPICH and srun - # OMPI_COMM_WORLD_RANK appears to be used by OpenMPI - if haskey(ENV, "PMI_RANK") || haskey(ENV, "OMPI_COMM_WORLD_RANK") - return MPICommsContext(device) - else - return SingletonCommsContext(device) - end - end +function context(device = device(); target_context = context_type()) + ContextConstructor = getproperty(ClimaComms, target_context) + return ContextConstructor(device) end """ diff --git a/src/devices.jl b/src/devices.jl index 2cd77480..c2709496 100644 --- a/src/devices.jl +++ b/src/devices.jl @@ -36,17 +36,6 @@ Use NVIDIA GPU accelarator """ struct CUDADevice <: AbstractDevice end -""" - ClimaComms.cuda_ext_available() - -Returns true when the `ClimaComms` `ClimaCommsCUDAExt` extension was loaded. - -To load `ClimaCommsCUDAExt`, just load `ClimaComms` and `CUDA`. -""" -function cuda_ext_available() - return !isnothing(Base.get_extension(ClimaComms, :ClimaCommsCUDAExt)) -end - """ ClimaComms.device_functional(device) @@ -60,36 +49,28 @@ device_functional(::CPUMultiThreaded) = true """ ClimaComms.device() -Automatically determine the appropriate device to use, returning one of - - [`AbstractCPUDevice()`](@ref) - - [`CUDADevice()`](@ref) +Determine the device to use depending on the `CLIMACOMMS_DEVICE` environment variable. -By default, it will check if a functional CUDA installation exists, using CUDA if possible. +Allowed values: +- `CPU`, single-threaded or multi-threaded depending on the number of threads; +- `CPUSingleThreaded`, +- `CPUMultiThreaded`, +- `CUDA`. -Behavior can be overridden by setting the `CLIMACOMMS_DEVICE` environment variable to either `CPU` or `CUDA`. +The default is `CPU`. """ function device() - env_var = get(ENV, "CLIMACOMMS_DEVICE", nothing) - if !isnothing(env_var) - if env_var == "CPU" - return Threads.nthreads() > 1 ? CPUMultiThreaded() : - CPUSingleThreaded() - elseif env_var == "CPUSingleThreaded" - return CPUSingleThreaded() - elseif env_var == "CPUMultiThreaded" - return CPUMultiThreaded() - elseif env_var == "CUDA" - cuda_ext_available() || error("CUDA was not loaded") - return CUDADevice() - else - error("Invalid CLIMACOMMS_DEVICE: $env_var") - end - end - if cuda_ext_available() && device_functional(CUDADevice()) + env_var = get(ENV, "CLIMACOMMS_DEVICE", "CPU") + if env_var == "CPU" + return Threads.nthreads() > 1 ? CPUMultiThreaded() : CPUSingleThreaded() + elseif env_var == "CPUSingleThreaded" + return CPUSingleThreaded() + elseif env_var == "CPUMultiThreaded" + return CPUMultiThreaded() + elseif env_var == "CUDA" return CUDADevice() else - return Threads.nthreads() == 1 ? CPUSingleThreaded() : - CPUMultiThreaded() + error("Invalid CLIMACOMMS_DEVICE: $env_var") end end diff --git a/src/loading.jl b/src/loading.jl new file mode 100644 index 00000000..ce6c0c1f --- /dev/null +++ b/src/loading.jl @@ -0,0 +1,38 @@ +export import_required_backends + +function mpi_required() + return context_type() == :MPICommsContext +end + +function cuda_required() + return device() isa CUDADevice +end + +""" + ClimaComms.@import_required_backends + +If the desired context is MPI (as determined by `ClimaComms.context()`), try loading MPI.jl. +If the desired device is CUDA (as determined by `ClimaComms.device()`), try loading CUDA.jl. +""" +macro import_required_backends() + return quote + @static if $mpi_required() + try + import MPI + catch + error( + "Cannot load MPI.jl. Make sure it is included in your environment stack.", + ) + end + end + @static if $cuda_required() + try + import CUDA + catch + error( + "Cannot load CUDA.jl. Make sure it is included in your environment stack.", + ) + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index e0a085d8..5574fdd3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,14 +1,6 @@ using Test using ClimaComms - -if haskey(ENV, "CLIMACOMMS_TEST_DEVICE") && - ENV["CLIMACOMMS_TEST_DEVICE"] == "CUDA" - import CUDA - @test ClimaComms.cuda_ext_available() -end - -import MPI -@test ClimaComms.mpi_ext_available() +ClimaComms.@import_required_backends context = ClimaComms.context() pid, nprocs = ClimaComms.init(context)