From fc656eeab950a9913dddc2d028bdaa5d1f01a6bc Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Mon, 15 Apr 2024 12:06:44 -0400 Subject: [PATCH] `square_view` fallback method and `adjoint`&`transpose` support --- Project.toml | 2 +- src/adjoints.jl | 21 +++++++++++---------- src/ptr_array.jl | 13 ++++--------- src/views.jl | 12 ++++++++++++ 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index b07692d..7ce69c8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "StrideArraysCore" uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da" authors = ["Chris Elrod and contributors"] -version = "0.5.4" +version = "0.5.5" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/adjoints.jl b/src/adjoints.jl index c387341..53946aa 100644 --- a/src/adjoints.jl +++ b/src/adjoints.jl @@ -8,30 +8,31 @@ q = Expr( :block, Expr(:meta, :inline), - :(sz = $getfield(A, :sizes)), :(sx = $getfield(A, :strides)), :(o = $getfield(A, :offsets)) ) - sz_expr = Expr(:tuple) + squaretype = A <: SquarePtrMatrix + if squaretype + sz_expr = :($getfield(A, :size)) + PA = SquarePtrMatrix + else + push!(q.args, :(sz = $getfield(A, :sizes))) + sz_expr = Expr(:tuple) + PA = AbstractPtrArray + end sx_expr = Expr(:tuple) o_expr = Expr(:tuple) rv_expr = Expr(:tuple) for n = 1:N j = P[n] - push!(sz_expr.args, :($getfield(sz, $j))) + squaretype || push!(sz_expr.args, :($getfield(sz, $j))) push!(sx_expr.args, :($getfield(sx, $j))) push!(o_expr.args, :($getfield(o, $j))) push!(rv_expr.args, R[j]) end push!( q.args, - :(AbstractPtrArray( - pointer(A), - $sz_expr, - $sx_expr, - $o_expr, - Val{$rv_expr}() - )) + :($PA(pointer(A), $sz_expr, $sx_expr, $o_expr, Val{$rv_expr}())) ) q end diff --git a/src/ptr_array.jl b/src/ptr_array.jl index 8feffc4..f057865 100644 --- a/src/ptr_array.jl +++ b/src/ptr_array.jl @@ -243,18 +243,13 @@ end p::Ptr{T}, s::S, strides::Tuple{X0,X1} = (nothing, nothing), - offsets::Tuple{O0,O1} = (One(), One()) -) where {T,S,X0,X1,O0,O1} = - SquarePtrMatrix{T,(1, 2),S,Tuple{X0,X1},Tuple{O0,O1}}(p, s, strides, offsets) + offsets::Tuple{O0,O1} = (One(), One()), + ::Val{R} = Val((1, 2)) +) where {T,S,X0,X1,O0,O1,R} = + SquarePtrMatrix{T,R,S,Tuple{X0,X1},Tuple{O0,O1}}(p, s, strides, offsets) import LinearAlgebra LinearAlgebra.checksquare(A::SquarePtrMatrix) = getfield(A, :size) -Base.@propagate_inbounds function square_view(A::PtrMatrix, i) - sizes = size(A) - @boundscheck i <= min(sizes[1], sizes[2]) || throw(BoundsError(A, (i, i))) - SquarePtrMatrix(pointer(A), i, static_strides(A), offsets(A)) -end - @inline valisbit(::AbstractPtrArray{<:Any,<:Any,<:Any,<:Any,<:Any,<:Any,Bit}) = Val(true) @inline valisbit( diff --git a/src/views.jl b/src/views.jl index c2474ff..38c911f 100644 --- a/src/views.jl +++ b/src/views.jl @@ -96,3 +96,15 @@ function rank_to_sortperm(R::NTuple{N,Int}) where {N} sp end rank_to_sortperm(R) = sortperm(R) + +Base.@propagate_inbounds function square_view(A::PtrMatrix, i) + sizes = size(A) + @boundscheck i <= min(sizes[1], sizes[2]) || throw(BoundsError(A, (i, i))) + SquarePtrMatrix(pointer(A), i, static_strides(A), offsets(A)) +end +# Base.@propagate_inbounds function square_view(A::AbstractMatrix, i) +# StrideArray(square_view(PtrArray(A), i), preserve_buffer(A)) +# end +Base.@propagate_inbounds function square_view(A::AbstractMatrix, i) + @view(A[begin:begin-1+i, begin:begin-1+i]) +end