Skip to content

Commit

Permalink
Merge pull request #355 from willow-ahrens/jaeyeon/fixdataflow
Browse files Browse the repository at this point in the history
Dataflow Bug
  • Loading branch information
willow-ahrens authored Dec 20, 2023
2 parents 88c0976 + 7e3dcfa commit df553b8
Show file tree
Hide file tree
Showing 33 changed files with 107 additions and 26 deletions.
3 changes: 3 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ quote
@warn "Performance Warning: non-concordant traversal of A[i, j] (hint: most arrays prefer column major or first index fast, run in fast mode to ignore this warning)"
for i_3 = 1:A_mode1_stop
for j_3 = 1:A_mode2_stop
sugar_3 = size(A)
A_mode1_stop = sugar_3[1]
A_mode2_stop = sugar_3[2]
val = A[i_3, j_3]
s_val = val + s_val
end
Expand Down
2 changes: 1 addition & 1 deletion docs/src/interactive.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
{
"output_type": "execute_result",
"data": {
"text/plain": "quote\n y = ex.body.body.lhs.tns.bind\n A_lvl = (ex.body.body.rhs.args[1]).tns.bind.lvl\n A_lvl_2 = A_lvl.lvl\n A_lvl_ptr = A_lvl_2.ptr\n A_lvl_idx = A_lvl_2.idx\n A_lvl_2_val = A_lvl_2.lvl.val\n x = (ex.body.body.rhs.args[2]).tns.bind\n sugar_1 = size(y)\n y_mode1_stop = sugar_1[1]\n A_lvl_2.shape == y_mode1_stop || throw(DimensionMismatch(\"mismatched dimension limits ($(A_lvl_2.shape) != $(y_mode1_stop))\"))\n sugar_2 = size(x)\n x_mode1_stop = sugar_2[1]\n x_mode1_stop == A_lvl.shape || throw(DimensionMismatch(\"mismatched dimension limits ($(x_mode1_stop) != $(A_lvl.shape))\"))\n for j_4 = 1:x_mode1_stop\n val = x[j_4]\n A_lvl_q = (1 - 1) * A_lvl.shape + j_4\n A_lvl_2_q = A_lvl_ptr[A_lvl_q]\n A_lvl_2_q_stop = A_lvl_ptr[A_lvl_q + 1]\n if A_lvl_2_q < A_lvl_2_q_stop\n A_lvl_2_i1 = A_lvl_idx[A_lvl_2_q_stop - 1]\n else\n A_lvl_2_i1 = 0\n end\n phase_stop = min(A_lvl_2.shape, A_lvl_2_i1)\n if phase_stop >= 1\n if A_lvl_idx[A_lvl_2_q] < 1\n A_lvl_2_q = Finch.scansearch(A_lvl_idx, 1, A_lvl_2_q, A_lvl_2_q_stop - 1)\n end\n while true\n A_lvl_2_i = A_lvl_idx[A_lvl_2_q]\n if A_lvl_2_i < phase_stop\n A_lvl_3_val = A_lvl_2_val[A_lvl_2_q]\n y[A_lvl_2_i] = val * A_lvl_3_val + y[A_lvl_2_i]\n A_lvl_2_q += 1\n else\n phase_stop_3 = min(A_lvl_2_i, phase_stop)\n if A_lvl_2_i == phase_stop_3\n A_lvl_3_val = A_lvl_2_val[A_lvl_2_q]\n y[phase_stop_3] = val * A_lvl_3_val + y[phase_stop_3]\n A_lvl_2_q += 1\n end\n break\n end\n end\n end\n end\n (y = y,)\nend"
"text/plain": "quote\n y = ex.body.body.lhs.tns.bind\n A_lvl = (ex.body.body.rhs.args[1]).tns.bind.lvl\n A_lvl_2 = A_lvl.lvl\n A_lvl_ptr = A_lvl_2.ptr\n A_lvl_idx = A_lvl_2.idx\n A_lvl_2_val = A_lvl_2.lvl.val\n x = (ex.body.body.rhs.args[2]).tns.bind\n sugar_1 = size(y)\n y_mode1_stop = sugar_1[1]\n A_lvl_2.shape == y_mode1_stop || throw(DimensionMismatch(\"mismatched dimension limits ($(A_lvl_2.shape) != $(y_mode1_stop))\"))\n sugar_2 = size(x)\n x_mode1_stop = sugar_2[1]\n x_mode1_stop == A_lvl.shape || throw(DimensionMismatch(\"mismatched dimension limits ($(x_mode1_stop) != $(A_lvl.shape))\"))\n for j_4 = 1:x_mode1_stop\n sugar_3 = size(x)\n x_mode1_stop = sugar_3[1]\n val = x[j_4]\n A_lvl_q = (1 - 1) * A_lvl.shape + j_4\n A_lvl_2_q = A_lvl_ptr[A_lvl_q]\n A_lvl_2_q_stop = A_lvl_ptr[A_lvl_q + 1]\n if A_lvl_2_q < A_lvl_2_q_stop\n A_lvl_2_i1 = A_lvl_idx[A_lvl_2_q_stop - 1]\n else\n A_lvl_2_i1 = 0\n end\n phase_stop = min(A_lvl_2.shape, A_lvl_2_i1)\n if phase_stop >= 1\n if A_lvl_idx[A_lvl_2_q] < 1\n A_lvl_2_q = Finch.scansearch(A_lvl_idx, 1, A_lvl_2_q, A_lvl_2_q_stop - 1)\n end\n while true\n A_lvl_2_i = A_lvl_idx[A_lvl_2_q]\n if A_lvl_2_i < phase_stop\n A_lvl_3_val = A_lvl_2_val[A_lvl_2_q]\n y[A_lvl_2_i] = val * A_lvl_3_val + y[A_lvl_2_i]\n A_lvl_2_q += 1\n else\n phase_stop_3 = min(A_lvl_2_i, phase_stop)\n if A_lvl_2_i == phase_stop_3\n A_lvl_3_val = A_lvl_2_val[A_lvl_2_q]\n y[phase_stop_3] = val * A_lvl_3_val + y[phase_stop_3]\n A_lvl_2_q += 1\n end\n break\n end\n end\n end\n end\n (y = y,)\nend"
},
"metadata": {},
"execution_count": 1
Expand Down
9 changes: 9 additions & 0 deletions docs/src/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ quote
for j_4 = 1:B_mode2_stop
sugar_2 = size(B)
B_mode1_stop = sugar_2[1]
B_mode2_stop = sugar_2[2]
A_lvl_q = (1 - 1) * A_lvl.shape + j_4
A_lvl_2_q = A_lvl_ptr[A_lvl_q]
A_lvl_2_q_stop = A_lvl_ptr[A_lvl_q + 1]
Expand All @@ -271,6 +272,7 @@ quote
A_lvl_3_val = A_lvl_2_val[A_lvl_2_q]
sugar_4 = size(B)
B_mode1_stop = sugar_4[1]
B_mode2_stop = sugar_4[2]
val = B[A_lvl_2_i, j_4]
C_val = C_val + f(A_lvl_3_val, val)
A_lvl_2_q += 1
Expand All @@ -285,11 +287,15 @@ quote
A_lvl_3_val = A_lvl_2_val[A_lvl_2_q]
sugar_6 = size(B)
B_mode1_stop = sugar_6[1]
B_mode2_stop = sugar_6[2]
val = B[phase_stop_3, j_4]
C_val = C_val + f(A_lvl_3_val, val)
A_lvl_2_q += 1
else
for i_10 = i:phase_stop_3
sugar_7 = size(B)
B_mode1_stop = sugar_7[1]
B_mode2_stop = sugar_7[2]
val = B[i_10, j_4]
C_val = C_val + f(0.0, val)
end
Expand All @@ -302,6 +308,9 @@ quote
phase_start_3 = max(1, 1 + A_lvl_2_i1)
if B_mode1_stop >= phase_start_3
for i_12 = phase_start_3:B_mode1_stop
sugar_8 = size(B)
B_mode1_stop = sugar_8[1]
B_mode2_stop = sugar_8[2]
val = B[i_12, j_4]
C_val = C_val + f(0.0, val)
end
Expand Down
8 changes: 5 additions & 3 deletions src/util/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -441,10 +441,12 @@ function (ctx::MarkDead)(ex, res)
body_2 = body
while true
ctx_2 = copy(ctx)
body_2 = ctx(body, false)
ctx_3 = branch(ctx)
body_2 = ctx_3(body, false)
meet!(ctx, ctx_3)
ext = ctx(ext, iseffectful(body_2))
ctx == ctx_2 && break
end
ext = ctx(ext, iseffectful(body_2))
return Expr(:for, Expr(:(=), i, ext), body_2)
elseif @capture ex :while(~cond, ~body)
body_2 = body
Expand Down Expand Up @@ -531,4 +533,4 @@ Base.@propagate_inbounds function scansearch(v, x, lo::T, hi::T)::T where T<:Int
end
end
return hi
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ begin
Finch.resize_if_smaller!(res_lvl_val, res_lvl_qos_stop)
Finch.fill_range!(res_lvl_val, false, res_lvl_qos, res_lvl_qos_stop)
end
res_lvl_val[res_lvl_qos] = tmp_lvl_2_val
res = (res_lvl_val[res_lvl_qos] = tmp_lvl_2_val)
res_lvl_idx[res_lvl_qos] = i_8
res_lvl_qos += 1
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ begin
Finch.fill_range!(res_lvl_val, false, res_lvl_qos, res_lvl_qos_stop)
end
tmp_lvl_2_val = tmp_lvl_val[tmp_lvl_q + -1 + i_5]
res_lvl_val[res_lvl_qos] = tmp_lvl_2_val
res = (res_lvl_val[res_lvl_qos] = tmp_lvl_2_val)
res_lvl_idx[res_lvl_qos] = i_5
res_lvl_qos += 1
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ begin
end
tmp_lvl_q = tmp_lvl_q_ofs + i_8
tmp_lvl_2_val = tmp_lvl_val[tmp_lvl_q]
res_lvl_val[res_lvl_qos] = tmp_lvl_2_val
res = (res_lvl_val[res_lvl_qos] = tmp_lvl_2_val)
res_lvl_idx[res_lvl_qos] = i_8
res_lvl_qos += 1
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ begin
Finch.resize_if_smaller!(res_lvl_2_val, res_lvl_2_qos_stop)
Finch.fill_range!(res_lvl_2_val, false, res_lvl_2_qos, res_lvl_2_qos_stop)
end
res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val
res = (res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val)
res_lvldirty = true
res_lvl_idx_2[res_lvl_2_qos] = i_8
res_lvl_2_qos += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ begin
Finch.fill_range!(res_lvl_2_val, false, res_lvl_2_qos, res_lvl_2_qos_stop)
end
tmp_lvl_3_val = tmp_lvl_2_val[tmp_lvl_2_q + -1 + i_5]
res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val
res = (res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val)
res_lvldirty = true
res_lvl_idx_2[res_lvl_2_qos] = i_5
res_lvl_2_qos += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ begin
end
tmp_lvl_2_q = tmp_lvl_2_q_ofs + i_8
tmp_lvl_3_val = tmp_lvl_2_val[tmp_lvl_2_q]
res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val
res = (res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val)
res_lvldirty = true
res_lvl_idx_2[res_lvl_2_qos] = i_8
res_lvl_2_qos += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ begin
Finch.resize_if_smaller!(res_lvl_2_val, res_lvl_2_qos_stop)
Finch.fill_range!(res_lvl_2_val, false, res_lvl_2_qos, res_lvl_2_qos_stop)
end
res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val
res = (res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val)
res_lvldirty = true
res_lvl_idx_2[res_lvl_2_qos] = i_8
res_lvl_2_qos += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ begin
Finch.fill_range!(res_lvl_2_val, false, res_lvl_2_qos, res_lvl_2_qos_stop)
end
tmp_lvl_3_val = tmp_lvl_2_val[tmp_lvl_2_q + -1 + i_5]
res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val
res = (res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val)
res_lvldirty = true
res_lvl_idx_2[res_lvl_2_qos] = i_5
res_lvl_2_qos += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ begin
end
tmp_lvl_2_q = tmp_lvl_2_q_ofs + i_8
tmp_lvl_3_val = tmp_lvl_2_val[tmp_lvl_2_q]
res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val
res = (res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val)
res_lvldirty = true
res_lvl_idx_2[res_lvl_2_qos] = i_8
res_lvl_2_qos += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ begin
Finch.resize_if_smaller!(res_lvl_2_val, res_lvl_2_qos_stop)
Finch.fill_range!(res_lvl_2_val, false, res_lvl_2_qos, res_lvl_2_qos_stop)
end
res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val
res = (res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val)
res_lvldirty = true
res_lvl_idx_2[res_lvl_2_qos] = i_8
res_lvl_2_qos += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ begin
Finch.fill_range!(res_lvl_2_val, false, res_lvl_2_qos, res_lvl_2_qos_stop)
end
tmp_lvl_2_val = tmp_lvl_val[tmp_lvl_s + -1 + i_5]
res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_2_val
res = (res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_2_val)
res_lvldirty = true
res_lvl_idx_2[res_lvl_2_qos] = i_5
res_lvl_2_qos += 1
Expand Down
10 changes: 10 additions & 0 deletions test/reference32/issue288_concordize_double_let.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,15 @@ begin
A_mode2_stop = sugar_8[2]
sugar_9 = size(C)
C_mode2_stop = sugar_9[2]
C_mode3_stop = sugar_9[3]
for j_5 = 1:C_mode2_stop
sugar_12 = size(C)
C_mode2_stop = sugar_12[2]
C_mode3_stop = sugar_12[3]
for i_9 = 1:A_mode1_stop
sugar_15 = size(C)
C_mode2_stop = sugar_15[2]
C_mode3_stop = sugar_15[3]
val = X[i_9, j_5]
for l_6 = 1:A_mode2_stop
val_2 = A[i_9, l_6, k_6]
Expand All @@ -41,6 +48,9 @@ begin
C[i_9, j_5, k_6] = val_2 * val + C[i_9, j_5, k_6]
end
end
sugar_19 = size(A)
A_mode1_stop = sugar_19[1]
A_mode2_stop = sugar_19[2]
val_3 = A[i_9, l_6, k_6]
phase_stop_3 = min((l_6 + 0) + -1, i_9)
if phase_stop_3 >= i_9
Expand Down
10 changes: 10 additions & 0 deletions test/reference32/issue288_concordize_let.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,20 @@ begin
A_mode2_stop = sugar_4[2]
sugar_5 = size(C)
C_mode2_stop = sugar_5[2]
C_mode3_stop = sugar_5[3]
for j_4 = 1:C_mode2_stop
sugar_7 = size(C)
C_mode2_stop = sugar_7[2]
C_mode3_stop = sugar_7[3]
for i_6 = 1:A_mode1_stop
sugar_9 = size(C)
C_mode2_stop = sugar_9[2]
C_mode3_stop = sugar_9[3]
val = X[i_6, j_4]
for l_4 = 1:A_mode2_stop
sugar_11 = size(A)
A_mode1_stop = sugar_11[1]
A_mode2_stop = sugar_11[2]
val_2 = A[i_6, l_4, k_4]
phase_stop = min(i_6, (l_4 + 0) + -1)
if phase_stop >= i_6
Expand Down
2 changes: 2 additions & 0 deletions test/reference32/typical_spmv_sparsematrixcsc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ quote
A.m == y_mode1_stop || throw(DimensionMismatch("mismatched dimension stop"))
fill!(y, 0)
for j_4 = 1:x_mode1_stop
sugar_5 = size(x)
x_mode1_stop = sugar_5[1]
val = x[j_4]
A_q = A.colptr[j_4]
A_q_stop = A.colptr[j_4 + 1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ begin
Finch.resize_if_smaller!(res_lvl_val, res_lvl_qos_stop)
Finch.fill_range!(res_lvl_val, false, res_lvl_qos, res_lvl_qos_stop)
end
res_lvl_val[res_lvl_qos] = tmp_lvl_2_val
res = (res_lvl_val[res_lvl_qos] = tmp_lvl_2_val)
res_lvl_idx[res_lvl_qos] = i_8
res_lvl_qos += 1
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ begin
Finch.fill_range!(res_lvl_val, false, res_lvl_qos, res_lvl_qos_stop)
end
tmp_lvl_2_val = tmp_lvl_val[tmp_lvl_q + -1 + i_5]
res_lvl_val[res_lvl_qos] = tmp_lvl_2_val
res = (res_lvl_val[res_lvl_qos] = tmp_lvl_2_val)
res_lvl_idx[res_lvl_qos] = i_5
res_lvl_qos += 1
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ begin
end
tmp_lvl_q = tmp_lvl_q_ofs + i_8
tmp_lvl_2_val = tmp_lvl_val[tmp_lvl_q]
res_lvl_val[res_lvl_qos] = tmp_lvl_2_val
res = (res_lvl_val[res_lvl_qos] = tmp_lvl_2_val)
res_lvl_idx[res_lvl_qos] = i_8
res_lvl_qos += 1
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ begin
Finch.resize_if_smaller!(res_lvl_2_val, res_lvl_2_qos_stop)
Finch.fill_range!(res_lvl_2_val, false, res_lvl_2_qos, res_lvl_2_qos_stop)
end
res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val
res = (res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val)
res_lvldirty = true
res_lvl_idx_2[res_lvl_2_qos] = i_8
res_lvl_2_qos += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ begin
Finch.fill_range!(res_lvl_2_val, false, res_lvl_2_qos, res_lvl_2_qos_stop)
end
tmp_lvl_3_val = tmp_lvl_2_val[tmp_lvl_2_q + -1 + i_5]
res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val
res = (res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val)
res_lvldirty = true
res_lvl_idx_2[res_lvl_2_qos] = i_5
res_lvl_2_qos += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ begin
end
tmp_lvl_2_q = tmp_lvl_2_q_ofs + i_8
tmp_lvl_3_val = tmp_lvl_2_val[tmp_lvl_2_q]
res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val
res = (res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val)
res_lvldirty = true
res_lvl_idx_2[res_lvl_2_qos] = i_8
res_lvl_2_qos += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ begin
Finch.resize_if_smaller!(res_lvl_2_val, res_lvl_2_qos_stop)
Finch.fill_range!(res_lvl_2_val, false, res_lvl_2_qos, res_lvl_2_qos_stop)
end
res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val
res = (res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val)
res_lvldirty = true
res_lvl_idx_2[res_lvl_2_qos] = i_8
res_lvl_2_qos += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ begin
Finch.fill_range!(res_lvl_2_val, false, res_lvl_2_qos, res_lvl_2_qos_stop)
end
tmp_lvl_3_val = tmp_lvl_2_val[tmp_lvl_2_q + -1 + i_5]
res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val
res = (res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val)
res_lvldirty = true
res_lvl_idx_2[res_lvl_2_qos] = i_5
res_lvl_2_qos += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ begin
end
tmp_lvl_2_q = tmp_lvl_2_q_ofs + i_8
tmp_lvl_3_val = tmp_lvl_2_val[tmp_lvl_2_q]
res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val
res = (res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val)
res_lvldirty = true
res_lvl_idx_2[res_lvl_2_qos] = i_8
res_lvl_2_qos += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ begin
Finch.resize_if_smaller!(res_lvl_2_val, res_lvl_2_qos_stop)
Finch.fill_range!(res_lvl_2_val, false, res_lvl_2_qos, res_lvl_2_qos_stop)
end
res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val
res = (res_lvl_2_val[res_lvl_2_qos] = tmp_lvl_3_val)
res_lvldirty = true
res_lvl_idx_2[res_lvl_2_qos] = i_8
res_lvl_2_qos += 1
Expand Down
Loading

0 comments on commit df553b8

Please sign in to comment.