Skip to content

Commit

Permalink
Add a simpler CuRefValue. (#2645)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Feb 11, 2025
1 parent 3250f1e commit 5461475
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 69 deletions.
1 change: 1 addition & 0 deletions src/CUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ include("device/quirks.jl")
# array essentials
include("memory.jl")
include("array.jl")
include("refpointer.jl")

# compiler libraries
include("../lib/cupti/CUPTI.jl")
Expand Down
8 changes: 4 additions & 4 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ end
# these are stored with a selector at the end (handled by Julia).
# 3. bitstype unions (`Union{Int, Float32}`, etc)
# these are stored contiguously and require a selector array (handled by us)
function check_eltype(T)
@inline function check_eltype(name, T)
if !Base.allocatedinline(T)
explanation = explain_eltype(T)
error("""
CuArray only supports element types that are allocated inline.
$name only supports element types that are allocated inline.
$explanation""")
end
end
Expand All @@ -63,7 +63,7 @@ mutable struct CuArray{T,N,M} <: AbstractGPUArray{T,N}
dims::Dims{N}

function CuArray{T,N,M}(::UndefInitializer, dims::Dims{N}) where {T,N,M}
check_eltype(T)
check_eltype("CuArray", T)
maxsize = prod(dims) * sizeof(T)
bufsize = if Base.isbitsunion(T)
# type tag array past the data
Expand All @@ -82,7 +82,7 @@ mutable struct CuArray{T,N,M} <: AbstractGPUArray{T,N}

function CuArray{T,N}(data::DataRef{Managed{M}}, dims::Dims{N};
maxsize::Int=prod(dims) * sizeof(T), offset::Int=0) where {T,N,M}
check_eltype(T)
check_eltype("CuArray", T)
obj = new{T,N,M}(data, maxsize, offset, dims)
finalizer(unsafe_free!, obj)
return obj
Expand Down
6 changes: 3 additions & 3 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,12 @@ end
# Note that it isn't safe to use unified or heterogeneous memory to support a
# mutable Ref, because there's no guarantee that the memory would be kept alive
# long enough (especially with broadcast using ephemeral Refs for scalar args).
struct CuRefValue{T} <: Ref{T}
struct KernelRefValue{T} <: Ref{T}
val::T
end
Base.getindex(r::CuRefValue{T}) where T = r.val
Base.getindex(r::KernelRefValue{T}) where T = r.val
Adapt.adapt_structure(to::KernelAdaptor, ref::Base.RefValue) =
CuRefValue(adapt(to, ref[]))
KernelRefValue(adapt(to, ref[]))

# broadcast sometimes passes a ref(type), resulting in a GPU-incompatible DataType box.
# avoid that by using a special kind of ref that knows about the boxed type.
Expand Down
59 changes: 1 addition & 58 deletions src/pointer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,68 +207,11 @@ Base.:(+)(x::Integer, y::CuArrayPtr) = y + x


#
# CUDA reference objects
# CUDA reference objects (forward declaration)
#

if sizeof(Ptr{Cvoid}) == 8
primitive type CuRef{T} 64 end
else
primitive type CuRef{T} 32 end
end

# general methods for CuRef{T} type
Base.eltype(x::Type{<:CuRef{T}}) where {T} = @isdefined(T) ? T : Any

Base.convert(::Type{CuRef{T}}, x::CuRef{T}) where {T} = x

# conversion or the actual ccall
Base.unsafe_convert(::Type{CuRef{T}}, x::CuRef{T}) where {T} = Base.bitcast(CuRef{T}, Base.unsafe_convert(CuPtr{T}, x))
Base.unsafe_convert(::Type{CuRef{T}}, x) where {T} = Base.bitcast(CuRef{T}, Base.unsafe_convert(CuPtr{T}, x))
## `@gcsafe_ccall` results in "double conversions" (remove this once `ccall` does `gcsafe`)
Base.unsafe_convert(::Type{CuPtr{T}}, x::CuRef{T}) where {T} = x

# CuRef from literal pointer
Base.convert(::Type{CuRef{T}}, x::CuPtr{T}) where {T} = x

# indirect constructors using CuRef
CuRef(x::Any) = CuRefArray(CuArray([x]))
CuRef{T}(x) where {T} = CuRefArray{T}(CuArray(T[x]))
CuRef{T}() where {T} = CuRefArray(CuArray{T}(undef, 1))
Base.convert(::Type{CuRef{T}}, x) where {T} = CuRef{T}(x)


## CuRef object backed by a CUDA array at index i

struct CuRefArray{T,A<:AbstractArray{T}} <: Ref{T}
x::A
i::Int
CuRefArray{T,A}(x,i) where {T,A<:AbstractArray{T}} = new(x,i)
end
CuRefArray{T}(x::AbstractArray{T}, i::Int=1) where {T} = CuRefArray{T,typeof(x)}(x, i)
CuRefArray(x::AbstractArray{T}, i::Int=1) where {T} = CuRefArray{T}(x, i)
Base.convert(::Type{CuRef{T}}, x::AbstractArray{T}) where {T} = CuRefArray(x, 1)
Base.convert(::Type{CuRef{T}}, x::CuRefArray{T}) where {T} = x

function Base.unsafe_convert(P::Type{CuPtr{T}}, b::CuRefArray{T}) where T
return pointer(b.x, b.i)
end
function Base.unsafe_convert(P::Type{CuPtr{Any}}, b::CuRefArray{Any})
return convert(P, pointer(b.x, b.i))
end
Base.unsafe_convert(::Type{CuPtr{Cvoid}}, b::CuRefArray{T}) where {T} =
convert(CuPtr{Cvoid}, Base.unsafe_convert(CuPtr{T}, b))

function Base.getindex(gpu::CuRefArray{T}) where {T}
cpu = Ref{T}()
GC.@preserve cpu begin
cpu_ptr = Base.unsafe_convert(Ptr{T}, cpu)
gpu_ptr = pointer(gpu.x, gpu.i)
unsafe_copyto!(cpu_ptr, gpu_ptr, 1)
end
cpu[]
end


## Union with all CuRef 'subtypes'

const CuRefs{T} = Union{CuPtr{T}, CuRefArray{T}}
141 changes: 141 additions & 0 deletions src/refpointer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# reference objects

abstract type AbstractCuRef{T} <: Ref{T} end

## opaque reference type
##
## we use a concrete CuRef type that actual references can be (no-op) converted to, without
## actually being a subtype of CuRef. This is necessary so that `CuRef` can be used in
## `ccall` signatures; which Base solves by special-casing `Ref` handing in `ccall.cpp`.
# forward declaration in pointer.jl

# general methods for CuRef{T} type
Base.eltype(x::Type{<:CuRef{T}}) where {T} = @isdefined(T) ? T : Any

Base.convert(::Type{CuRef{T}}, x::CuRef{T}) where {T} = x

# conversion or the actual ccall
Base.unsafe_convert(::Type{CuRef{T}}, x::CuRef{T}) where {T} = Base.bitcast(CuRef{T}, Base.unsafe_convert(CuPtr{T}, x))
Base.unsafe_convert(::Type{CuRef{T}}, x) where {T} = Base.bitcast(CuRef{T}, Base.unsafe_convert(CuPtr{T}, x))
## `@gcsafe_ccall` results in "double conversions" (remove this once `ccall` does `gcsafe`)
Base.unsafe_convert(::Type{CuPtr{T}}, x::CuRef{T}) where {T} = x

# CuRef from literal pointer
Base.convert(::Type{CuRef{T}}, x::CuPtr{T}) where {T} = x

# indirect constructors using CuRef
CuRef(x::Any) = CuRefValue(x)
CuRef{T}(x) where {T} = CuRefValue{T}(x)
CuRef{T}() where {T} = CuRefValue{T}()
Base.convert(::Type{CuRef{T}}, x) where {T} = CuRef{T}(x)

# idempotency
Base.convert(::Type{CuRef{T}}, x::AbstractCuRef{T}) where {T} = x


## reference backed by a single allocation

# TODO: maintain a small global cache of reference boxes

mutable struct CuRefValue{T} <: AbstractCuRef{T}
buf::Managed{DeviceMemory}

function CuRefValue{T}() where {T}
check_eltype("CuRef", T)
buf = pool_alloc(DeviceMemory, sizeof(T))
obj = new(buf)
finalizer(obj) do _
pool_free(buf)
end
return obj
end
end
function CuRefValue{T}(x::T) where {T}
ref = CuRefValue{T}()
ref[] = x
return ref
end
CuRefValue{T}(x) where {T} = CuRefValue{T}(convert(T, x))
CuRefValue(x::T) where {T} = CuRefValue{T}(x)

Base.unsafe_convert(::Type{CuPtr{T}}, b::CuRefValue{T}) where {T} = convert(CuPtr{T}, b.buf)
Base.unsafe_convert(P::Type{CuPtr{Any}}, b::CuRefValue{Any}) = convert(P, b.buf)
Base.unsafe_convert(::Type{CuPtr{Cvoid}}, b::CuRefValue{T}) where {T} =
convert(CuPtr{Cvoid}, b.buf)

function Base.setindex!(gpu::CuRefValue{T}, x::T) where {T}
cpu = Ref(x)
GC.@preserve cpu begin
cpu_ptr = Base.unsafe_convert(Ptr{T}, cpu)
gpu_ptr = Base.unsafe_convert(CuPtr{T}, gpu)
unsafe_copyto!(gpu_ptr, cpu_ptr, 1; async=true)
end
return gpu
end

function Base.getindex(gpu::CuRefValue{T}) where {T}
# synchronize first to maximize time spent executing Julia code
synchronize(gpu.buf)

cpu = Ref{T}()
GC.@preserve cpu begin
cpu_ptr = Base.unsafe_convert(Ptr{T}, cpu)
gpu_ptr = Base.unsafe_convert(CuPtr{T}, gpu)
unsafe_copyto!(cpu_ptr, gpu_ptr, 1; async=false)
end
cpu[]
end

function Base.show(io::IO, x::CuRefValue{T}) where {T}
print(io, "CuRefValue{$T}(")
print(io, x[])
print(io, ")")
end


## reference backed by a CUDA array at index i

struct CuRefArray{T,A<:AbstractArray{T}} <: AbstractCuRef{T}
x::A
i::Int
CuRefArray{T,A}(x,i) where {T,A<:AbstractArray{T}} = new(x,i)
end
CuRefArray{T}(x::AbstractArray{T}, i::Int=1) where {T} = CuRefArray{T,typeof(x)}(x, i)
CuRefArray(x::AbstractArray{T}, i::Int=1) where {T} = CuRefArray{T}(x, i)

Base.convert(::Type{CuRef{T}}, x::AbstractArray{T}) where {T} = CuRefArray(x, 1)
Base.convert(::Type{CuRef{T}}, x::CuRefArray{T}) where {T} = x

Base.unsafe_convert(P::Type{CuPtr{T}}, b::CuRefArray{T}) where {T} = pointer(b.x, b.i)
Base.unsafe_convert(P::Type{CuPtr{Any}}, b::CuRefArray{Any}) = convert(P, pointer(b.x, b.i))
Base.unsafe_convert(::Type{CuPtr{Cvoid}}, b::CuRefArray{T}) where {T} =
convert(CuPtr{Cvoid}, Base.unsafe_convert(CuPtr{T}, b))

function Base.setindex!(gpu::CuRefArray{T}, x::T) where {T}
cpu = Ref(x)
GC.@preserve cpu begin
cpu_ptr = Base.unsafe_convert(Ptr{T}, cpu)
gpu_ptr = pointer(gpu.x, gpu.i)
unsafe_copyto!(gpu_ptr, cpu_ptr, 1; async=true)
end
return gpu
end

function Base.getindex(gpu::CuRefArray{T}) where {T}
# synchronize first to maximize time spent executing Julia code
synchronize(gpu.x)

cpu = Ref{T}()
GC.@preserve cpu begin
cpu_ptr = Base.unsafe_convert(Ptr{T}, cpu)
gpu_ptr = pointer(gpu.x, gpu.i)
unsafe_copyto!(cpu_ptr, gpu_ptr, 1; async=false)
end
cpu[]
end

function Base.show(io::IO, x::CuRefArray{T}) where {T}
print(io, "CuRefArray{$T}(")
print(io, x[])
print(io, ")")
end
8 changes: 4 additions & 4 deletions test/libraries/cublas/level1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,15 @@ k = 13
@test CUBLAS.iamin(ca) == 3
result_type = CUBLAS.version() >= v"12.0" ? Int64 : Cint
result = CuRef{result_type}(0)
result = CUBLAS.iamax(ca, result)
@test BLAS.iamax(a) == only(Array(result.x))
CUBLAS.iamax(ca, result)
@test BLAS.iamax(a) == result[]
end
@testset "nrm2 with result" begin
x = rand(T, m)
dx = CuArray(x)
result = CuRef{real(T)}(zero(real(T)))
result = CUBLAS.nrm2(dx, result)
@test norm(x) only(Array(result.x))
CUBLAS.nrm2(dx, result)
@test norm(x) result[]
end
end # level 1 testset
@testset for T in [Float16, ComplexF16]
Expand Down

0 comments on commit 5461475

Please sign in to comment.