diff --git a/src/array.jl b/src/array.jl index 6cb4855a..cade38c5 100644 --- a/src/array.jl +++ b/src/array.jl @@ -6,16 +6,11 @@ mutable struct ROCArray{T, N, B} <: AbstractGPUArray{T, N} function ROCArray{T, N, B}(::UndefInitializer, dims::Dims{N}) where {T, N, B <: Mem.AbstractAMDBuffer} @assert isbitstype(T) "ROCArray only supports bits types" sz::Int64 = prod(dims) * sizeof(T) - x = GPUArrays.cached_alloc((ROCArray, AMDGPU.device(), T, B, sz)) do + ref = GPUArrays.cached_alloc((ROCArray, AMDGPU.device(), B, sz)) do @debug "Allocate `T=$T`, `dims=$dims`: $(Base.format_bytes(sz))" - data = DataRef(pool_free, pool_alloc(B, sz)) - return finalizer(unsafe_free!, new{T, N, B}(data, dims, 0)) + DataRef(pool_free, pool_alloc(B, sz)) end - return if size(x) != dims - reshape(x, dims) - else - x - end::ROCArray{T, N, B} + return finalizer(unsafe_free!, new{T, N, B}(ref, dims, 0)) end function ROCArray{T, N}(buf::DataRef{Managed{B}}, dims::Dims{N}; offset::Integer = 0) where {T, N, B <: Mem.AbstractAMDBuffer} diff --git a/src/memory.jl b/src/memory.jl index 524087f8..59de3b68 100644 --- a/src/memory.jl +++ b/src/memory.jl @@ -217,6 +217,8 @@ function synchronize(m::Managed) return end +Base.sizeof(m::Managed) = sizeof(m.mem) + function Base.convert(::Type{Ptr{T}}, managed::Managed{M}) where {T, M} strm = AMDGPU.stream() diff --git a/src/runtime/memory/hip.jl b/src/runtime/memory/hip.jl index 33241676..ba1af9a9 100644 --- a/src/runtime/memory/hip.jl +++ b/src/runtime/memory/hip.jl @@ -69,6 +69,8 @@ function HIPBuffer(ptr::Ptr{Cvoid}, bytesize::Int) HIPBuffer(s.device, s.ctx, ptr, bytesize, false) end +Base.sizeof(b::HIPBuffer) = UInt64(b.bytesize) + Base.convert(::Type{Ptr{T}}, buf::HIPBuffer) where T = convert(Ptr{T}, buf.ptr) function view(buf::HIPBuffer, bytesize::Int) @@ -137,6 +139,8 @@ function HostBuffer( HostBuffer(stream.device, stream.ctx, ptr, dev_ptr, sz, false) end +Base.sizeof(b::HostBuffer) = UInt64(b.bytesize) + function view(buf::HostBuffer, bytesize::Int) bytesize > buf.bytesize && throw(BoundsError(buf, bytesize)) HostBuffer(