From f2fb46b7f371d9842946ea7dd8ed1baf7aa62d46 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Tue, 16 Apr 2024 16:59:40 -0400 Subject: [PATCH] Add unsafe_getindex and unsafe_setindex! --- Project.toml | 2 +- src/ptr_array.jl | 86 ++++++++++++++++++++++++++++++++++++++++-------- src/views.jl | 4 +-- 3 files changed, 75 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index 7ce69c8..6712f65 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "StrideArraysCore" uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da" authors = ["Chris Elrod and contributors"] -version = "0.5.5" +version = "0.5.6" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/ptr_array.jl b/src/ptr_array.jl index f057865..1d892fe 100644 --- a/src/ptr_array.jl +++ b/src/ptr_array.jl @@ -1,3 +1,4 @@ +using Base: @propagate_inbounds @generated function permtuple(x::Tuple, ::Val{R}) where {R} t = Expr(:tuple) @@ -898,6 +899,17 @@ Base.@propagate_inbounds function Base.getindex( PtrArray(A)[i...] end end +@inline unsafe_getindex(A::AbstractStrideVector, i::Integer, ::Integer) = + unsafe_getindex(A, i) +@inline function unsafe_getindex( + A::AbstractStrideArray, + i::Vararg{Union{Integer,StaticInt},K} +) where {K} + b = preserve_buffer(A) + GC.@preserve b begin + unsafe_getindex(PtrArray(A), i...) + end +end Base.@propagate_inbounds function Base.getindex( A::AbstractStrideArray, i::Vararg{Union{Integer,StaticInt,Colon,AbstractRange},K} @@ -914,8 +926,18 @@ Base.@propagate_inbounds function Base.setindex!( PtrArray(A)[i...] = v end end +@inline function unsafe_setindex!( + A::AbstractStrideArray, + v, + i::Vararg{Union{Integer,StaticInt},K} +) where {K} + b = preserve_buffer(A) + GC.@preserve b begin + unsafe_setindex!(PtrArray(A), v, i...) + end +end -@inline _offset_dense(i::Tuple{}, s::Tuple{}) = Zero() +@inline _offset_dense(::Tuple{}, ::Tuple{}) = Zero() @inline _offset_dense(i::Tuple{I}, s::Tuple{S}) where {I,S} = i[1] * s[1] @inline _offset_dense( i::Tuple{I,J,Vararg}, @@ -953,34 +975,31 @@ end _offset_ptr_padded(p, j, strides(A)) end end - -@inline function Base.getindex(A::PtrArray, i::Vararg{Integer}) - @boundscheck checkbounds(A, i...) +@inline function unsafe_getindex(A::PtrArray, i::Vararg{Integer}) pload(_offset_ptr(A, i)) end -@inline function Base.setindex!(A::PtrArray, v, i::Vararg{Integer,K}) where {K} - @boundscheck checkbounds(A, i...) +@inline function unsafe_setindex!( + A::PtrArray, + v, + i::Vararg{Integer,K} +) where {K} pstore!(_offset_ptr(A, i), v) v end -@inline function Base.getindex(A::PtrArray{T}, i::Integer) where {T} - @boundscheck checkbounds(A, i) +@inline function unsafe_getindex(A::PtrArray{T}, i::Integer) where {T} pload(pointer(A) + (i - oneunit(i)) * static_sizeof(T)) end -@inline function Base.setindex!(A::PtrArray{T}, v, i::Integer) where {T} - @boundscheck checkbounds(A, i) +@inline function unsafe_setindex!(A::PtrArray{T}, v, i::Integer) where {T} pstore!(pointer(A) + (i - oneunit(i)) * static_sizeof(T), v) v end -@inline function Base.getindex(A::PtrVector{T}, i::Integer) where {T} - @boundscheck checkbounds(A, i) +@inline function unsafe_getindex(A::PtrVector{T}, i::Integer) where {T} pload( pointer(A) + (i - ArrayInterface.offset1(A)) * only(LayoutPointers.bytestrides(A)) ) end -@inline function Base.setindex!(A::PtrVector{T}, v, i::Integer) where {T} - @boundscheck checkbounds(A, i) +@inline function unsafe_setindex!(A::PtrVector{T}, v, i::Integer) where {T} pstore!( pointer(A) + (i - ArrayInterface.offset1(A)) * only(LayoutPointers.bytestrides(A)), @@ -988,6 +1007,45 @@ end ) v end +@propagate_inbounds function Base.getindex(A::PtrArray, i::Vararg{Integer}) + @boundscheck checkbounds(A, i...) + unsafe_getindex(A, i...) +end +@propagate_inbounds function Base.setindex!( + A::PtrArray, + v, + i::Vararg{Integer,K} +) where {K} + @boundscheck checkbounds(A, i...) + unsafe_setindex!(A, v, i...) +end +@propagate_inbounds function Base.getindex(A::PtrArray{T}, i::Integer) where {T} + @boundscheck checkbounds(A, i) + unsafe_getindex(A, i) +end +@propagate_inbounds function Base.setindex!( + A::PtrArray{T}, + v, + i::Integer +) where {T} + @boundscheck checkbounds(A, i) + unsafe_setindex!(A, v, i) +end +@propagate_inbounds function Base.getindex( + A::PtrVector{T}, + i::Integer +) where {T} + @boundscheck checkbounds(A, i) + unsafe_getindex(A, i) +end +@propagate_inbounds function Base.setindex!( + A::PtrVector{T}, + v, + i::Integer +) where {T} + @boundscheck checkbounds(A, i) + unsafe_setindex!(A, v, i) +end _scale(::False, x, _, __) = x @inline function _scale(::True, x, num, denom) diff --git a/src/views.jl b/src/views.jl index 38c911f..00d3599 100644 --- a/src/views.jl +++ b/src/views.jl @@ -98,8 +98,8 @@ end rank_to_sortperm(R) = sortperm(R) Base.@propagate_inbounds function square_view(A::PtrMatrix, i) - sizes = size(A) - @boundscheck i <= min(sizes[1], sizes[2]) || throw(BoundsError(A, (i, i))) + # sizes = size(A) + # @boundscheck i <= min(sizes[1], sizes[2]) || throw(BoundsError(A, (i, i))) SquarePtrMatrix(pointer(A), i, static_strides(A), offsets(A)) end # Base.@propagate_inbounds function square_view(A::AbstractMatrix, i)