Skip to content

Commit

Permalink
Load MPI and CUDA maybe
Browse files Browse the repository at this point in the history
  • Loading branch information
Sbozzolo committed Apr 25, 2024
1 parent b0d754e commit 3bf2e6f
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 76 deletions.
8 changes: 8 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ steps:
- srun julia --project=test test/runtests.jl
env:
CLIMACOMMS_TEST_DEVICE: CPU
CLIMACOMMS_CONTEXT: MPI
CLIMACOMMS_DEVICE: CPU
agents:
slurm_ntasks: 2

Expand All @@ -46,6 +48,7 @@ steps:
- julia --threads 4 --project=test test/runtests.jl
env:
CLIMACOMMS_TEST_DEVICE: CPU
CLIMACOMMS_DEVICE: CPU
agents:
slurm_cpus_per_task: 4

Expand All @@ -55,6 +58,8 @@ steps:
- srun julia --threads 4 --project=test test/runtests.jl
env:
CLIMACOMMS_TEST_DEVICE: CPU
CLIMACOMMS_CONTEXT: MPI
CLIMACOMMS_DEVICE: CPU
agents:
slurm_ntasks: 2
slurm_cpus_per_task: 4
Expand All @@ -65,6 +70,7 @@ steps:
- julia --project=test test/runtests.jl
env:
CLIMACOMMS_TEST_DEVICE: CUDA
CLIMACOMMS_DEVICE: CUDA
agents:
slurm_gpus: 1

Expand All @@ -74,6 +80,8 @@ steps:
- srun julia --project=test test/runtests.jl
env:
CLIMACOMMS_TEST_DEVICE: CUDA
CLIMACOMMS_CONTEXT: MPI
CLIMACOMMS_DEVICE: CUDA
agents:
slurm_gpus_per_task: 1
slurm_ntasks: 2
24 changes: 24 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
ClimaComms.jl Release Notes
========================

v0.6.0
-------

- ![][badge-💥breaking] `ClimaComms` does no longer try to guess the correct
compute device: the default is now CPU. To control which device to use,
use the `CLIMACOMMS_DEVICE` environment variable.
- ![][badge-💥breaking] `CUDA` and `MPI` are now extensions in `ClimaComms`. To
use `CUDA`/`MPI`, `CUDA.jl`/`MPI.jl` have to be loaded. A convenience macro
`ClimaComms.@import_required_backends` checks what device/context could be
used and conditionally loads `CUDA.jl`/`MPI.jl`. It is recommended to change
```julia
import ClimaComms
```
to
```julia
import ClimaComms
ClimaComms.@import_required_backends
```
This has to be done before calling `ClimaComms.context()`.

[badge-💥breaking]: https://img.shields.io/badge/💥BREAKING-red.svg
6 changes: 6 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ CurrentModule = ClimaComms
ClimaComms
```

## Loading

```@docs
ClimaComms.import_required_backends
```

## Devices

```@docs
Expand Down
1 change: 1 addition & 0 deletions src/ClimaComms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ include("devices.jl")
include("context.jl")
include("singleton.jl")
include("mpi.jl")
include("loading.jl")

end # module
54 changes: 22 additions & 32 deletions src/context.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
import ..ClimaComms

"""
ClimaComms.mpi_ext_available()
Returns true when the `ClimaComms` `ClimaCommsMPIExt` extension was loaded.
To load `ClimaCommsMPIExt`, just load `ClimaComms` and `MPI`.
"""
function mpi_ext_available()
return !isnothing(Base.get_extension(ClimaComms, :ClimaCommsMPIExt))
function context_type()
name = get(ENV, "CLIMACOMMS_CONTEXT", nothing)
if !isnothing(name)
if name == "MPI"
return :MPICommsContext
elseif name == "SINGLETON"
return :SingletonCommsContext
else
error("Invalid context: $name")
end
end
# detect common environment variables used by MPI launchers
# PMI_RANK appears to be used by MPICH and srun
# OMPI_COMM_WORLD_RANK appears to be used by OpenMPI
if haskey(ENV, "PMI_RANK") || haskey(ENV, "OMPI_COMM_WORLD_RANK")
return :MPICommsContext
else
return :SingletonCommsContext
end
end

"""
Expand All @@ -23,29 +33,9 @@ it will return a [`SingletonCommsContext`](@ref).
Behavior can be overridden by setting the `CLIMACOMMS_CONTEXT` environment variable
to either `MPI` or `SINGLETON`.
"""
function context(device = device())
if !(mpi_ext_available())
return SingletonCommsContext(device)
else
name = get(ENV, "CLIMACOMMS_CONTEXT", nothing)
if !isnothing(name)
if name == "MPI"
return MPICommsContext()
elseif name == "SINGLETON"
return SingletonCommsContext()
else
error("Invalid context: $name")
end
end
# detect common environment variables used by MPI launchers
# PMI_RANK appears to be used by MPICH and srun
# OMPI_COMM_WORLD_RANK appears to be used by OpenMPI
if haskey(ENV, "PMI_RANK") || haskey(ENV, "OMPI_COMM_WORLD_RANK")
return MPICommsContext(device)
else
return SingletonCommsContext(device)
end
end
function context(device = device(); target_context = context_type())
ContextConstructor = getproperty(ClimaComms, target_context)
return ContextConstructor(device)
end

"""
Expand Down
51 changes: 16 additions & 35 deletions src/devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,6 @@ Use NVIDIA GPU accelarator
"""
struct CUDADevice <: AbstractDevice end

"""
ClimaComms.cuda_ext_available()
Returns true when the `ClimaComms` `ClimaCommsCUDAExt` extension was loaded.
To load `ClimaCommsCUDAExt`, just load `ClimaComms` and `CUDA`.
"""
function cuda_ext_available()
return !isnothing(Base.get_extension(ClimaComms, :ClimaCommsCUDAExt))
end

"""
ClimaComms.device_functional(device)
Expand All @@ -60,36 +49,28 @@ device_functional(::CPUMultiThreaded) = true
"""
ClimaComms.device()
Automatically determine the appropriate device to use, returning one of
- [`AbstractCPUDevice()`](@ref)
- [`CUDADevice()`](@ref)
Determine the device to use depending on the `CLIMACOMMS_DEVICE` environment variable.
By default, it will check if a functional CUDA installation exists, using CUDA if possible.
Allowed values:
- `CPU`, single-threaded or multi-threaded depending on the number of threads;
- `CPUSingleThreaded`,
- `CPUMultiThreaded`,
- `CUDA`.
Behavior can be overridden by setting the `CLIMACOMMS_DEVICE` environment variable to either `CPU` or `CUDA`.
The default is `CPU`.
"""
function device()
env_var = get(ENV, "CLIMACOMMS_DEVICE", nothing)
if !isnothing(env_var)
if env_var == "CPU"
return Threads.nthreads() > 1 ? CPUMultiThreaded() :
CPUSingleThreaded()
elseif env_var == "CPUSingleThreaded"
return CPUSingleThreaded()
elseif env_var == "CPUMultiThreaded"
return CPUMultiThreaded()
elseif env_var == "CUDA"
cuda_ext_available() || error("CUDA was not loaded")
return CUDADevice()
else
error("Invalid CLIMACOMMS_DEVICE: $env_var")
end
end
if cuda_ext_available() && device_functional(CUDADevice())
env_var = get(ENV, "CLIMACOMMS_DEVICE", "CPU")
if env_var == "CPU"
return Threads.nthreads() > 1 ? CPUMultiThreaded() : CPUSingleThreaded()
elseif env_var == "CPUSingleThreaded"
return CPUSingleThreaded()
elseif env_var == "CPUMultiThreaded"
return CPUMultiThreaded()
elseif env_var == "CUDA"
return CUDADevice()
else
return Threads.nthreads() == 1 ? CPUSingleThreaded() :
CPUMultiThreaded()
error("Invalid CLIMACOMMS_DEVICE: $env_var")
end
end

Expand Down
38 changes: 38 additions & 0 deletions src/loading.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
export import_required_backends

function mpi_required()
return context_type() == :MPICommsContext
end

function cuda_required()
return device() isa CUDADevice
end

"""
ClimaComms.@import_required_backends
If the desired context is MPI (as determined by `ClimaComms.context()`), try loading MPI.jl.
If the desired device is CUDA (as determined by `ClimaComms.device()`), try loading CUDA.jl.
"""
macro import_required_backends()
return quote
@static if $mpi_required()
try
import MPI
catch
error(
"Cannot load MPI.jl. Make sure it is included in your environment stack.",
)
end
end
@static if $cuda_required()
try
import CUDA
catch
error(
"Cannot load CUDA.jl. Make sure it is included in your environment stack.",
)
end
end
end
end
10 changes: 1 addition & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
using Test
using ClimaComms

if haskey(ENV, "CLIMACOMMS_TEST_DEVICE") &&
ENV["CLIMACOMMS_TEST_DEVICE"] == "CUDA"
import CUDA
@test ClimaComms.cuda_ext_available()
end

import MPI
@test ClimaComms.mpi_ext_available()
ClimaComms.@import_required_backends

context = ClimaComms.context()
pid, nprocs = ClimaComms.init(context)
Expand Down

0 comments on commit 3bf2e6f

Please sign in to comment.