Skip to content

Commit

Permalink
Add unsafe_getindex and unsafe_setindex!
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Apr 16, 2024
1 parent fc656ee commit f2fb46b
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "StrideArraysCore"
uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da"
authors = ["Chris Elrod <[email protected]> and contributors"]
version = "0.5.5"
version = "0.5.6"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
86 changes: 72 additions & 14 deletions src/ptr_array.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Base: @propagate_inbounds

@generated function permtuple(x::Tuple, ::Val{R}) where {R}
t = Expr(:tuple)
Expand Down Expand Up @@ -898,6 +899,17 @@ Base.@propagate_inbounds function Base.getindex(
PtrArray(A)[i...]
end
end
@inline unsafe_getindex(A::AbstractStrideVector, i::Integer, ::Integer) =
unsafe_getindex(A, i)
@inline function unsafe_getindex(
A::AbstractStrideArray,
i::Vararg{Union{Integer,StaticInt},K}
) where {K}
b = preserve_buffer(A)
GC.@preserve b begin
unsafe_getindex(PtrArray(A), i...)
end
end
Base.@propagate_inbounds function Base.getindex(
A::AbstractStrideArray,
i::Vararg{Union{Integer,StaticInt,Colon,AbstractRange},K}
Expand All @@ -914,8 +926,18 @@ Base.@propagate_inbounds function Base.setindex!(
PtrArray(A)[i...] = v
end
end
@inline function unsafe_setindex!(
A::AbstractStrideArray,
v,
i::Vararg{Union{Integer,StaticInt},K}
) where {K}
b = preserve_buffer(A)
GC.@preserve b begin
unsafe_setindex!(PtrArray(A), v, i...)
end
end

@inline _offset_dense(i::Tuple{}, s::Tuple{}) = Zero()
@inline _offset_dense(::Tuple{}, ::Tuple{}) = Zero()
@inline _offset_dense(i::Tuple{I}, s::Tuple{S}) where {I,S} = i[1] * s[1]
@inline _offset_dense(
i::Tuple{I,J,Vararg},
Expand Down Expand Up @@ -953,41 +975,77 @@ end
_offset_ptr_padded(p, j, strides(A))
end
end

@inline function Base.getindex(A::PtrArray, i::Vararg{Integer})
@boundscheck checkbounds(A, i...)
@inline function unsafe_getindex(A::PtrArray, i::Vararg{Integer})
pload(_offset_ptr(A, i))
end
@inline function Base.setindex!(A::PtrArray, v, i::Vararg{Integer,K}) where {K}
@boundscheck checkbounds(A, i...)
@inline function unsafe_setindex!(
A::PtrArray,
v,
i::Vararg{Integer,K}
) where {K}
pstore!(_offset_ptr(A, i), v)
v
end
@inline function Base.getindex(A::PtrArray{T}, i::Integer) where {T}
@boundscheck checkbounds(A, i)
@inline function unsafe_getindex(A::PtrArray{T}, i::Integer) where {T}
pload(pointer(A) + (i - oneunit(i)) * static_sizeof(T))
end
@inline function Base.setindex!(A::PtrArray{T}, v, i::Integer) where {T}
@boundscheck checkbounds(A, i)
@inline function unsafe_setindex!(A::PtrArray{T}, v, i::Integer) where {T}
pstore!(pointer(A) + (i - oneunit(i)) * static_sizeof(T), v)
v
end
@inline function Base.getindex(A::PtrVector{T}, i::Integer) where {T}
@boundscheck checkbounds(A, i)
@inline function unsafe_getindex(A::PtrVector{T}, i::Integer) where {T}
pload(
pointer(A) +
(i - ArrayInterface.offset1(A)) * only(LayoutPointers.bytestrides(A))
)
end
@inline function Base.setindex!(A::PtrVector{T}, v, i::Integer) where {T}
@boundscheck checkbounds(A, i)
@inline function unsafe_setindex!(A::PtrVector{T}, v, i::Integer) where {T}
pstore!(
pointer(A) +
(i - ArrayInterface.offset1(A)) * only(LayoutPointers.bytestrides(A)),
v
)
v
end
@propagate_inbounds function Base.getindex(A::PtrArray, i::Vararg{Integer})
@boundscheck checkbounds(A, i...)
unsafe_getindex(A, i...)
end
@propagate_inbounds function Base.setindex!(
A::PtrArray,
v,
i::Vararg{Integer,K}
) where {K}
@boundscheck checkbounds(A, i...)
unsafe_setindex!(A, v, i...)
end
@propagate_inbounds function Base.getindex(A::PtrArray{T}, i::Integer) where {T}
@boundscheck checkbounds(A, i)
unsafe_getindex(A, i)
end
@propagate_inbounds function Base.setindex!(
A::PtrArray{T},
v,
i::Integer
) where {T}
@boundscheck checkbounds(A, i)
unsafe_setindex!(A, v, i)
end
@propagate_inbounds function Base.getindex(
A::PtrVector{T},
i::Integer
) where {T}
@boundscheck checkbounds(A, i)
unsafe_getindex(A, i)
end
@propagate_inbounds function Base.setindex!(
A::PtrVector{T},
v,
i::Integer
) where {T}
@boundscheck checkbounds(A, i)
unsafe_setindex!(A, v, i)
end

_scale(::False, x, _, __) = x
@inline function _scale(::True, x, num, denom)
Expand Down
4 changes: 2 additions & 2 deletions src/views.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ end
rank_to_sortperm(R) = sortperm(R)

Base.@propagate_inbounds function square_view(A::PtrMatrix, i)
sizes = size(A)
@boundscheck i <= min(sizes[1], sizes[2]) || throw(BoundsError(A, (i, i)))
# sizes = size(A)
# @boundscheck i <= min(sizes[1], sizes[2]) || throw(BoundsError(A, (i, i)))
SquarePtrMatrix(pointer(A), i, static_strides(A), offsets(A))
end
# Base.@propagate_inbounds function square_view(A::AbstractMatrix, i)
Expand Down

0 comments on commit f2fb46b

Please sign in to comment.