Skip to content

Commit

Permalink
Merge pull request #675 from finch-tensor/wma/pointfix
Browse files Browse the repository at this point in the history
  • Loading branch information
willow-ahrens authored Dec 28, 2024
2 parents f8e0e54 + 97065a1 commit 53ae4fb
Show file tree
Hide file tree
Showing 45 changed files with 454 additions and 824 deletions.
23 changes: 19 additions & 4 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Finch
using BenchmarkTools
using MatrixDepot
using SparseArrays
using Random
include(joinpath(@__DIR__, "../docs/examples/bfs.jl"))
include(joinpath(@__DIR__, "../docs/examples/pagerank.jl"))
include(joinpath(@__DIR__, "../docs/examples/shortest_paths.jl"))
Expand Down Expand Up @@ -245,6 +246,17 @@ function spmv_serial(A, x)
end
end

function spmv_noinit(y, A, x)
@finch begin
for i=_
for j=_
y[i] += A[j, i] * x[j]
end
end
return y
end
end

function spmv_threaded(A, x)
y = Tensor(Dense{Int64}(Element{0.0, Float64}()))
@finch begin
Expand Down Expand Up @@ -302,25 +314,28 @@ end

SUITE["structure"] = BenchmarkGroup()

N = 100_000
N = 1_000_000

SUITE["structure"]["permutation"] = BenchmarkGroup()

A_ref = Tensor(Dense(SparseList(Element(0.0))), fsparse(collect(1:N), collect(1:N), ones(N)))
perm = randperm(N)

A_ref = Tensor(Dense(SparseList(Element(0.0))), fsparse(collect(1:N), perm, ones(N)))

A = Tensor(Dense(SparsePoint(Element(0.0))), A_ref)

x = rand(N)

SUITE["structure"]["permutation"]["SparseList"] = @benchmarkable spmv_serial($A_ref, $x)
SUITE["structure"]["permutation"]["SparsePoint"] = @benchmarkable spmv_serial($A, $x)
SUITE["structure"]["permutation"]["baseline"] = @benchmarkable $x[$perm]

SUITE["structure"]["banded"] = BenchmarkGroup()

A_ref = Tensor(Dense(Sparse(Element(0.0))), N, N)

@finch for i = _, j = _
if abs(i - j) < 2
@finch for j = _, i = _
if j - 2 < i < j + 2
A_ref[i, j] = 1.0
end
end
Expand Down
124 changes: 44 additions & 80 deletions src/tensors/levels/sparse_band_levels.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
SparseBandLevel{[Ti=Int], [Ptr, Idx, Ofs]}(lvl, [dim])
SparseBandLevel{[Ti=Int], [Idx, Ofs]}(lvl, [dim])
Like the [`SparseBlockListLevel`](@ref), but stores only a single block, and fills in zeros.
Expand All @@ -14,10 +14,9 @@ Dense [:,1:3]
│ ├─[1]: 20.0
│ ├─[3]: 40.0
"""
struct SparseBandLevel{Ti, Ptr<:AbstractVector, Idx<:AbstractVector, Ofs<:AbstractVector, Lvl} <: AbstractLevel
struct SparseBandLevel{Ti, Idx<:AbstractVector, Ofs<:AbstractVector, Lvl} <: AbstractLevel
lvl::Lvl
shape::Ti
ptr::Ptr
idx::Idx
ofs::Ofs
end
Expand All @@ -26,40 +25,39 @@ const SparseBand = SparseBandLevel
SparseBandLevel(lvl::Lvl) where {Lvl} = SparseBandLevel{Int}(lvl)
SparseBandLevel(lvl, shape, args...) = SparseBandLevel{typeof(shape)}(lvl, shape, args...)
SparseBandLevel{Ti}(lvl) where {Ti} = SparseBandLevel{Ti}(lvl, zero(Ti))
SparseBandLevel{Ti}(lvl, shape) where {Ti} = SparseBandLevel{Ti}(lvl, shape, postype(lvl)[1], Ti[], postype(lvl)[])
SparseBandLevel{Ti}(lvl::Lvl, shape, ptr::Ptr, idx::Idx, ofs::Ofs) where {Ti, Lvl, Ptr, Idx, Ofs} =
SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}(lvl, Ti(shape), ptr, idx, ofs)
SparseBandLevel{Ti}(lvl, shape) where {Ti} = SparseBandLevel{Ti}(lvl, shape, Ti[], postype(lvl)[])
SparseBandLevel{Ti}(lvl::Lvl, shape, idx::Idx, ofs::Ofs) where {Ti, Lvl, Idx, Ofs} =
SparseBandLevel{Ti, Idx, Ofs, Lvl}(lvl, Ti(shape), idx, ofs)

function postype(::Type{SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}}) where {Ti, Ptr, Idx, Ofs, Lvl}
function postype(::Type{SparseBandLevel{Ti, Idx, Ofs, Lvl}}) where {Ti, Idx, Ofs, Lvl}
return postype(Lvl)
end

function moveto(lvl::SparseBandLevel{Ti}, device) where {Ti}
lvl_2 = moveto(lvl.lvl, device)
ptr_2 = moveto(lvl.ptr, device)
idx_2 = moveto(lvl.idx, device)
ofs_2 = moveto(lvl.ofs, device)
return SparseBandLevel{Ti}(lvl_2, lvl.shape, ptr_2, idx_2, ofs_2)
return SparseBandLevel{Ti}(lvl_2, lvl.shape, idx_2, ofs_2)
end

Base.summary(lvl::SparseBandLevel) = "SparseBand($(summary(lvl.lvl)))"
similar_level(lvl::SparseBandLevel, fill_value, eltype::Type, dim, tail...) =
SparseBand(similar_level(lvl.lvl, fill_value, eltype, tail...), dim)

pattern!(lvl::SparseBandLevel{Ti}) where {Ti} =
SparseBandLevel{Ti}(pattern!(lvl.lvl), lvl.shape, lvl.ptr, lvl.idx, lvl.ofs)
SparseBandLevel{Ti}(pattern!(lvl.lvl), lvl.shape, lvl.idx, lvl.ofs)

function countstored_level(lvl::SparseBandLevel, pos)
countstored_level(lvl.lvl, lvl.ofs[lvl.ptr[pos + 1]]-1)
countstored_level(lvl.lvl, lvl.ofs[pos + 1]-1)
end

set_fill_value!(lvl::SparseBandLevel{Ti}, init) where {Ti} =
SparseBandLevel{Ti}(set_fill_value!(lvl.lvl, init), lvl.shape, lvl.ptr, lvl.idx, lvl.ofs)
SparseBandLevel{Ti}(set_fill_value!(lvl.lvl, init), lvl.shape, lvl.idx, lvl.ofs)

Base.resize!(lvl::SparseBandLevel{Ti}, dims...) where {Ti} =
SparseBandLevel{Ti}(resize!(lvl.lvl, dims[1:end-1]...), dims[end], lvl.ptr, lvl.idx, lvl.ofs)
SparseBandLevel{Ti}(resize!(lvl.lvl, dims[1:end-1]...), dims[end], lvl.idx, lvl.ofs)

function Base.show(io::IO, lvl::SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}) where {Ti, Ptr, Idx, Ofs, Lvl}
function Base.show(io::IO, lvl::SparseBandLevel{Ti, Idx, Ofs, Lvl}) where {Ti, Idx, Ofs, Lvl}
if get(io, :compact, false)
print(io, "SparseBand(")
else
Expand All @@ -72,8 +70,6 @@ function Base.show(io::IO, lvl::SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}) where {
if get(io, :compact, false)
print(io, "")
else
show(io, lvl.ptr)
print(io, ", ")
show(io, lvl.idx)
print(io, ", ")
show(io, lvl.ofs)
Expand All @@ -87,37 +83,34 @@ labelled_show(io::IO, fbr::SubFiber{<:SparseBandLevel}) =
function labelled_children(fbr::SubFiber{<:SparseBandLevel})
lvl = fbr.lvl
pos = fbr.pos
pos + 1 > length(lvl.ptr) && return []
res = []
for r = lvl.ptr[pos]:lvl.ptr[pos + 1] - 1
i = lvl.idx[r]
qos = lvl.ofs[r]
l = lvl.ofs[r + 1] - lvl.ofs[r]
for qos = lvl.ofs[r]:lvl.ofs[r + 1] - 1
push!(res, LabelledTree(cartesian_label([range_label() for _ = 1:ndims(fbr) - 1]..., i - (lvl.ofs[r + 1] - 1) + qos), SubFiber(lvl.lvl, qos)))
end
for qos = lvl.ofs[pos]:lvl.ofs[pos + 1] - 1
i = lvl.idx[pos] - lvl.ofs[pos + 1] + qos + 1
push!(res, LabelledTree(cartesian_label([range_label() for _ = 1:ndims(fbr) - 1]..., i), SubFiber(lvl.lvl, qos)))
end
res
end

@inline level_ndims(::Type{<:SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}}) where {Ti, Ptr, Idx, Ofs, Lvl} = 1 + level_ndims(Lvl)
@inline level_ndims(::Type{<:SparseBandLevel{Ti, Idx, Ofs, Lvl}}) where {Ti, Idx, Ofs, Lvl} = 1 + level_ndims(Lvl)
@inline level_size(lvl::SparseBandLevel) = (level_size(lvl.lvl)..., lvl.shape)
@inline level_axes(lvl::SparseBandLevel) = (level_axes(lvl.lvl)..., Base.OneTo(lvl.shape))
@inline level_eltype(::Type{<:SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}}) where {Ti, Ptr, Idx, Ofs, Lvl} = level_eltype(Lvl)
@inline level_fill_value(::Type{<:SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}}) where {Ti, Ptr, Idx, Ofs, Lvl} = level_fill_value(Lvl)
data_rep_level(::Type{<:SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}}) where {Ti, Ptr, Idx, Ofs, Lvl} = SparseData(data_rep_level(Lvl))
@inline level_eltype(::Type{<:SparseBandLevel{Ti, Idx, Ofs, Lvl}}) where {Ti, Idx, Ofs, Lvl} = level_eltype(Lvl)
@inline level_fill_value(::Type{<:SparseBandLevel{Ti, Idx, Ofs, Lvl}}) where {Ti, Idx, Ofs, Lvl} = level_fill_value(Lvl)
data_rep_level(::Type{<:SparseBandLevel{Ti, Idx, Ofs, Lvl}}) where {Ti, Idx, Ofs, Lvl} = SparseData(data_rep_level(Lvl))

(fbr::AbstractFiber{<:SparseBandLevel})() = fbr
function (fbr::SubFiber{<:SparseBandLevel})(idxs...)
isempty(idxs) && return fbr
lvl = fbr.lvl
p = fbr.pos
r = lvl.ptr[p] + searchsortedfirst(@view(lvl.idx[lvl.ptr[p]:lvl.ptr[p + 1] - 1]), idxs[end]) - 1
r < lvl.ptr[p + 1] || return fill_value(fbr)
q = lvl.ofs[r + 1] - 1 - lvl.idx[r] + idxs[end]
q >= lvl.ofs[r] || return fill_value(fbr)
fbr_2 = SubFiber(lvl.lvl, q)
return fbr_2(idxs[1:end-1]...)
pos = fbr.pos
start = lvl.idx[pos] - lvl.ofs[pos + 1] + lvl.ofs[pos] + 1
stop = lvl.idx[pos]
if start <= idxs[end] <= stop
qos = lvl.ofs[pos] + idxs[end] - start
fbr_2 = SubFiber(lvl.lvl, qos)
return fbr_2(idxs[1:end-1]...)
end
return fill_value(fbr)
end

mutable struct VirtualSparseBandLevel <: AbstractVirtualLevel
Expand All @@ -130,7 +123,6 @@ mutable struct VirtualSparseBandLevel <: AbstractVirtualLevel
ros_fill
ros_stop
dirty
ptr
idx
ofs
prev_pos
Expand All @@ -149,33 +141,30 @@ end
postype(lvl::VirtualSparseBandLevel) = postype(lvl.lvl)


function virtualize(ctx, ex, ::Type{SparseBandLevel{Ti, Ptr, Idx, Ofs, Lvl}}, tag=:lvl) where {Ti, Ptr, Idx, Ofs, Lvl}
function virtualize(ctx, ex, ::Type{SparseBandLevel{Ti, Idx, Ofs, Lvl}}, tag=:lvl) where {Ti, Idx, Ofs, Lvl}
sym = freshen(ctx, tag)
shape = value(:($sym.shape), Int)
qos_fill = freshen(ctx, sym, :_qos_fill)
qos_stop = freshen(ctx, sym, :_qos_stop)
ros_fill = freshen(ctx, sym, :_ros_fill)
ros_stop = freshen(ctx, sym, :_ros_stop)
dirty = freshen(ctx, sym, :_dirty)
ptr = freshen(ctx, tag, :_ptr)
idx = freshen(ctx, tag, :_idx)
ofs = freshen(ctx, tag, :_ofs)
push_preamble!(ctx, quote
$sym = $ex
$ptr = $sym.ptr
$idx = $sym.idx
$ofs = $sym.ofs
end)
prev_pos = freshen(ctx, sym, :_prev_pos)
lvl_2 = virtualize(ctx, :($sym.lvl), Lvl, sym)
VirtualSparseBandLevel(lvl_2, sym, Ti, shape, qos_fill, qos_stop, ros_fill, ros_stop, dirty, ptr, idx, ofs, prev_pos)
VirtualSparseBandLevel(lvl_2, sym, Ti, shape, qos_fill, qos_stop, ros_fill, ros_stop, dirty, idx, ofs, prev_pos)
end
function lower(ctx::AbstractCompiler, lvl::VirtualSparseBandLevel, ::DefaultStyle)
quote
$SparseBandLevel{$(lvl.Ti)}(
$(ctx(lvl.lvl)),
$(ctx(lvl.shape)),
$(lvl.ptr),
$(lvl.idx),
$(lvl.ofs),
)
Expand All @@ -199,19 +188,15 @@ virtual_level_eltype(lvl::VirtualSparseBandLevel) = virtual_level_eltype(lvl.lvl
virtual_level_fill_value(lvl::VirtualSparseBandLevel) = virtual_level_fill_value(lvl.lvl)

function virtual_moveto_level(ctx::AbstractCompiler, lvl::VirtualSparseBandLevel, arch)
ptr_2 = freshen(ctx, lvl.ptr)
tbl_2 = freshen(ctx, lvl.tbl)
ofs_2 = freshen(ctx, lvl.ofs)
push_preamble!(ctx, quote
$ptr_2 = $(lvl.ptr)
$tbl_2 = $(lvl.tbl)
$ofs_2 = $(lvl.ofs)
$(lvl.ptr) = $moveto($(lvl.ptr), $(ctx(arch)))
$(lvl.tbl) = $moveto($(lvl.tbl), $(ctx(arch)))
$(lvl.ofs) = $moveto($(lvl.ofs), $(ctx(arch)))
end)
push_epilogue!(ctx, quote
$(lvl.ptr) = $ptr_2
$(lvl.tbl) = $tbl_2
$(lvl.ofs) = $ofs_2
end)
Expand All @@ -224,8 +209,6 @@ function declare_level!(ctx::AbstractCompiler, lvl::VirtualSparseBandLevel, pos,
push_preamble!(ctx, quote
$(lvl.qos_fill) = $(Tp(0))
$(lvl.qos_stop) = $(Tp(0))
$(lvl.ros_fill) = $(Tp(0))
$(lvl.ros_stop) = $(Tp(0))
Finch.resize_if_smaller!($(lvl.ofs), 1)
$(lvl.ofs)[1] = 1
end)
Expand All @@ -242,26 +225,25 @@ function assemble_level!(ctx, lvl::VirtualSparseBandLevel, pos_start, pos_stop)
pos_start = ctx(cache!(ctx, :p_start, pos_start))
pos_stop = ctx(cache!(ctx, :p_start, pos_stop))
return quote
Finch.resize_if_smaller!($(lvl.ptr), $pos_stop + 1)
Finch.fill_range!($(lvl.ptr), 0, $pos_start + 1, $pos_stop + 1)
Finch.resize_if_smaller!($(lvl.idx), $pos_stop)
Finch.fill_range!($(lvl.idx), 1, $pos_start, $pos_stop)
Finch.resize_if_smaller!($(lvl.ofs), $pos_stop + 1)
Finch.fill_range!($(lvl.ofs), 0, $pos_start + 1, $pos_stop + 1)
end
end

function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseBandLevel, pos_stop)
p = freshen(ctx, :p)
Tp = postype(lvl)
pos_stop = ctx(cache!(ctx, :pos_stop, simplify(ctx, pos_stop)))
ros_stop = freshen(ctx, :ros_stop)
qos_stop = freshen(ctx, :qos_stop)
push_preamble!(ctx, quote
resize!($(lvl.ptr), $pos_stop + 1)
resize!($(lvl.idx), $pos_stop)
resize!($(lvl.ofs), $pos_stop + 1)
for $p = 2:($pos_stop + 1)
$(lvl.ptr)[$p] += $(lvl.ptr)[$p - 1]
$(lvl.ofs)[$p] += $(lvl.ofs)[$p - 1]
end
$ros_stop = $(lvl.ptr)[$pos_stop + 1] - 1
resize!($(lvl.idx), $ros_stop)
resize!($(lvl.ofs), $ros_stop + 1)
$qos_stop = $(lvl.ofs)[$ros_stop + 1] - $(Tp(1))
$qos_stop = $(lvl.ofs)[$pos_stop + 1] - $(Tp(1))
end)
lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_stop))
return lvl
Expand All @@ -285,19 +267,10 @@ function unfurl(ctx, fbr::VirtualSubFiber{VirtualSparseBandLevel}, ext, mode::Re
arr = fbr,
body = Thunk(
preamble = quote
$my_r = $(lvl.ptr)[$(ctx(pos))]
$my_r_stop = $(lvl.ptr)[$(ctx(pos)) + $(Tp(1))] - 1
if $my_r <= $my_r_stop
$my_i1 = $(lvl.idx)[$my_r]
$my_q_stop = $(lvl.ofs)[$my_r + $(Tp(1))]
$my_i_start = $my_i1 - ($my_q_stop - $(lvl.ofs)[$my_r] - 1)
$my_q_ofs = $my_q_stop - $my_i1 - $(Tp(1))
else
$my_i_start = $(Ti(1))
$my_i1 = $(Ti(0))
$my_q_stop = $(Ti(0))
$my_q = $(Ti(0))
end
$my_i1 = $(lvl.idx)[$(ctx(pos))]
$my_q_stop = $(lvl.ofs)[$(ctx(pos)) + $(Tp(1))]
$my_i_start = $my_i1 - ($my_q_stop - $(lvl.ofs)[$(ctx(pos))] - 1)
$my_q_ofs = $my_q_stop - $my_i1 - $(Tp(1))
end,
body = (ctx) -> Sequence([
Phase(
Expand Down Expand Up @@ -346,7 +319,6 @@ function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseBandLevel}, ext, mo
arr = fbr,
body = Thunk(
preamble = quote
$ros = $ros_fill
$qos = $qos_fill + 1
$qos_set = $qos_fill
$my_i_prev = $(Ti(-1))
Expand Down Expand Up @@ -394,24 +366,16 @@ function unfurl(ctx, fbr::VirtualHollowSubFiber{VirtualSparseBandLevel}, ext, mo
),
epilogue = quote
if $my_i_prev > 0
$ros += 1
if $ros > $ros_stop
$ros_stop = max($ros_stop << 1, 1)
Finch.resize_if_smaller!($(lvl.idx), $ros_stop)
Finch.resize_if_smaller!($(lvl.ofs), $ros_stop + 1)
end
$qos = $qos_set
$(lvl.idx)[$(ros)] = $my_i_set
$(lvl.ofs)[$(ros) + 1] = $qos + 1
$(lvl.idx)[$(ctx(pos))] = $my_i_set
$(lvl.ofs)[$(ctx(pos)) + 1] = $my_i_set - $my_i_prev + 1
$(if issafe(get_mode_flag(ctx))
quote
$(lvl.prev_pos) = $(ctx(pos))
end
end)
$qos_fill = $qos
end
$(lvl.ptr)[$(ctx(pos)) + 1] += $ros - $ros_fill
$ros_fill = $ros
end
)
)
Expand Down
1 change: 1 addition & 0 deletions src/tensors/levels/sparse_dict_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ end
function virtual_moveto_level(ctx::AbstractCompiler, lvl::VirtualSparseDictLevel, arch)
ptr_2 = freshen(ctx, lvl.ptr)
idx_2 = freshen(ctx, lvl.idx)
tbl_2 = freshen(ctx, lvl.tbl_2)
push_preamble!(ctx, quote
$tbl_2 = $(lvl.tbl)
$(lvl.tbl) = $moveto($(lvl.tbl), $(ctx(arch)))
Expand Down
Loading

0 comments on commit 53ae4fb

Please sign in to comment.