Skip to content

Commit

Permalink
handle complex datatype
Browse files Browse the repository at this point in the history
  • Loading branch information
nikopj committed Jul 19, 2024
1 parent e88e268 commit 6967c52
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 130 deletions.
1 change: 1 addition & 0 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,4 @@ ncclDataType_t(::Type{UInt64}) = ncclUint64
ncclDataType_t(::Type{Float16}) = ncclFloat16
ncclDataType_t(::Type{Float32}) = ncclFloat32
ncclDataType_t(::Type{Float64}) = ncclFloat64
ncclDataType_t(::Type{Complex{T}}) where {T} = ncclDataType_t(T)
29 changes: 16 additions & 13 deletions src/collective.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
count(X::CuArray{T}) where {T} = length(X)
count(X::CuArray{Complex{T}}) where {T} = 2*length(X)

"""
NCCL.Allreduce!(
sendbuf, recvbuf, op, comm::Communicator;
Expand All @@ -11,11 +14,11 @@ or [`NCCL.avg`](@ref)), writing the result to `recvbuf` to all ranks.
"""
function Allreduce!(sendbuf, recvbuf, op, comm::Communicator;
stream::CuStream=default_device_stream(comm))
count = length(recvbuf)
@assert length(sendbuf) == count
a_count = count(recvbuf)
@assert count(sendbuf) == a_count
data_type = ncclDataType_t(eltype(recvbuf))
_op = ncclRedOp_t(op)
ncclAllReduce(sendbuf, recvbuf, count, data_type, _op, comm, stream)
ncclAllReduce(sendbuf, recvbuf, a_count, data_type, _op, comm, stream)
return recvbuf
end

Expand Down Expand Up @@ -47,8 +50,8 @@ Copies array the `sendbuf` on rank `root` to `recvbuf` on all ranks.
function Broadcast!(sendbuf, recvbuf, comm::Communicator; root::Integer=0,
stream::CuStream=default_device_stream(comm))
data_type = ncclDataType_t(eltype(recvbuf))
count = length(recvbuf)
ncclBroadcast(sendbuf, recvbuf, count, data_type, root, comm, stream)
a_count = count(recvbuf)
ncclBroadcast(sendbuf, recvbuf, a_count, data_type, root, comm, stream)
return recvbuf
end
function Broadcast!(sendrecvbuf, comm::Communicator; root::Integer=0,
Expand All @@ -72,9 +75,9 @@ or `[`NCCL.avg`](@ref)`), writing the result to `recvbuf` on rank `root`.
function Reduce!(sendbuf, recvbuf, op, comm::Communicator; root::Integer=0,
stream::CuStream=default_device_stream(comm))
data_type = ncclDataType_t(eltype(recvbuf))
count = length(recvbuf)
a_count = count(recvbuf)
_op = ncclRedOp_t(op)
ncclReduce(sendbuf, recvbuf, count, data_type, _op, root, comm, stream)
ncclReduce(sendbuf, recvbuf, a_count, data_type, _op, root, comm, stream)
return recvbuf
end
function Reduce!(sendrecvbuf, op, comm::Communicator; root::Integer=0,
Expand All @@ -96,9 +99,9 @@ Concatenate `sendbuf` from each rank into `recvbuf` on all ranks.
function Allgather!(sendbuf, recvbuf, comm::Communicator;
stream::CuStream=default_device_stream(comm))
data_type = ncclDataType_t(eltype(recvbuf))
sendcount = length(sendbuf)
@assert length(recvbuf) == sendcount * size(comm)
ncclAllGather(sendbuf, recvbuf, sendcount, data_type, comm, stream)
senda_count = count(sendbuf)
@assert count(recvbuf) == senda_count * size(comm)
ncclAllGather(sendbuf, recvbuf, senda_count, data_type, comm, stream)
return recvbuf
end

Expand All @@ -117,10 +120,10 @@ scattered over the devices such that `recvbuf` on each rank will contain the
"""
function ReduceScatter!(sendbuf, recvbuf, op, comm::Communicator;
stream::CuStream=default_device_stream(comm))
recvcount = length(recvbuf)
@assert length(sendbuf) == recvcount * size(comm)
recva_count = count(recvbuf)
@assert count(sendbuf) == recva_count * size(comm)
data_type = ncclDataType_t(eltype(recvbuf))
_op = ncclRedOp_t(op)
ncclReduceScatter(sendbuf, recvbuf, recvcount, data_type, _op, comm, stream)
ncclReduceScatter(sendbuf, recvbuf, recva_count, data_type, _op, comm, stream)
return recvbuf
end
250 changes: 133 additions & 117 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,171 +26,187 @@ end
devs = CUDA.devices()
comms = NCCL.Communicators(devs)

@testset "sum" begin
recvbuf = Vector{CuVector{Float64}}(undef, length(devs))
sendbuf = Vector{CuVector{Float64}}(undef, length(devs))
N = 512
@testset "$T" for T in (Float64, ComplexF64)
@testset "sum" begin
recvbuf = Vector{CuVector{T}}(undef, length(devs))
sendbuf = Vector{CuVector{T}}(undef, length(devs))
N = 512
for (ii, dev) in enumerate(devs)
CUDA.device!(ii - 1)
sendbuf[ii] = CuArray(fill(T(ii), N))
recvbuf[ii] = CUDA.zeros(T, N)
end
NCCL.group() do
for ii in 1:length(devs)
NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], +, comms[ii])
end
end
answer = sum(1:length(devs))
for (ii, dev) in enumerate(devs)
device!(ii - 1)
crecv = collect(recvbuf[ii])
@test all(crecv .== answer)
end
end

@testset "NCCL.avg" begin
recvbuf = Vector{CuVector{T}}(undef, length(devs))
sendbuf = Vector{CuVector{T}}(undef, length(devs))
N = 512
for (ii, dev) in enumerate(devs)
CUDA.device!(ii - 1)
sendbuf[ii] = CuArray(fill(T(ii), N))
recvbuf[ii] = CUDA.zeros(T, N)
end
NCCL.group() do
for ii in 1:length(devs)
NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], NCCL.avg, comms[ii])
end
end
answer = sum(1:length(devs)) / length(devs)
for (ii, dev) in enumerate(devs)
device!(ii - 1)
crecv = collect(recvbuf[ii])
@test all(crecv .≈ answer)
end
end
end
end

@testset "Broadcast!" begin
devs = CUDA.devices()
comms = NCCL.Communicators(devs)

@testset "$T" for T in (Float64, ComplexF64)
recvbuf = Vector{CuVector{T}}(undef, length(devs))
sendbuf = Vector{CuVector{T}}(undef, length(devs))
root = 0
for (ii, dev) in enumerate(devs)
CUDA.device!(ii - 1)
sendbuf[ii] = CuArray(fill(Float64(ii), N))
recvbuf[ii] = CUDA.zeros(Float64, N)
sendbuf[ii] = (ii - 1) == root ? CuArray(fill(T(1.0), 512)) : CUDA.zeros(T, 512)
recvbuf[ii] = CUDA.zeros(T, 512)
end
NCCL.group() do
for ii in 1:length(devs)
NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], +, comms[ii])
NCCL.Broadcast!(sendbuf[ii], recvbuf[ii], comms[ii]; root)
end
end
answer = sum(1:length(devs))
answer = 1.0
for (ii, dev) in enumerate(devs)
device!(ii - 1)
crecv = collect(recvbuf[ii])
@test all(crecv .== answer)
end
end
end

@testset "NCCL.avg" begin
recvbuf = Vector{CuVector{Float64}}(undef, length(devs))
sendbuf = Vector{CuVector{Float64}}(undef, length(devs))
N = 512
@testset "Reduce!" begin
devs = CUDA.devices()
comms = NCCL.Communicators(devs)
@testset "$T" for T in (Float64, ComplexF64)
recvbuf = Vector{CuVector{T}}(undef, length(devs))
sendbuf = Vector{CuVector{T}}(undef, length(devs))
root = 0
for (ii, dev) in enumerate(devs)
CUDA.device!(ii - 1)
sendbuf[ii] = CuArray(fill(Float64(ii), N))
recvbuf[ii] = CUDA.zeros(Float64, N)
sendbuf[ii] = CuArray(fill(T(ii), 512))
recvbuf[ii] = CUDA.zeros(T, 512)
end
NCCL.group() do
for ii in 1:length(devs)
NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], NCCL.avg, comms[ii])
NCCL.Reduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]; root)
end
end
answer = sum(1:length(devs)) / length(devs)
for (ii, dev) in enumerate(devs)
answer = (ii - 1) == root ? sum(1:length(devs)) : 0.0
device!(ii - 1)
crecv = collect(recvbuf[ii])
@test all(crecv . answer)
@test all(crecv .== answer)
end
end
end

@testset "Broadcast!" begin
@testset "Allgather!" begin
devs = CUDA.devices()
comms = NCCL.Communicators(devs)
recvbuf = Vector{CuVector{Float64}}(undef, length(devs))
sendbuf = Vector{CuVector{Float64}}(undef, length(devs))
root = 0
for (ii, dev) in enumerate(devs)
CUDA.device!(ii - 1)
sendbuf[ii] = (ii - 1) == root ? CuArray(fill(Float64(1.0), 512)) : CUDA.zeros(Float64, 512)
recvbuf[ii] = CUDA.zeros(Float64, 512)
end
NCCL.group() do
for ii in 1:length(devs)
NCCL.Broadcast!(sendbuf[ii], recvbuf[ii], comms[ii]; root)
end
end
answer = 1.0
for (ii, dev) in enumerate(devs)
device!(ii - 1)
crecv = collect(recvbuf[ii])
@test all(crecv .== answer)
end
end

@testset "Reduce!" begin
devs = CUDA.devices()
comms = NCCL.Communicators(devs)
recvbuf = Vector{CuVector{Float64}}(undef, length(devs))
sendbuf = Vector{CuVector{Float64}}(undef, length(devs))
root = 0
for (ii, dev) in enumerate(devs)
CUDA.device!(ii - 1)
sendbuf[ii] = CuArray(fill(Float64(ii), 512))
recvbuf[ii] = CUDA.zeros(Float64, 512)
end
NCCL.group() do
for ii in 1:length(devs)
NCCL.Reduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]; root)
@testset "$T" for T in (Float64, ComplexF64)
recvbuf = Vector{CuVector{T}}(undef, length(devs))
sendbuf = Vector{CuVector{T}}(undef, length(devs))
for (ii, dev) in enumerate(devs)
CUDA.device!(ii - 1)
sendbuf[ii] = CuArray(fill(T(ii), 512))
recvbuf[ii] = CUDA.zeros(T, length(devs)*512)
end
end
for (ii, dev) in enumerate(devs)
answer = (ii - 1) == root ? sum(1:length(devs)) : 0.0
device!(ii - 1)
crecv = collect(recvbuf[ii])
@test all(crecv .== answer)
end
end

@testset "Allgather!" begin
devs = CUDA.devices()
comms = NCCL.Communicators(devs)
recvbuf = Vector{CuVector{Float64}}(undef, length(devs))
sendbuf = Vector{CuVector{Float64}}(undef, length(devs))
for (ii, dev) in enumerate(devs)
CUDA.device!(ii - 1)
sendbuf[ii] = CuArray(fill(Float64(ii), 512))
recvbuf[ii] = CUDA.zeros(Float64, length(devs)*512)
end
NCCL.group() do
for ii in 1:length(devs)
NCCL.Allgather!(sendbuf[ii], recvbuf[ii], comms[ii])
NCCL.group() do
for ii in 1:length(devs)
NCCL.Allgather!(sendbuf[ii], recvbuf[ii], comms[ii])
end
end
answer = vec(repeat(1:length(devs), inner=512))
for (ii, dev) in enumerate(devs)
device!(ii - 1)
crecv = collect(recvbuf[ii])
@test all(crecv .== answer)
end
end
answer = vec(repeat(1:length(devs), inner=512))
for (ii, dev) in enumerate(devs)
device!(ii - 1)
crecv = collect(recvbuf[ii])
@test all(crecv .== answer)
end
end

@testset "ReduceScatter!" begin
devs = CUDA.devices()
comms = NCCL.Communicators(devs)
recvbuf = Vector{CuVector{Float64}}(undef, length(devs))
sendbuf = Vector{CuVector{Float64}}(undef, length(devs))
for (ii, dev) in enumerate(devs)
CUDA.device!(ii - 1)
sendbuf[ii] = CuArray(vec(repeat(collect(1:length(devs)), inner=2)))
recvbuf[ii] = CUDA.zeros(Float64, 2)
end
NCCL.group() do
for ii in 1:length(devs)
NCCL.ReduceScatter!(sendbuf[ii], recvbuf[ii], +, comms[ii])

@testset "$T" for T in (Float64, ComplexF64)
recvbuf = Vector{CuVector{T}}(undef, length(devs))
sendbuf = Vector{CuVector{T}}(undef, length(devs))
for (ii, dev) in enumerate(devs)
CUDA.device!(ii - 1)
sendbuf[ii] = CuArray(vec(repeat(collect(1:length(devs)), inner=2)))
recvbuf[ii] = CUDA.zeros(T, 2)
end
NCCL.group() do
for ii in 1:length(devs)
NCCL.ReduceScatter!(sendbuf[ii], recvbuf[ii], +, comms[ii])
end
end
for (ii, dev) in enumerate(devs)
answer = length(devs)*ii
device!(ii - 1)
crecv = collect(recvbuf[ii])
@test all(crecv .== answer)
end
end
for (ii, dev) in enumerate(devs)
answer = length(devs)*ii
device!(ii - 1)
crecv = collect(recvbuf[ii])
@test all(crecv .== answer)
end
end

@testset "Send/Recv" begin
devs = CUDA.devices()
comms = NCCL.Communicators(devs)
recvbuf = Vector{CuVector{Float64}}(undef, length(devs))
sendbuf = Vector{CuVector{Float64}}(undef, length(devs))
N = 512
for (ii, dev) in enumerate(devs)
CUDA.device!(ii - 1)
sendbuf[ii] = CuArray(fill(Float64(ii), N))
recvbuf[ii] = CUDA.zeros(Float64, N)
end

NCCL.group() do
for ii in 1:length(devs)
comm = comms[ii]
dest = mod(NCCL.rank(comm)+1, NCCL.size(comm))
source = mod(NCCL.rank(comm)-1, NCCL.size(comm))
NCCL.Send(sendbuf[ii], comm; dest)
NCCL.Recv!(recvbuf[ii], comm; source)
@testset "$T" for T in (Float64, ComplexF64)
recvbuf = Vector{CuVector{T}}(undef, length(devs))
sendbuf = Vector{CuVector{T}}(undef, length(devs))
N = 512
for (ii, dev) in enumerate(devs)
CUDA.device!(ii - 1)
sendbuf[ii] = CuArray(fill(T(ii), N))
recvbuf[ii] = CUDA.zeros(T, N)
end

NCCL.group() do
for ii in 1:length(devs)
comm = comms[ii]
dest = mod(NCCL.rank(comm)+1, NCCL.size(comm))
source = mod(NCCL.rank(comm)-1, NCCL.size(comm))
NCCL.Send(sendbuf[ii], comm; dest)
NCCL.Recv!(recvbuf[ii], comm; source)
end
end
for (ii, dev) in enumerate(devs)
answer = mod1(ii - 1, length(devs))
device!(ii - 1)
crecv = collect(recvbuf[ii])
@test all(crecv .== answer)
end
end
for (ii, dev) in enumerate(devs)
answer = mod1(ii - 1, length(devs))
device!(ii - 1)
crecv = collect(recvbuf[ii])
@test all(crecv .== answer)
end
end

Expand Down

0 comments on commit 6967c52

Please sign in to comment.