diff --git a/Project.toml b/Project.toml index 4644639..9a82eb1 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,7 @@ authors = [ "Jake Bolewski ", "Gabriele Bozzola ", ] -version = "0.6.4" +version = "0.6.5" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/ext/ClimaCommsMPIExt.jl b/ext/ClimaCommsMPIExt.jl index d04732b..d75e958 100644 --- a/ext/ClimaCommsMPIExt.jl +++ b/ext/ClimaCommsMPIExt.jl @@ -1,10 +1,26 @@ module ClimaCommsMPIExt import MPI +import ClimaComms: mpicomm, set_mpicomm! import ClimaComms -ClimaComms.MPICommsContext(device = ClimaComms.device()) = - ClimaComms.MPICommsContext(device, MPI.COMM_WORLD) +const CLIMA_COMM_WORLD = Ref{typeof(MPI.COMM_WORLD)}() + +function set_mpicomm!() + CLIMA_COMM_WORLD[] = MPI.COMM_WORLD + return CLIMA_COMM_WORLD[] +end + +ClimaComms.mpicomm(::ClimaComms.MPICommsContext) = CLIMA_COMM_WORLD[] + +# For backwards compatibility +function Base.getproperty(ctx::ClimaComms.MPICommsContext, sym::Symbol) + if sym === :mpicomm + return ClimaComms.mpicomm(ctx) + else + return getfield(ctx, sym) + end +end function ClimaComms.init(ctx::ClimaComms.MPICommsContext) if !MPI.Initialized() @@ -19,9 +35,9 @@ function ClimaComms.init(ctx::ClimaComms.MPICommsContext) end # assign GPUs based on local rank local_comm = MPI.Comm_split_type( - ctx.mpicomm, + mpicomm(ctx), MPI.COMM_TYPE_SHARED, - MPI.Comm_rank(ctx.mpicomm), + MPI.Comm_rank(mpicomm(ctx)), ) ClimaComms._assign_device(ctx.device, MPI.Comm_rank(local_comm)) MPI.free(local_comm) @@ -32,36 +48,36 @@ end ClimaComms.device(ctx::ClimaComms.MPICommsContext) = ctx.device ClimaComms.mypid(ctx::ClimaComms.MPICommsContext) = - MPI.Comm_rank(ctx.mpicomm) + 1 + MPI.Comm_rank(mpicomm(ctx)) + 1 ClimaComms.iamroot(ctx::ClimaComms.MPICommsContext) = ClimaComms.mypid(ctx) == 1 -ClimaComms.nprocs(ctx::ClimaComms.MPICommsContext) = MPI.Comm_size(ctx.mpicomm) +ClimaComms.nprocs(ctx::ClimaComms.MPICommsContext) = MPI.Comm_size(mpicomm(ctx)) -ClimaComms.barrier(ctx::ClimaComms.MPICommsContext) = MPI.Barrier(ctx.mpicomm) +ClimaComms.barrier(ctx::ClimaComms.MPICommsContext) = MPI.Barrier(mpicomm(ctx)) ClimaComms.reduce(ctx::ClimaComms.MPICommsContext, val, op) = - MPI.Reduce(val, op, 0, ctx.mpicomm) + MPI.Reduce(val, op, 0, mpicomm(ctx)) ClimaComms.reduce!(ctx::ClimaComms.MPICommsContext, sendbuf, recvbuf, op) = - MPI.Reduce!(sendbuf, recvbuf, op, ctx.mpicomm; root = 0) + MPI.Reduce!(sendbuf, recvbuf, op, mpicomm(ctx); root = 0) ClimaComms.reduce!(ctx::ClimaComms.MPICommsContext, sendrecvbuf, op) = - MPI.Reduce!(sendrecvbuf, op, ctx.mpicomm; root = 0) + MPI.Reduce!(sendrecvbuf, op, mpicomm(ctx); root = 0) ClimaComms.allreduce(ctx::ClimaComms.MPICommsContext, sendbuf, op) = - MPI.Allreduce(sendbuf, op, ctx.mpicomm) + MPI.Allreduce(sendbuf, op, mpicomm(ctx)) ClimaComms.allreduce!(ctx::ClimaComms.MPICommsContext, sendbuf, recvbuf, op) = - MPI.Allreduce!(sendbuf, recvbuf, op, ctx.mpicomm) + MPI.Allreduce!(sendbuf, recvbuf, op, mpicomm(ctx)) ClimaComms.allreduce!(ctx::ClimaComms.MPICommsContext, sendrecvbuf, op) = - MPI.Allreduce!(sendrecvbuf, op, ctx.mpicomm) + MPI.Allreduce!(sendrecvbuf, op, mpicomm(ctx)) ClimaComms.bcast(ctx::ClimaComms.MPICommsContext, object) = - MPI.bcast(object, ctx.mpicomm; root = 0) + MPI.bcast(object, mpicomm(ctx); root = 0) function ClimaComms.gather(ctx::ClimaComms.MPICommsContext, array) dims = size(array) - lengths = MPI.Gather(dims[end], 0, ctx.mpicomm) + lengths = MPI.Gather(dims[end], 0, mpicomm(ctx)) if ClimaComms.iamroot(ctx) dimsout = (dims[1:(end - 1)]..., sum(lengths)) arrayout = similar(array, dimsout) @@ -69,11 +85,11 @@ function ClimaComms.gather(ctx::ClimaComms.MPICommsContext, array) else recvbuf = nothing end - MPI.Gatherv!(array, recvbuf, 0, ctx.mpicomm) + MPI.Gatherv!(array, recvbuf, 0, mpicomm(ctx)) end ClimaComms.abort(ctx::ClimaComms.MPICommsContext, status::Int) = - MPI.Abort(ctx.mpicomm, status) + MPI.Abort(mpicomm(ctx), status) # We could probably do something fancier here? @@ -171,7 +187,7 @@ function graph_context( for n in 1:length(recv_bufs) MPI.Recv_init( recv_bufs[n], - ctx.mpicomm, + mpicomm(ctx), recv_reqs[n]; source = recv_ranks[n], tag = tag, @@ -181,7 +197,7 @@ function graph_context( for n in 1:length(send_bufs) MPI.Send_init( send_bufs[n], - ctx.mpicomm, + mpicomm(ctx), send_reqs[n]; dest = send_ranks[n], tag = tag, @@ -226,7 +242,7 @@ function ClimaComms.start( ghost.recv_bufs[n], ghost.recv_ranks[n], ghost.tag, - ghost.ctx.mpicomm, + mpicomm(ghost.ctx), ghost.recv_reqs[n], ) end @@ -236,7 +252,7 @@ function ClimaComms.start( ghost.send_bufs[n], ghost.send_ranks[n], ghost.tag, - ghost.ctx.mpicomm, + mpicomm(ghost.ctx), ghost.send_reqs[n], ) end @@ -254,9 +270,9 @@ function ClimaComms.progress( ghost::Union{MPISendRecvGraphContext, MPIPersistentSendRecvGraphContext}, ) if isdefined(MPI, :MPI_ANY_SOURCE) # < v0.20 - MPI.Iprobe(MPI.MPI_ANY_SOURCE, ghost.tag, ghost.ctx.mpicomm) + MPI.Iprobe(MPI.MPI_ANY_SOURCE, ghost.tag, mpicomm(ghost.ctx)) else # >= v0.20 - MPI.Iprobe(MPI.ANY_SOURCE, ghost.tag, ghost.ctx.mpicomm) + MPI.Iprobe(MPI.ANY_SOURCE, ghost.tag, mpicomm(ghost.ctx)) end end diff --git a/src/mpi.jl b/src/mpi.jl index 524cc2c..189c6fa 100644 --- a/src/mpi.jl +++ b/src/mpi.jl @@ -6,9 +6,15 @@ A MPI communications context, used for distributed runs. [`AbstractCPUDevice`](@ref) and [`CUDADevice`](@ref) device options are currently supported. """ -struct MPICommsContext{D <: AbstractDevice, C} <: AbstractCommsContext +struct MPICommsContext{D <: AbstractDevice} <: AbstractCommsContext device::D - mpicomm::C + function MPICommsContext(dev::AbstractDevice = device()) + set_mpicomm!() + return new{typeof(dev)}(dev) + end end function MPICommsContext end + +function set_mpicomm! end +function mpicomm end