Skip to content

Commit

Permalink
Define adapt for contexts and devices
Browse files Browse the repository at this point in the history
Update envs
  • Loading branch information
charleskawczynski committed Jan 9, 2025
1 parent 06d94a5 commit f4071f3
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 51 deletions.
14 changes: 14 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
ClimaComms.jl Release Notes
========================

main
-------

v0.6.5
-------

- Support for adapt was added, so that users can convert between CPU and GPU
devices, and contexts containing cpu and gpu devices [PR 103](https://github.com/CliMA/ClimaComms.jl/pull/103).

v0.6.4
-------

- Add device-flexible `@assert` was added [PR 86](https://github.com/CliMA/ClimaComms.jl/pull/86).

v0.6.3
-------

Expand Down
15 changes: 6 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
name = "ClimaComms"
uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
authors = [
"Kiran Pamnany <[email protected]>",
"Simon Byrne <[email protected]>",
"Charles Kawczynski <[email protected]>",
"Sriharsha Kandala <[email protected]>",
"Jake Bolewski <[email protected]>",
"Gabriele Bozzola <[email protected]>",
]
version = "0.6.4"
authors = ["Kiran Pamnany <[email protected]>", "Simon Byrne <[email protected]>", "Charles Kawczynski <[email protected]>", "Sriharsha Kandala <[email protected]>", "Jake Bolewski <[email protected]>", "Gabriele Bozzola <[email protected]>"]
version = "0.6.5"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand All @@ -20,5 +16,6 @@ ClimaCommsMPIExt = "MPI"

[compat]
CUDA = "3, 4, 5"
Adapt = "3, 4"
MPI = "0.20.18"
julia = "1.9"
107 changes: 65 additions & 42 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.2"
julia_version = "1.10.7"
manifest_format = "2.0"
project_hash = "c5b9e727593a1bc35ccae9b71e346465d8a7803c"

Expand All @@ -14,6 +14,18 @@ git-tree-sha1 = "2d9c9a55f9c93e8887ad391fbae72f8ef55e1177"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.4.5"

[[deps.Adapt]]
deps = ["LinearAlgebra", "Requires"]
git-tree-sha1 = "50c3c56a52972d78e8be9fd135bfb91c9574c140"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "4.1.1"

[deps.Adapt.extensions]
AdaptStaticArraysExt = "StaticArrays"

[deps.Adapt.weakdeps]
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
version = "1.1.1"
Expand All @@ -25,9 +37,10 @@ uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[deps.ClimaComms]]
deps = ["Adapt"]
path = ".."
uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.6.0"
version = "0.6.5"

[deps.ClimaComms.extensions]
ClimaCommsCUDAExt = "CUDA"
Expand All @@ -39,14 +52,14 @@ version = "0.6.0"

[[deps.CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
git-tree-sha1 = "59939d8a997469ee05c4b4944560a820f9ba0d73"
git-tree-sha1 = "bce6804e5e6044c6daab27bb533d1295e4a2e759"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.7.4"
version = "0.7.6"

[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.1.0+0"
version = "1.1.1+0"

[[deps.Dates]]
deps = ["Printf"]
Expand All @@ -64,9 +77,9 @@ version = "0.9.3"

[[deps.Documenter]]
deps = ["ANSIColoredPrinters", "AbstractTrees", "Base64", "CodecZlib", "Dates", "DocStringExtensions", "Downloads", "Git", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "MarkdownAST", "Pkg", "PrecompileTools", "REPL", "RegistryInstances", "SHA", "TOML", "Test", "Unicode"]
git-tree-sha1 = "f15a91e6e3919055efa4f206f942a73fedf5dfe6"
git-tree-sha1 = "d0ea2c044963ed6f37703cead7e29f70cba13d7e"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "1.4.0"
version = "1.8.0"

[[deps.Downloads]]
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
Expand All @@ -75,9 +88,9 @@ version = "1.6.0"

[[deps.Expat_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "4558ab818dcceaab612d1bb8c19cee87eda2b83c"
git-tree-sha1 = "e51db81749b0777b2147fbe7b783ee79045b8e99"
uuid = "2e619515-83b5-522b-bb60-26c02a35a201"
version = "2.5.0+0"
version = "2.6.4+3"

[[deps.FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
Expand All @@ -90,31 +103,31 @@ version = "1.3.1"

[[deps.Git_jll]]
deps = ["Artifacts", "Expat_jll", "JLLWrappers", "LibCURL_jll", "Libdl", "Libiconv_jll", "OpenSSL_jll", "PCRE2_jll", "Zlib_jll"]
git-tree-sha1 = "d18fb8a1f3609361ebda9bf029b60fd0f120c809"
git-tree-sha1 = "399f4a308c804b446ae4c91eeafadb2fe2c54ff9"
uuid = "f8c6e375-362e-5223-8a59-34ff63f689eb"
version = "2.44.0+2"
version = "2.47.1+0"

[[deps.Hwloc_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "ca0f6bf568b4bfc807e7537f081c81e35ceca114"
git-tree-sha1 = "50aedf345a709ab75872f80a2779568dc0bb461b"
uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8"
version = "2.10.0+0"
version = "2.11.2+3"

[[deps.IOCapture]]
deps = ["Logging", "Random"]
git-tree-sha1 = "8b72179abc660bfab5e28472e019392b97d0985c"
git-tree-sha1 = "b6d6bfdd7ce25b0f9b2f6b3dd56b2673a66c8770"
uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
version = "0.2.4"
version = "0.2.5"

[[deps.InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[deps.JLLWrappers]]
deps = ["Artifacts", "Preferences"]
git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
git-tree-sha1 = "a007feb38b422fbdab534406aeca1b86823cb4d6"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.5.0"
version = "1.7.0"

[[deps.JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
Expand All @@ -123,9 +136,9 @@ uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.4"

[[deps.LazilyInitializedFields]]
git-tree-sha1 = "8f7f3cabab0fd1800699663533b6d5cb3fc0e612"
git-tree-sha1 = "0f2da712350b020bc3957f269c9caad516383ee0"
uuid = "0e77f7df-68c5-4e49-93ce-4cd80f5598bf"
version = "1.2.2"
version = "1.3.0"

[[deps.LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
Expand Down Expand Up @@ -160,18 +173,22 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

[[deps.Libiconv_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175"
git-tree-sha1 = "61dfdba58e585066d8bce214c5a51eaa0539f269"
uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531"
version = "1.17.0+0"
version = "1.17.0+1"

[[deps.LinearAlgebra]]
deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[deps.MPI]]
deps = ["Distributed", "DocStringExtensions", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "PkgVersion", "PrecompileTools", "Requires", "Serialization", "Sockets"]
git-tree-sha1 = "4e3136db3735924f96632a5b40a5979f1f53fa07"
git-tree-sha1 = "892676019c58f34e38743bc989b0eca5bce5edc5"
uuid = "da04e1cc-30fd-572f-bb4f-1f8673147195"
version = "0.20.19"
version = "0.20.22"

[deps.MPI.extensions]
AMDGPUExt = "AMDGPU"
Expand All @@ -183,21 +200,21 @@ version = "0.20.19"

[[deps.MPICH_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"]
git-tree-sha1 = "d8a7bf80c88326ebc98b7d38437208c3a0f20725"
git-tree-sha1 = "7715e65c47ba3941c502bffb7f266a41a7f54423"
uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4"
version = "4.2.1+0"
version = "4.2.3+0"

[[deps.MPIPreferences]]
deps = ["Libdl", "Preferences"]
git-tree-sha1 = "8f6af051b9e8ec597fa09d8885ed79fd582f33c9"
git-tree-sha1 = "c105fe467859e7f6e9a852cb15cb4301126fac07"
uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267"
version = "0.1.10"
version = "0.1.11"

[[deps.MPItrampoline_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"]
git-tree-sha1 = "3f884417b47a96d87e7c6219f8f7b30ce67f4f2c"
git-tree-sha1 = "70e830dab5d0775183c99fc75e4c24c614ed7142"
uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748"
version = "5.3.3+0"
version = "5.5.1+2"

[[deps.Markdown]]
deps = ["Base64"]
Expand All @@ -216,9 +233,9 @@ version = "2.28.2+1"

[[deps.MicrosoftMPI_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "f12a29c4400ba812841c6ace3f4efbb6dbb3ba01"
git-tree-sha1 = "bc95bf4149bf535c09602e3acdf950d9b4376227"
uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf"
version = "10.1.4+2"
version = "10.1.4+3"

[[deps.Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
Expand All @@ -231,17 +248,22 @@ version = "2023.1.10"
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
version = "1.2.0"

[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.23+4"

[[deps.OpenMPI_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"]
git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762"
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML", "Zlib_jll"]
git-tree-sha1 = "2dace87e14256edb1dd0724ab7ba831c779b96bd"
uuid = "fe0851c0-eecd-5654-98d4-656369965a5c"
version = "4.1.6+0"
version = "5.0.6+0"

[[deps.OpenSSL_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "3da7367955dcc5c54c1ba4d402ccdc09a1a3e046"
git-tree-sha1 = "7493f61f55a6cce7325f197443aa80d32554ba10"
uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95"
version = "3.0.13+1"
version = "3.0.15+3"

[[deps.PCRE2_jll]]
deps = ["Artifacts", "Libdl"]
Expand Down Expand Up @@ -326,13 +348,9 @@ deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[deps.TranscodingStreams]]
git-tree-sha1 = "71509f04d045ec714c4748c785a59045c3736349"
git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.10.7"
weakdeps = ["Random", "Test"]

[deps.TranscodingStreams.extensions]
TestExt = ["Test", "Random"]
version = "0.11.3"

[[deps.UUIDs]]
deps = ["Random", "SHA"]
Expand All @@ -346,6 +364,11 @@ deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.13+1"

[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.11.0+0"

[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
Expand Down
2 changes: 2 additions & 0 deletions docs/src/apis.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ ClimaComms.@elapsed
ClimaComms.@assert
ClimaComms.@sync
ClimaComms.@cuda_sync
Adapt.adapt_structure(::Type{<:AbstractArray}, ::ClimaComms.AbstractDevice)
```

## Contexts
Expand All @@ -45,6 +46,7 @@ ClimaComms.MPICommsContext
ClimaComms.AbstractGraphContext
ClimaComms.context
ClimaComms.graph_context
Adapt.adapt_structure(::Type{<:AbstractArray}, ::ClimaComms.AbstractCommsContext)
```

## Context operations
Expand Down
13 changes: 13 additions & 0 deletions ext/ClimaCommsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module ClimaCommsCUDAExt

import CUDA

import Adapt
import ClimaComms
import ClimaComms: CUDADevice

Expand All @@ -14,6 +15,18 @@ function ClimaComms.device_functional(::CUDADevice)
return CUDA.functional()
end

function Adapt.adapt_structure(
to::Type{<:CUDA.CuArray},
context::ClimaComms.AbstractCommsContext,
)
return context(adapt(to, ClimaComms.device(context)))
end

Adapt.adapt_structure(
::Type{<:CUDA.CuArray},
device::ClimaComms.AbstractCPUDevice,
) = ClimaComms.CUDADevice()

ClimaComms.array_type(::CUDADevice) = CUDA.CuArray
ClimaComms.allowscalar(f, ::CUDADevice, args...; kwargs...) =
CUDA.@allowscalar f(args...; kwargs...)
Expand Down
1 change: 1 addition & 0 deletions src/ClimaComms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ include("context.jl")
include("singleton.jl")
include("mpi.jl")
include("loading.jl")
include("adapt.jl")

end # module
37 changes: 37 additions & 0 deletions src/adapt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import Adapt

"""
Adapt.adapt_structure(::Type{<:AbstractArray}, context::AbstractCommsContext)
Adapt a given context to a context with a device associated with the given array type.
# Example
```julia
Adapt.adapt_structure(Array, ClimaComms.context(ClimaComms.CUDADevice())) -> ClimaComms.CPUSingleThreaded()
```
!!! note
By default, adapting to `Array` creates a `CPUSingleThreaded` device, and
there is currently no way to conver to a CPUMultiThreaded device.
"""
Adapt.adapt_structure(to::Type{<:AbstractArray}, ctx::AbstractCommsContext) =
context(Adapt.adapt(to, device(ctx)))

"""
Adapt.adapt_structure(::Type{<:AbstractArray}, device::AbstractDevice)
Adapt a given device to a device associated with the given array type.
# Example
```julia
Adapt.adapt_structure(Array, ClimaComms.CUDADevice()) -> ClimaComms.CPUSingleThreaded()
```
!!! note
By default, adapting to `Array` creates a `CPUSingleThreaded` device, and
there is currently no way to conver to a CPUMultiThreaded device.
"""
Adapt.adapt_structure(::Type{<:AbstractArray}, device::AbstractDevice) =
CPUSingleThreaded()
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand Down
Loading

0 comments on commit f4071f3

Please sign in to comment.