From 520cf84477bec10c45e3249ffcccf796fe551b6f Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Tue, 23 Apr 2024 14:53:08 -0400 Subject: [PATCH] Fix `square_view` perf regression --- src/ptr_array.jl | 3 ++- src/views.jl | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/ptr_array.jl b/src/ptr_array.jl index d083c8c..eca6308 100644 --- a/src/ptr_array.jl +++ b/src/ptr_array.jl @@ -527,7 +527,8 @@ ArrayInterface.device( ::Type{<:AbstractStrideArray{<:Any,<:Any,R,<:Any,X}} ) where {R,X} i = findfirst(isone, R) - C = i === nothing ? -1 : (X.parameters[i] === Nothing ? i : -1) + C = i === nothing ? -1 : (((X.parameters[i] === Nothing) || (X.parameters[i] === One)) ? i : -1) + # C = i === nothing ? -1 : ((X.parameters[i] === Nothing) ? i : -1) StaticInt{C}() end ArrayInterface.contiguous_batch_size(::Type{<:AbstractStrideArray}) = diff --git a/src/views.jl b/src/views.jl index 00d3599..c0667b2 100644 --- a/src/views.jl +++ b/src/views.jl @@ -96,11 +96,13 @@ function rank_to_sortperm(R::NTuple{N,Int}) where {N} sp end rank_to_sortperm(R) = sortperm(R) - -Base.@propagate_inbounds function square_view(A::PtrMatrix, i) +_one_to_nothing(x::I) where {I} = StrideReset(x) +_one_to_nothing(::One) = nothing +Base.@propagate_inbounds function square_view(A::PtrMatrix{T,R,S,X,O}, i::I) where {T,R,S,X,O,I} # sizes = size(A) # @boundscheck i <= min(sizes[1], sizes[2]) || throw(BoundsError(A, (i, i))) - SquarePtrMatrix(pointer(A), i, static_strides(A), offsets(A)) + x = map(_one_to_nothing, static_strides(A)) + SquarePtrMatrix{T,R,I,typeof(x),O}(pointer(A), i, x, offsets(A)) end # Base.@propagate_inbounds function square_view(A::AbstractMatrix, i) # StrideArray(square_view(PtrArray(A), i), preserve_buffer(A))