Skip to content

Commit

Permalink
Load MPI and CUDA maybe
Browse files Browse the repository at this point in the history
  • Loading branch information
Sbozzolo committed Apr 25, 2024
1 parent b0d754e commit 71d8fa7
Show file tree
Hide file tree
Showing 11 changed files with 194 additions and 78 deletions.
8 changes: 8 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ steps:
- srun julia --project=test test/runtests.jl
env:
CLIMACOMMS_TEST_DEVICE: CPU
CLIMACOMMS_CONTEXT: MPI
CLIMACOMMS_DEVICE: CPU
agents:
slurm_ntasks: 2

Expand All @@ -46,6 +48,7 @@ steps:
- julia --threads 4 --project=test test/runtests.jl
env:
CLIMACOMMS_TEST_DEVICE: CPU
CLIMACOMMS_DEVICE: CPU
agents:
slurm_cpus_per_task: 4

Expand All @@ -55,6 +58,8 @@ steps:
- srun julia --threads 4 --project=test test/runtests.jl
env:
CLIMACOMMS_TEST_DEVICE: CPU
CLIMACOMMS_CONTEXT: MPI
CLIMACOMMS_DEVICE: CPU
agents:
slurm_ntasks: 2
slurm_cpus_per_task: 4
Expand All @@ -65,6 +70,7 @@ steps:
- julia --project=test test/runtests.jl
env:
CLIMACOMMS_TEST_DEVICE: CUDA
CLIMACOMMS_DEVICE: CUDA
agents:
slurm_gpus: 1

Expand All @@ -74,6 +80,8 @@ steps:
- srun julia --project=test test/runtests.jl
env:
CLIMACOMMS_TEST_DEVICE: CUDA
CLIMACOMMS_CONTEXT: MPI
CLIMACOMMS_DEVICE: CUDA
agents:
slurm_gpus_per_task: 1
slurm_ntasks: 2
24 changes: 24 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
ClimaComms.jl Release Notes
========================

v0.6.0
-------

- ![][badge-💥breaking] `ClimaComms` does no longer try to guess the correct
compute device: the default is now CPU. To control which device to use,
use the `CLIMACOMMS_DEVICE` environment variable.
- ![][badge-💥breaking] `CUDA` and `MPI` are now extensions in `ClimaComms`. To
use `CUDA`/`MPI`, `CUDA.jl`/`MPI.jl` have to be loaded. A convenience macro
`ClimaComms.@import_required_backends` checks what device/context could be
used and conditionally loads `CUDA.jl`/`MPI.jl`. It is recommended to change
```julia
import ClimaComms
```
to
```julia
import ClimaComms
ClimaComms.@import_required_backends
```
This has to be done before calling `ClimaComms.context()`.

[badge-💥breaking]: https://img.shields.io/badge/💥BREAKING-red.svg
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ authors = [
"Jake Bolewski <[email protected]>",
"Gabriele Bozzola <[email protected]>",
]
version = "0.5.9"
version = "0.6.0"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
77 changes: 76 additions & 1 deletion docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.2"
manifest_format = "2.0"
project_hash = "2a3f6f2093cc9e00b435b32de577d1b8ccb4f7ac"
project_hash = "c5b9e727593a1bc35ccae9b71e346465d8a7803c"

[[deps.ANSIColoredPrinters]]
git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c"
Expand Down Expand Up @@ -43,10 +43,19 @@ git-tree-sha1 = "59939d8a997469ee05c4b4944560a820f9ba0d73"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.7.4"

[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.1.0+0"

[[deps.Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[deps.Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[deps.DocStringExtensions]]
deps = ["LibGit2"]
git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d"
Expand Down Expand Up @@ -85,6 +94,12 @@ git-tree-sha1 = "d18fb8a1f3609361ebda9bf029b60fd0f120c809"
uuid = "f8c6e375-362e-5223-8a59-34ff63f689eb"
version = "2.44.0+2"

[[deps.Hwloc_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "ca0f6bf568b4bfc807e7537f081c81e35ceca114"
uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8"
version = "2.10.0+0"

[[deps.IOCapture]]
deps = ["Logging", "Random"]
git-tree-sha1 = "8b72179abc660bfab5e28472e019392b97d0985c"
Expand Down Expand Up @@ -112,6 +127,10 @@ git-tree-sha1 = "8f7f3cabab0fd1800699663533b6d5cb3fc0e612"
uuid = "0e77f7df-68c5-4e49-93ce-4cd80f5598bf"
version = "1.2.2"

[[deps.LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"

[[deps.LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
Expand Down Expand Up @@ -148,6 +167,38 @@ version = "1.17.0+0"
[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[deps.MPI]]
deps = ["Distributed", "DocStringExtensions", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "PkgVersion", "PrecompileTools", "Requires", "Serialization", "Sockets"]
git-tree-sha1 = "4e3136db3735924f96632a5b40a5979f1f53fa07"
uuid = "da04e1cc-30fd-572f-bb4f-1f8673147195"
version = "0.20.19"

[deps.MPI.extensions]
AMDGPUExt = "AMDGPU"
CUDAExt = "CUDA"

[deps.MPI.weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[[deps.MPICH_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"]
git-tree-sha1 = "d8a7bf80c88326ebc98b7d38437208c3a0f20725"
uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4"
version = "4.2.1+0"

[[deps.MPIPreferences]]
deps = ["Libdl", "Preferences"]
git-tree-sha1 = "8f6af051b9e8ec597fa09d8885ed79fd582f33c9"
uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267"
version = "0.1.10"

[[deps.MPItrampoline_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"]
git-tree-sha1 = "3f884417b47a96d87e7c6219f8f7b30ce67f4f2c"
uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748"
version = "5.3.3+0"

[[deps.Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
Expand All @@ -163,6 +214,12 @@ deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
version = "2.28.2+1"

[[deps.MicrosoftMPI_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "f12a29c4400ba812841c6ace3f4efbb6dbb3ba01"
uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf"
version = "10.1.4+2"

[[deps.Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

Expand All @@ -174,6 +231,12 @@ version = "2023.1.10"
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
version = "1.2.0"

[[deps.OpenMPI_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"]
git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762"
uuid = "fe0851c0-eecd-5654-98d4-656369965a5c"
version = "4.1.6+0"

[[deps.OpenSSL_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "3da7367955dcc5c54c1ba4d402ccdc09a1a3e046"
Expand All @@ -196,6 +259,12 @@ deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.10.0"

[[deps.PkgVersion]]
deps = ["Pkg"]
git-tree-sha1 = "f9501cc0430a26bc3d156ae1b5b0c1b47af4d6da"
uuid = "eebad327-c553-4316-9ea0-9fa01ccd7688"
version = "0.3.3"

[[deps.PrecompileTools]]
deps = ["Preferences"]
git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f"
Expand Down Expand Up @@ -226,6 +295,12 @@ git-tree-sha1 = "ffd19052caf598b8653b99404058fce14828be51"
uuid = "2792f1a3-b283-48e8-9a74-f99dce5104f3"
version = "0.1.0"

[[deps.Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.3.0"

[[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
version = "0.7.0"
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[deps]
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
6 changes: 6 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ CurrentModule = ClimaComms
ClimaComms
```

## Loading

```@docs
ClimaComms.import_required_backends
```

## Devices

```@docs
Expand Down
1 change: 1 addition & 0 deletions src/ClimaComms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ include("devices.jl")
include("context.jl")
include("singleton.jl")
include("mpi.jl")
include("loading.jl")

end # module
54 changes: 22 additions & 32 deletions src/context.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
import ..ClimaComms

"""
ClimaComms.mpi_ext_available()
Returns true when the `ClimaComms` `ClimaCommsMPIExt` extension was loaded.
To load `ClimaCommsMPIExt`, just load `ClimaComms` and `MPI`.
"""
function mpi_ext_available()
return !isnothing(Base.get_extension(ClimaComms, :ClimaCommsMPIExt))
function context_type()
name = get(ENV, "CLIMACOMMS_CONTEXT", nothing)
if !isnothing(name)
if name == "MPI"
return :MPICommsContext
elseif name == "SINGLETON"
return :SingletonCommsContext
else
error("Invalid context: $name")
end
end
# detect common environment variables used by MPI launchers
# PMI_RANK appears to be used by MPICH and srun
# OMPI_COMM_WORLD_RANK appears to be used by OpenMPI
if haskey(ENV, "PMI_RANK") || haskey(ENV, "OMPI_COMM_WORLD_RANK")
return :MPICommsContext
else
return :SingletonCommsContext
end
end

"""
Expand All @@ -23,29 +33,9 @@ it will return a [`SingletonCommsContext`](@ref).
Behavior can be overridden by setting the `CLIMACOMMS_CONTEXT` environment variable
to either `MPI` or `SINGLETON`.
"""
function context(device = device())
if !(mpi_ext_available())
return SingletonCommsContext(device)
else
name = get(ENV, "CLIMACOMMS_CONTEXT", nothing)
if !isnothing(name)
if name == "MPI"
return MPICommsContext()
elseif name == "SINGLETON"
return SingletonCommsContext()
else
error("Invalid context: $name")
end
end
# detect common environment variables used by MPI launchers
# PMI_RANK appears to be used by MPICH and srun
# OMPI_COMM_WORLD_RANK appears to be used by OpenMPI
if haskey(ENV, "PMI_RANK") || haskey(ENV, "OMPI_COMM_WORLD_RANK")
return MPICommsContext(device)
else
return SingletonCommsContext(device)
end
end
function context(device = device(); target_context = context_type())
ContextConstructor = getproperty(ClimaComms, target_context)
return ContextConstructor(device)
end

"""
Expand Down
51 changes: 16 additions & 35 deletions src/devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,6 @@ Use NVIDIA GPU accelarator
"""
struct CUDADevice <: AbstractDevice end

"""
ClimaComms.cuda_ext_available()
Returns true when the `ClimaComms` `ClimaCommsCUDAExt` extension was loaded.
To load `ClimaCommsCUDAExt`, just load `ClimaComms` and `CUDA`.
"""
function cuda_ext_available()
return !isnothing(Base.get_extension(ClimaComms, :ClimaCommsCUDAExt))
end

"""
ClimaComms.device_functional(device)
Expand All @@ -60,36 +49,28 @@ device_functional(::CPUMultiThreaded) = true
"""
ClimaComms.device()
Automatically determine the appropriate device to use, returning one of
- [`AbstractCPUDevice()`](@ref)
- [`CUDADevice()`](@ref)
Determine the device to use depending on the `CLIMACOMMS_DEVICE` environment variable.
By default, it will check if a functional CUDA installation exists, using CUDA if possible.
Allowed values:
- `CPU`, single-threaded or multi-threaded depending on the number of threads;
- `CPUSingleThreaded`,
- `CPUMultiThreaded`,
- `CUDA`.
Behavior can be overridden by setting the `CLIMACOMMS_DEVICE` environment variable to either `CPU` or `CUDA`.
The default is `CPU`.
"""
function device()
env_var = get(ENV, "CLIMACOMMS_DEVICE", nothing)
if !isnothing(env_var)
if env_var == "CPU"
return Threads.nthreads() > 1 ? CPUMultiThreaded() :
CPUSingleThreaded()
elseif env_var == "CPUSingleThreaded"
return CPUSingleThreaded()
elseif env_var == "CPUMultiThreaded"
return CPUMultiThreaded()
elseif env_var == "CUDA"
cuda_ext_available() || error("CUDA was not loaded")
return CUDADevice()
else
error("Invalid CLIMACOMMS_DEVICE: $env_var")
end
end
if cuda_ext_available() && device_functional(CUDADevice())
env_var = get(ENV, "CLIMACOMMS_DEVICE", "CPU")
if env_var == "CPU"
return Threads.nthreads() > 1 ? CPUMultiThreaded() : CPUSingleThreaded()
elseif env_var == "CPUSingleThreaded"
return CPUSingleThreaded()
elseif env_var == "CPUMultiThreaded"
return CPUMultiThreaded()
elseif env_var == "CUDA"
return CUDADevice()
else
return Threads.nthreads() == 1 ? CPUSingleThreaded() :
CPUMultiThreaded()
error("Invalid CLIMACOMMS_DEVICE: $env_var")
end
end

Expand Down
Loading

0 comments on commit 71d8fa7

Please sign in to comment.