diff --git a/Project.toml b/Project.toml index e7b9e41..92b1266 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,14 @@ name = "StrideArraysCore" uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da" authors = ["Chris Elrod and contributors"] -version = "0.5.2" +version = "0.5.3" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CloseOpenIntervals = "fb6a15b2-703c-40df-9091-08a04967cfa9" IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" LayoutPointers = "10f19ff3-798f-405d-979b-55457f8fc047" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ManualMemory = "d125e4d3-2237-4719-b19c-fa641b8a4667" SIMDTypes = "94e857df-77ce-4151-89e5-788b33177be4" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" @@ -19,6 +20,7 @@ ArrayInterface = "7" CloseOpenIntervals = "0.1.2" IfElse = "0.1" LayoutPointers = "0.1.1" +LinearAlgebra = "1" ManualMemory = "0.1.6" SIMDTypes = "0.1" Static = "0.7, 0.8" diff --git a/src/ptr_array.jl b/src/ptr_array.jl index 3dcaa9b..2b1a58e 100644 --- a/src/ptr_array.jl +++ b/src/ptr_array.jl @@ -232,11 +232,29 @@ const BitPtrVector1{R,S,X} = AbstractPtrArray{Bool,1,R,S,X,NTuple{1,One},Bit} const BitPtrMatrix0{R,S,X} = AbstractPtrArray{Bool,2,R,S,X,NTuple{2,Zero},Bit} const BitPtrMatrix1{R,S,X} = AbstractPtrArray{Bool,2,R,S,X,NTuple{2,One},Bit} +struct SquarePtrMatrix{T,R,S,X,O} <: + AbstractPtrStrideArray{T,2,R,Tuple{S,S},X,O} + ptr::Ptr{T} + size::S + strides::X + offsets::O +end +SquarePtrMatrix(p::Ptr{T}, s::S) where {T,S} = + SquarePtrMatrix{T,(1, 2),S,Tuple{Nothing,Nothing},Tuple{One,One}}( + p, + s, + (nothing, nothing), + (One(), One()) + ) +import LinearAlgebra +LinearAlgebra.checksquare(A::SquarePtrMatrix) = getfield(A, :size) + @inline valisbit(::AbstractPtrArray{<:Any,<:Any,<:Any,<:Any,<:Any,<:Any,Bit}) = Val(true) @inline valisbit( ::AbstractPtrArray{<:Any,<:Any,<:Any,<:Any,<:Any,<:Any,<:Any} ) = Val(false) +@inline valisbit(::SquarePtrMatrix) = Val(false) # function PtrArray( # ptr::Ptr{T}, sizes::S, strides::X, offsets::O, ::Val{R} @@ -489,6 +507,10 @@ end @inline ArrayInterface.static_size(A::AbstractPtrStrideArray) = getfield(A, :sizes) +@inline function ArrayInterface.static_size(A::SquarePtrMatrix) + s = getfield(A, :size) + (s, s) +end @inline function ArrayInterface.static_strides( A::AbstractPtrStrideArray{<:Any,<:Any,R} ) where {R} diff --git a/test/runtests.jl b/test/runtests.jl index 5010fe7..e3cc784 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,7 +74,7 @@ end @testset "StrideArraysCore.jl" begin # Currently StrideArraysCore commits piracy with zero_offsets(A::AbstractArray) and preserve_buffer(A::MemoryBuffer) - Aqua.test_all(StrideArraysCore; piracies=false) + Aqua.test_all(StrideArraysCore; piracies = false) @testset "StrideArrays Basic" begin @@ -442,4 +442,11 @@ end @test C ≈ Ca end + @testset "square" begin + A = rand(5, 5) + GC.@preserve A begin + B = StrideArraysCore.SquarePtrMatrix(pointer(A), 5) + @test A == B + end + end end