From 6a7527ecd195b323840d2e67c740122035923165 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Wed, 6 Dec 2023 13:00:43 -0500 Subject: [PATCH 1/2] some utility functions --- ext/SparseArraysExt.jl | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/ext/SparseArraysExt.jl b/ext/SparseArraysExt.jl index beab269ae..01ce65184 100644 --- a/ext/SparseArraysExt.jl +++ b/ext/SparseArraysExt.jl @@ -34,6 +34,24 @@ function Finch.fiber!(arr::SparseVector{Tv, Ti}; default=zero(Tv)) where {Tv, Ti return Fiber(SparseList{Ti}(Element{zero(Tv)}(arr.nzval), n, [1, length(arr.nzind) + 1], arr.nzind)) end +function SparseArrays.SparseMatrixCSC(arr::Union{Fiber, Swizzle}) + return SparseMatrixCSC(Fiber!(Dense(SparseList(Element(0.0))), arr)) +end + +function SparseArrays.SparseMatrixCSC(arr::Fiber{<:Dense{Ti, <:SparseList{Ti, Ptr, Idx, <:Element{Tv}}}}) where {Ti, Ptr, Idx, Tv} + return SparseMatrixCSC{Ti, Tv}(size(arr)..., arr.lvl.lvl.ptr, arr.lvl.lvl.idx, arr.lvl.lvl.lvl.val) +end + +function SparseArrays.sparse(fbr::Fiber) + if ndims(fbr) == 1 + return SparseVector(fbr) + elseif ndims(fbr) == 2 + return SparseMatrixCSC(fbr) + else + throw(ArgumentError("SparseArrays only supports 1D and 2D arrays")) + end +end + @kwdef mutable struct VirtualSparseMatrixCSC ex Tv @@ -131,6 +149,14 @@ Finch.FinchNotation.finch_leaf(x::VirtualSparseMatrixCSC) = virtual(x) Finch.virtual_default(arr::VirtualSparseMatrixCSC, ctx) = zero(arr.Tv) Finch.virtual_eltype(tns::VirtualSparseMatrixCSC, ctx) = tns.Tv + +function SparseArrays.SparseVector(arr::Union{Fiber, Swizzle}) + return SparseVector(Fiber!(SparseList(Element(0.0)), arr)) +end + +function SparseArrays.SparseVector(arr::Fiber{<:SparseList{Ti, Ptr, Idx, <:Element{Tv}}}) where {Ti, Ptr, Idx, Tv} + return SparseVector{Ti, Tv}(size(arr)..., arr.lvl.ptr, arr.lvl.idx, arr.lvl.lvl.val) +end @kwdef mutable struct VirtualSparseVector ex Tv From 91eb280e75cabb2735293511c08f9a510dc1445c Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Wed, 6 Dec 2023 13:51:30 -0500 Subject: [PATCH 2/2] should be fixed now --- ext/SparseArraysExt.jl | 36 ++++++++++++++++++++++++--------- src/interface/abstractarrays.jl | 10 +++++++++ test/test_issues.jl | 22 ++++++++++++++++++++ 3 files changed, 59 insertions(+), 9 deletions(-) diff --git a/ext/SparseArraysExt.jl b/ext/SparseArraysExt.jl index 01ce65184..6ee5ca5bf 100644 --- a/ext/SparseArraysExt.jl +++ b/ext/SparseArraysExt.jl @@ -3,7 +3,7 @@ module SparseArraysExt using Finch using Finch: AbstractCompiler, DefaultStyle, Extent using Finch: Unfurled, Furlable, Stepper, Jumper, Run, Fill, Lookup, Simplify, Sequence, Phase, Thunk, Spike -using Finch: virtual_size, virtual_default, getstart, getstop, freshen +using Finch: virtual_size, virtual_default, getstart, getstop, freshen, SwizzleArray using Finch.FinchNotation using Base: @kwdef @@ -34,21 +34,33 @@ function Finch.fiber!(arr::SparseVector{Tv, Ti}; default=zero(Tv)) where {Tv, Ti return Fiber(SparseList{Ti}(Element{zero(Tv)}(arr.nzval), n, [1, length(arr.nzind) + 1], arr.nzind)) end -function SparseArrays.SparseMatrixCSC(arr::Union{Fiber, Swizzle}) +""" + SparseMatrixCSC(arr::Union{Fiber, SwizzleArray}) + +Construct a sparse matrix from a fiber or swizzle. May reuse the underlying storage if possible. +""" +function SparseArrays.SparseMatrixCSC(arr::Union{Fiber, SwizzleArray}) + default(arr) === zero(eltype(arr)) || throw(ArgumentError("SparseArrays, a Julia stdlib, only supports zero default values, was given $(default(arr)) as default")) return SparseMatrixCSC(Fiber!(Dense(SparseList(Element(0.0))), arr)) end -function SparseArrays.SparseMatrixCSC(arr::Fiber{<:Dense{Ti, <:SparseList{Ti, Ptr, Idx, <:Element{Tv}}}}) where {Ti, Ptr, Idx, Tv} - return SparseMatrixCSC{Ti, Tv}(size(arr)..., arr.lvl.lvl.ptr, arr.lvl.lvl.idx, arr.lvl.lvl.lvl.val) +function SparseArrays.SparseMatrixCSC(arr::Fiber{<:Dense{Ti, <:SparseList{Ti, Ptr, Idx, <:Element{D, Tv}}}}) where {D, Ti, Ptr, Idx, Tv} + D === zero(Tv) || throw(ArgumentError("SparseArrays, a Julia stdlib, only supports zero default values, was given $D as default")) + return SparseMatrixCSC{Tv, Ti}(size(arr)..., arr.lvl.lvl.ptr, arr.lvl.lvl.idx, arr.lvl.lvl.lvl.val) end -function SparseArrays.sparse(fbr::Fiber) +""" + sparse(arr::Union{Fiber, SwizzleArray}) + +Construct a SparseArray from a Fiber or Swizzle. May reuse the underlying storage if possible. +""" +function SparseArrays.sparse(fbr::Union{Fiber, SwizzleArray}) if ndims(fbr) == 1 return SparseVector(fbr) elseif ndims(fbr) == 2 return SparseMatrixCSC(fbr) else - throw(ArgumentError("SparseArrays only supports 1D and 2D arrays")) + throw(ArgumentError("SparseArrays, a Julia stdlib, only supports 1-D and 2-D arrays, was given a $(ndims(fbr))-D array")) end end @@ -149,13 +161,19 @@ Finch.FinchNotation.finch_leaf(x::VirtualSparseMatrixCSC) = virtual(x) Finch.virtual_default(arr::VirtualSparseMatrixCSC, ctx) = zero(arr.Tv) Finch.virtual_eltype(tns::VirtualSparseMatrixCSC, ctx) = tns.Tv +""" + SparseVector(arr::Union{Fiber, SwizzleArray}) -function SparseArrays.SparseVector(arr::Union{Fiber, Swizzle}) +Construct a sparse matrix from a fiber or swizzle. May reuse the underlying storage if possible. +""" +function SparseArrays.SparseVector(arr::Union{Fiber, SwizzleArray}) + default(arr) === zero(eltype(arr)) || throw(ArgumentError("SparseArrays, a Julia stdlib, only supports zero default values, was given $(default(arr)) as default")) return SparseVector(Fiber!(SparseList(Element(0.0)), arr)) end -function SparseArrays.SparseVector(arr::Fiber{<:SparseList{Ti, Ptr, Idx, <:Element{Tv}}}) where {Ti, Ptr, Idx, Tv} - return SparseVector{Ti, Tv}(size(arr)..., arr.lvl.ptr, arr.lvl.idx, arr.lvl.lvl.val) +function SparseArrays.SparseVector(arr::Fiber{<:SparseList{Ti, Ptr, Idx, <:Element{D, Tv}}}) where {Ti, Ptr, Idx, Tv, D} + D === zero(Tv) || throw(ArgumentError("SparseArrays, a Julia stdlib, only supports zero default values, was given $D as default")) + return SparseVector{Tv, Ti}(size(arr)..., arr.lvl.idx, arr.lvl.lvl.val) end @kwdef mutable struct VirtualSparseVector ex diff --git a/src/interface/abstractarrays.jl b/src/interface/abstractarrays.jl index 854141222..e020f9418 100644 --- a/src/interface/abstractarrays.jl +++ b/src/interface/abstractarrays.jl @@ -53,6 +53,16 @@ end default(a::AbstractArray) = default(typeof(a)) default(T::Type{<:AbstractArray}) = zero(eltype(T)) +""" + Array(arr::Union{Fiber, SwizzleArray}) + +Construct an array from a fiber or swizzle. May reuse memory, will usually densify the fiber. +""" +function Base.Array(fbr::Union{Fiber, SwizzleArray}) + arr = Array{eltype(fbr)}(undef, size(fbr)...) + return copyto!(arr, fbr) +end + struct AsArray{T, N, Fbr} <: AbstractArray{T, N} fbr::Fbr function AsArray{T, N, Fbr}(fbr::Fbr) where {T, N, Fbr} diff --git a/test/test_issues.jl b/test/test_issues.jl index 44f96b3bc..b22f6ff54 100644 --- a/test/test_issues.jl +++ b/test/test_issues.jl @@ -489,4 +489,26 @@ using CIndices @finch (output_tensor .=0; for j=_,i=_,k=_; output_tensor[i,k] += a_fiber[i,j] * b_fiber[k,j]; end) end + + #https://github.com/willow-ahrens/Finch.jl/issues/321 + let + A = fsprand((10, 10), 0.1) + B = sparse(A) + @test B isa SparseMatrixCSC + @test B == A + B = SparseMatrixCSC(A) + @test B isa SparseMatrixCSC + @test B == A + A = fsprand((10,), 0.1) + B = sparse(A) + @test B isa SparseVector + @test B == A + B = SparseVector(A) + @test B isa SparseVector + @test B == A + A = fsprand((10, 10), 0.1) + B = Array(A) + @test B isa Array + @test B == A + end end