From 71d8fa74e4dcb0bba5783111f6e014a744840629 Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Wed, 24 Apr 2024 15:56:52 -0700 Subject: [PATCH] Load MPI and CUDA maybe --- .buildkite/pipeline.yml | 8 +++++ NEWS.md | 24 +++++++++++++ Project.toml | 2 +- docs/Manifest.toml | 77 ++++++++++++++++++++++++++++++++++++++++- docs/Project.toml | 1 + docs/src/index.md | 6 ++++ src/ClimaComms.jl | 1 + src/context.jl | 54 ++++++++++++----------------- src/devices.jl | 51 +++++++++------------------ src/loading.jl | 38 ++++++++++++++++++++ test/runtests.jl | 10 +----- 11 files changed, 194 insertions(+), 78 deletions(-) create mode 100644 NEWS.md create mode 100644 src/loading.jl diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index ce0c8e7..fc0053c 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -37,6 +37,8 @@ steps: - srun julia --project=test test/runtests.jl env: CLIMACOMMS_TEST_DEVICE: CPU + CLIMACOMMS_CONTEXT: MPI + CLIMACOMMS_DEVICE: CPU agents: slurm_ntasks: 2 @@ -46,6 +48,7 @@ steps: - julia --threads 4 --project=test test/runtests.jl env: CLIMACOMMS_TEST_DEVICE: CPU + CLIMACOMMS_DEVICE: CPU agents: slurm_cpus_per_task: 4 @@ -55,6 +58,8 @@ steps: - srun julia --threads 4 --project=test test/runtests.jl env: CLIMACOMMS_TEST_DEVICE: CPU + CLIMACOMMS_CONTEXT: MPI + CLIMACOMMS_DEVICE: CPU agents: slurm_ntasks: 2 slurm_cpus_per_task: 4 @@ -65,6 +70,7 @@ steps: - julia --project=test test/runtests.jl env: CLIMACOMMS_TEST_DEVICE: CUDA + CLIMACOMMS_DEVICE: CUDA agents: slurm_gpus: 1 @@ -74,6 +80,8 @@ steps: - srun julia --project=test test/runtests.jl env: CLIMACOMMS_TEST_DEVICE: CUDA + CLIMACOMMS_CONTEXT: MPI + CLIMACOMMS_DEVICE: CUDA agents: slurm_gpus_per_task: 1 slurm_ntasks: 2 diff --git a/NEWS.md b/NEWS.md new file mode 100644 index 0000000..3f26845 --- /dev/null +++ b/NEWS.md @@ -0,0 +1,24 @@ +ClimaComms.jl Release Notes +======================== + +v0.6.0 +------- + +- ![][badge-💥breaking] `ClimaComms` does no longer try to guess the correct + compute device: the default is now CPU. To control which device to use, + use the `CLIMACOMMS_DEVICE` environment variable. +- ![][badge-💥breaking] `CUDA` and `MPI` are now extensions in `ClimaComms`. To + use `CUDA`/`MPI`, `CUDA.jl`/`MPI.jl` have to be loaded. A convenience macro + `ClimaComms.@import_required_backends` checks what device/context could be + used and conditionally loads `CUDA.jl`/`MPI.jl`. It is recommended to change + ```julia + import ClimaComms + ``` + to + ```julia + import ClimaComms + ClimaComms.@import_required_backends + ``` + This has to be done before calling `ClimaComms.context()`. + +[badge-💥breaking]: https://img.shields.io/badge/💥BREAKING-red.svg diff --git a/Project.toml b/Project.toml index 7f3a1ec..be50f10 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,7 @@ authors = [ "Jake Bolewski ", "Gabriele Bozzola ", ] -version = "0.5.9" +version = "0.6.0" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 9740e5c..4f64e84 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.2" manifest_format = "2.0" -project_hash = "2a3f6f2093cc9e00b435b32de577d1b8ccb4f7ac" +project_hash = "c5b9e727593a1bc35ccae9b71e346465d8a7803c" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" @@ -43,10 +43,19 @@ git-tree-sha1 = "59939d8a997469ee05c4b4944560a820f9ba0d73" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" version = "0.7.4" +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.1.0+0" + [[deps.Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + [[deps.DocStringExtensions]] deps = ["LibGit2"] git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" @@ -85,6 +94,12 @@ git-tree-sha1 = "d18fb8a1f3609361ebda9bf029b60fd0f120c809" uuid = "f8c6e375-362e-5223-8a59-34ff63f689eb" version = "2.44.0+2" +[[deps.Hwloc_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "ca0f6bf568b4bfc807e7537f081c81e35ceca114" +uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" +version = "2.10.0+0" + [[deps.IOCapture]] deps = ["Logging", "Random"] git-tree-sha1 = "8b72179abc660bfab5e28472e019392b97d0985c" @@ -112,6 +127,10 @@ git-tree-sha1 = "8f7f3cabab0fd1800699663533b6d5cb3fc0e612" uuid = "0e77f7df-68c5-4e49-93ce-4cd80f5598bf" version = "1.2.2" +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" @@ -148,6 +167,38 @@ version = "1.17.0+0" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +[[deps.MPI]] +deps = ["Distributed", "DocStringExtensions", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "PkgVersion", "PrecompileTools", "Requires", "Serialization", "Sockets"] +git-tree-sha1 = "4e3136db3735924f96632a5b40a5979f1f53fa07" +uuid = "da04e1cc-30fd-572f-bb4f-1f8673147195" +version = "0.20.19" + + [deps.MPI.extensions] + AMDGPUExt = "AMDGPU" + CUDAExt = "CUDA" + + [deps.MPI.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + +[[deps.MPICH_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "d8a7bf80c88326ebc98b7d38437208c3a0f20725" +uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" +version = "4.2.1+0" + +[[deps.MPIPreferences]] +deps = ["Libdl", "Preferences"] +git-tree-sha1 = "8f6af051b9e8ec597fa09d8885ed79fd582f33c9" +uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" +version = "0.1.10" + +[[deps.MPItrampoline_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "3f884417b47a96d87e7c6219f8f7b30ce67f4f2c" +uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" +version = "5.3.3+0" + [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" @@ -163,6 +214,12 @@ deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" version = "2.28.2+1" +[[deps.MicrosoftMPI_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "f12a29c4400ba812841c6ace3f4efbb6dbb3ba01" +uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" +version = "10.1.4+2" + [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" @@ -174,6 +231,12 @@ version = "2023.1.10" uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" +[[deps.OpenMPI_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762" +uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" +version = "4.1.6+0" + [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "3da7367955dcc5c54c1ba4d402ccdc09a1a3e046" @@ -196,6 +259,12 @@ deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", " uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" version = "1.10.0" +[[deps.PkgVersion]] +deps = ["Pkg"] +git-tree-sha1 = "f9501cc0430a26bc3d156ae1b5b0c1b47af4d6da" +uuid = "eebad327-c553-4316-9ea0-9fa01ccd7688" +version = "0.3.3" + [[deps.PrecompileTools]] deps = ["Preferences"] git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" @@ -226,6 +295,12 @@ git-tree-sha1 = "ffd19052caf598b8653b99404058fce14828be51" uuid = "2792f1a3-b283-48e8-9a74-f99dce5104f3" version = "0.1.0" +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" diff --git a/docs/Project.toml b/docs/Project.toml index f85258f..435d745 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,4 @@ [deps] ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" diff --git a/docs/src/index.md b/docs/src/index.md index 26f69cd..d158f3e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -8,6 +8,12 @@ CurrentModule = ClimaComms ClimaComms ``` +## Loading + +```@docs +ClimaComms.import_required_backends +``` + ## Devices ```@docs diff --git a/src/ClimaComms.jl b/src/ClimaComms.jl index 3d4084c..a860523 100644 --- a/src/ClimaComms.jl +++ b/src/ClimaComms.jl @@ -15,5 +15,6 @@ include("devices.jl") include("context.jl") include("singleton.jl") include("mpi.jl") +include("loading.jl") end # module diff --git a/src/context.jl b/src/context.jl index 6ff1332..7cf2b51 100644 --- a/src/context.jl +++ b/src/context.jl @@ -1,14 +1,24 @@ import ..ClimaComms -""" - ClimaComms.mpi_ext_available() - -Returns true when the `ClimaComms` `ClimaCommsMPIExt` extension was loaded. - -To load `ClimaCommsMPIExt`, just load `ClimaComms` and `MPI`. -""" -function mpi_ext_available() - return !isnothing(Base.get_extension(ClimaComms, :ClimaCommsMPIExt)) +function context_type() + name = get(ENV, "CLIMACOMMS_CONTEXT", nothing) + if !isnothing(name) + if name == "MPI" + return :MPICommsContext + elseif name == "SINGLETON" + return :SingletonCommsContext + else + error("Invalid context: $name") + end + end + # detect common environment variables used by MPI launchers + # PMI_RANK appears to be used by MPICH and srun + # OMPI_COMM_WORLD_RANK appears to be used by OpenMPI + if haskey(ENV, "PMI_RANK") || haskey(ENV, "OMPI_COMM_WORLD_RANK") + return :MPICommsContext + else + return :SingletonCommsContext + end end """ @@ -23,29 +33,9 @@ it will return a [`SingletonCommsContext`](@ref). Behavior can be overridden by setting the `CLIMACOMMS_CONTEXT` environment variable to either `MPI` or `SINGLETON`. """ -function context(device = device()) - if !(mpi_ext_available()) - return SingletonCommsContext(device) - else - name = get(ENV, "CLIMACOMMS_CONTEXT", nothing) - if !isnothing(name) - if name == "MPI" - return MPICommsContext() - elseif name == "SINGLETON" - return SingletonCommsContext() - else - error("Invalid context: $name") - end - end - # detect common environment variables used by MPI launchers - # PMI_RANK appears to be used by MPICH and srun - # OMPI_COMM_WORLD_RANK appears to be used by OpenMPI - if haskey(ENV, "PMI_RANK") || haskey(ENV, "OMPI_COMM_WORLD_RANK") - return MPICommsContext(device) - else - return SingletonCommsContext(device) - end - end +function context(device = device(); target_context = context_type()) + ContextConstructor = getproperty(ClimaComms, target_context) + return ContextConstructor(device) end """ diff --git a/src/devices.jl b/src/devices.jl index 2cd7748..c270949 100644 --- a/src/devices.jl +++ b/src/devices.jl @@ -36,17 +36,6 @@ Use NVIDIA GPU accelarator """ struct CUDADevice <: AbstractDevice end -""" - ClimaComms.cuda_ext_available() - -Returns true when the `ClimaComms` `ClimaCommsCUDAExt` extension was loaded. - -To load `ClimaCommsCUDAExt`, just load `ClimaComms` and `CUDA`. -""" -function cuda_ext_available() - return !isnothing(Base.get_extension(ClimaComms, :ClimaCommsCUDAExt)) -end - """ ClimaComms.device_functional(device) @@ -60,36 +49,28 @@ device_functional(::CPUMultiThreaded) = true """ ClimaComms.device() -Automatically determine the appropriate device to use, returning one of - - [`AbstractCPUDevice()`](@ref) - - [`CUDADevice()`](@ref) +Determine the device to use depending on the `CLIMACOMMS_DEVICE` environment variable. -By default, it will check if a functional CUDA installation exists, using CUDA if possible. +Allowed values: +- `CPU`, single-threaded or multi-threaded depending on the number of threads; +- `CPUSingleThreaded`, +- `CPUMultiThreaded`, +- `CUDA`. -Behavior can be overridden by setting the `CLIMACOMMS_DEVICE` environment variable to either `CPU` or `CUDA`. +The default is `CPU`. """ function device() - env_var = get(ENV, "CLIMACOMMS_DEVICE", nothing) - if !isnothing(env_var) - if env_var == "CPU" - return Threads.nthreads() > 1 ? CPUMultiThreaded() : - CPUSingleThreaded() - elseif env_var == "CPUSingleThreaded" - return CPUSingleThreaded() - elseif env_var == "CPUMultiThreaded" - return CPUMultiThreaded() - elseif env_var == "CUDA" - cuda_ext_available() || error("CUDA was not loaded") - return CUDADevice() - else - error("Invalid CLIMACOMMS_DEVICE: $env_var") - end - end - if cuda_ext_available() && device_functional(CUDADevice()) + env_var = get(ENV, "CLIMACOMMS_DEVICE", "CPU") + if env_var == "CPU" + return Threads.nthreads() > 1 ? CPUMultiThreaded() : CPUSingleThreaded() + elseif env_var == "CPUSingleThreaded" + return CPUSingleThreaded() + elseif env_var == "CPUMultiThreaded" + return CPUMultiThreaded() + elseif env_var == "CUDA" return CUDADevice() else - return Threads.nthreads() == 1 ? CPUSingleThreaded() : - CPUMultiThreaded() + error("Invalid CLIMACOMMS_DEVICE: $env_var") end end diff --git a/src/loading.jl b/src/loading.jl new file mode 100644 index 0000000..ce6c0c1 --- /dev/null +++ b/src/loading.jl @@ -0,0 +1,38 @@ +export import_required_backends + +function mpi_required() + return context_type() == :MPICommsContext +end + +function cuda_required() + return device() isa CUDADevice +end + +""" + ClimaComms.@import_required_backends + +If the desired context is MPI (as determined by `ClimaComms.context()`), try loading MPI.jl. +If the desired device is CUDA (as determined by `ClimaComms.device()`), try loading CUDA.jl. +""" +macro import_required_backends() + return quote + @static if $mpi_required() + try + import MPI + catch + error( + "Cannot load MPI.jl. Make sure it is included in your environment stack.", + ) + end + end + @static if $cuda_required() + try + import CUDA + catch + error( + "Cannot load CUDA.jl. Make sure it is included in your environment stack.", + ) + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index e0a085d..5574fdd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,14 +1,6 @@ using Test using ClimaComms - -if haskey(ENV, "CLIMACOMMS_TEST_DEVICE") && - ENV["CLIMACOMMS_TEST_DEVICE"] == "CUDA" - import CUDA - @test ClimaComms.cuda_ext_available() -end - -import MPI -@test ClimaComms.mpi_ext_available() +ClimaComms.@import_required_backends context = ClimaComms.context() pid, nprocs = ClimaComms.init(context)