Skip to content

Commit

Permalink
Merge pull request #315 from willow-ahrens/fix-#313
Browse files Browse the repository at this point in the history
Fix #313
  • Loading branch information
willow-ahrens authored Nov 17, 2023
2 parents 10f788d + 4f0bd64 commit cf89dca
Showing 3 changed files with 38 additions and 19 deletions.
36 changes: 18 additions & 18 deletions src/symbolic/simplify_program.jl
Original file line number Diff line number Diff line change
@@ -115,7 +115,7 @@ function get_program_rules(alg, shash)
body_contain_idx = idx getunbound(body)
if !body_contain_idx
decl_in_scope = filter(!isnothing, map(node-> if @capture(node, declare(~tns, ~init)) tns
elseif @capture(node, define(~var, ~val, ~body)) var
elseif @capture(node, define(~var, ~val, ~body_2)) var
end, PostOrderDFS(body)))
Postwalk(@rule assign(access(~lhs, updater, ~j...), ~f, ~rhs) => begin
access_in_rhs = filter(!isnothing, map(node-> if @capture(node, access(~tns, reader, ~k...)) tns # TODO add getroot here?
@@ -144,23 +144,23 @@ function get_program_rules(alg, shash)

## Bottom-up reduction2
(@rule loop(~idx, ~ext::isvirtual, block(~s1..., assign(access(~lhs, updater, ~j...), ~f, ~rhs), ~s2...)) => begin
if ortho(getroot(lhs), s1) && ortho(getroot(lhs), s2)
if idx j && idx getunbound(rhs)
body = block(s1..., assign(access(lhs, updater, j...), f, rhs), s2...)
decl_in_scope = filter(!isnothing, map(node-> if @capture(node, declare(~tns, ~init)) tns
elseif @capture(node, define(~var, ~val, ~body)) var
end, PostOrderDFS(body)))

access_in_rhs = filter(!isnothing, map(node-> if @capture(node, access(~tns, reader, ~k...)) tns
elseif @capture(node, ~var::isvariable) var
end, PostOrderDFS(rhs)))

if !(lhs in decl_in_scope) && isempty(intersect(access_in_rhs, decl_in_scope))
collapsed_body = collapsed(alg, idx, ext.val, access(lhs, updater, j...), f, rhs)
block(collapsed_body, loop(idx, ext, block(s1..., s2...)))
end
end
end
if ortho(getroot(lhs), s1) && ortho(getroot(lhs), s2)
if idx j && idx getunbound(rhs)
body = block(s1..., assign(access(lhs, updater, j...), f, rhs), s2...)
decl_in_scope = filter(!isnothing, map(node-> if @capture(node, declare(~tns, ~init)) tns
elseif @capture(node, define(~var, ~val, ~body_2)) var
end, PostOrderDFS(body)))

access_in_rhs = filter(!isnothing, map(node-> if @capture(node, access(~tns, reader, ~k...)) tns
elseif @capture(node, ~var::isvariable) var
end, PostOrderDFS(rhs)))
if !(lhs in decl_in_scope) && isempty(intersect(access_in_rhs, decl_in_scope))
collapsed_body = collapsed(alg, idx, ext.val, access(lhs, updater, j...), f, rhs)
block(collapsed_body, loop(idx, ext, block(s1..., s2...)))
end
end
end
end),
(@rule block(~s1..., thaw(~a::isvariable), ~s2..., freeze(~a), ~s3...) => if ortho(a, s2)
block(s1..., s2..., s3...)
2 changes: 1 addition & 1 deletion src/tensors/levels/sparsehashlevels.jl
Original file line number Diff line number Diff line change
@@ -277,7 +277,7 @@ function thaw_level!(lvl::VirtualSparseHashLevel, ctx::AbstractCompiler, pos)
$(lvl.ptr)[1] = 1
$(lvl.qos_fill) = length($(lvl.tbl))
end)
lvl.lvl = thaw_level!(lvl.lvl, ctx, call(*, pos, lvl.shape))
lvl.lvl = thaw_level!(lvl.lvl, ctx, value(lvl.qos_fill, Tp))
return lvl
end

19 changes: 19 additions & 0 deletions test/test_issues.jl
Original file line number Diff line number Diff line change
@@ -470,4 +470,23 @@ using CIndices
end
end)
end

#https://github.com/willow-ahrens/Finch.jl/issues/313
let
edge_matrix = Fiber!(SparseList(SparseList(Element(0.0), 254), 254))
edge_values = fsprand((254, 254), .001)
@finch (edge_matrix .= 0; for j=_, i=_; edge_matrix[i,j] = edge_values[i,j]; end)
output_matrix = Fiber!(SparseHash{1}(SparseHash{1}(Element(0.0), (254,)), (254,)))
@finch (for v_4=_, v_3=_, v_2=_, v_5=_; output_matrix[v_2,v_5] += edge_matrix[v_5, v_4]*edge_matrix[v_2, v_3]*edge_matrix[v_3, v_4]; end)

a_matrix = [1 0; 0 1]
a_fiber = Fiber!(SparseList(SparseList(Element(0.0), 2), 2))
copyto!(a_fiber, a_matrix)
b_matrix = [0 1; 1 0]
b_fiber = Fiber!(SparseList(SparseList(Element(0.0), 2), 2))
copyto!(b_fiber, b_matrix)
output_tensor = Fiber!(SparseHash{1}(SparseHash{1}(Element(0.0), (2,)), (2,)))

@finch (output_tensor .=0; for j=_,i=_,k=_; output_tensor[i,k] += a_fiber[i,j] * b_fiber[k,j]; end)
end
end

0 comments on commit cf89dca

Please sign in to comment.