diff --git a/ext/SparseArraysExt.jl b/ext/SparseArraysExt.jl index beab269ae..6ee5ca5bf 100644 --- a/ext/SparseArraysExt.jl +++ b/ext/SparseArraysExt.jl @@ -3,7 +3,7 @@ module SparseArraysExt using Finch using Finch: AbstractCompiler, DefaultStyle, Extent using Finch: Unfurled, Furlable, Stepper, Jumper, Run, Fill, Lookup, Simplify, Sequence, Phase, Thunk, Spike -using Finch: virtual_size, virtual_default, getstart, getstop, freshen +using Finch: virtual_size, virtual_default, getstart, getstop, freshen, SwizzleArray using Finch.FinchNotation using Base: @kwdef @@ -34,6 +34,36 @@ function Finch.fiber!(arr::SparseVector{Tv, Ti}; default=zero(Tv)) where {Tv, Ti return Fiber(SparseList{Ti}(Element{zero(Tv)}(arr.nzval), n, [1, length(arr.nzind) + 1], arr.nzind)) end +""" + SparseMatrixCSC(arr::Union{Fiber, SwizzleArray}) + +Construct a sparse matrix from a fiber or swizzle. May reuse the underlying storage if possible. +""" +function SparseArrays.SparseMatrixCSC(arr::Union{Fiber, SwizzleArray}) + default(arr) === zero(eltype(arr)) || throw(ArgumentError("SparseArrays, a Julia stdlib, only supports zero default values, was given $(default(arr)) as default")) + return SparseMatrixCSC(Fiber!(Dense(SparseList(Element(0.0))), arr)) +end + +function SparseArrays.SparseMatrixCSC(arr::Fiber{<:Dense{Ti, <:SparseList{Ti, Ptr, Idx, <:Element{D, Tv}}}}) where {D, Ti, Ptr, Idx, Tv} + D === zero(Tv) || throw(ArgumentError("SparseArrays, a Julia stdlib, only supports zero default values, was given $D as default")) + return SparseMatrixCSC{Tv, Ti}(size(arr)..., arr.lvl.lvl.ptr, arr.lvl.lvl.idx, arr.lvl.lvl.lvl.val) +end + +""" + sparse(arr::Union{Fiber, SwizzleArray}) + +Construct a SparseArray from a Fiber or Swizzle. May reuse the underlying storage if possible. +""" +function SparseArrays.sparse(fbr::Union{Fiber, SwizzleArray}) + if ndims(fbr) == 1 + return SparseVector(fbr) + elseif ndims(fbr) == 2 + return SparseMatrixCSC(fbr) + else + throw(ArgumentError("SparseArrays, a Julia stdlib, only supports 1-D and 2-D arrays, was given a $(ndims(fbr))-D array")) + end +end + @kwdef mutable struct VirtualSparseMatrixCSC ex Tv @@ -131,6 +161,20 @@ Finch.FinchNotation.finch_leaf(x::VirtualSparseMatrixCSC) = virtual(x) Finch.virtual_default(arr::VirtualSparseMatrixCSC, ctx) = zero(arr.Tv) Finch.virtual_eltype(tns::VirtualSparseMatrixCSC, ctx) = tns.Tv +""" + SparseVector(arr::Union{Fiber, SwizzleArray}) + +Construct a sparse matrix from a fiber or swizzle. May reuse the underlying storage if possible. +""" +function SparseArrays.SparseVector(arr::Union{Fiber, SwizzleArray}) + default(arr) === zero(eltype(arr)) || throw(ArgumentError("SparseArrays, a Julia stdlib, only supports zero default values, was given $(default(arr)) as default")) + return SparseVector(Fiber!(SparseList(Element(0.0)), arr)) +end + +function SparseArrays.SparseVector(arr::Fiber{<:SparseList{Ti, Ptr, Idx, <:Element{D, Tv}}}) where {Ti, Ptr, Idx, Tv, D} + D === zero(Tv) || throw(ArgumentError("SparseArrays, a Julia stdlib, only supports zero default values, was given $D as default")) + return SparseVector{Tv, Ti}(size(arr)..., arr.lvl.idx, arr.lvl.lvl.val) +end @kwdef mutable struct VirtualSparseVector ex Tv diff --git a/src/interface/abstractarrays.jl b/src/interface/abstractarrays.jl index 854141222..e020f9418 100644 --- a/src/interface/abstractarrays.jl +++ b/src/interface/abstractarrays.jl @@ -53,6 +53,16 @@ end default(a::AbstractArray) = default(typeof(a)) default(T::Type{<:AbstractArray}) = zero(eltype(T)) +""" + Array(arr::Union{Fiber, SwizzleArray}) + +Construct an array from a fiber or swizzle. May reuse memory, will usually densify the fiber. +""" +function Base.Array(fbr::Union{Fiber, SwizzleArray}) + arr = Array{eltype(fbr)}(undef, size(fbr)...) + return copyto!(arr, fbr) +end + struct AsArray{T, N, Fbr} <: AbstractArray{T, N} fbr::Fbr function AsArray{T, N, Fbr}(fbr::Fbr) where {T, N, Fbr} diff --git a/test/test_issues.jl b/test/test_issues.jl index 44f96b3bc..b22f6ff54 100644 --- a/test/test_issues.jl +++ b/test/test_issues.jl @@ -489,4 +489,26 @@ using CIndices @finch (output_tensor .=0; for j=_,i=_,k=_; output_tensor[i,k] += a_fiber[i,j] * b_fiber[k,j]; end) end + + #https://github.com/willow-ahrens/Finch.jl/issues/321 + let + A = fsprand((10, 10), 0.1) + B = sparse(A) + @test B isa SparseMatrixCSC + @test B == A + B = SparseMatrixCSC(A) + @test B isa SparseMatrixCSC + @test B == A + A = fsprand((10,), 0.1) + B = sparse(A) + @test B isa SparseVector + @test B == A + B = SparseVector(A) + @test B isa SparseVector + @test B == A + A = fsprand((10, 10), 0.1) + B = Array(A) + @test B isa Array + @test B == A + end end