Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

for >3 loops in a group, don't permute all possible orders #398

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 28 additions & 57 deletions src/modeling/determinestrategy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ function unroll_no_reductions(ls, order, vloopsym)
end
end
# latency not a concern, because no depchains
innerloop = last(order)
compute_l = 0.0
rpp = 0 # register pressure proportional to unrolling
rpc = 0 # register pressure independent of unroll factor
Expand Down Expand Up @@ -475,19 +474,6 @@ end
# X[1]*u₂factor*u₁factor + X[4] + X[2] * u₂factor + X[3] * u₁factor
X[1] + X[2] * u₂factor + X[3] * u₁factor + X[4] * u₁factor * u₂factor
end
# function itertilesize(X, u₁L, u₂L)
# cb = Inf
# u₁b = 1; u₂b = 1
# for u₁ ∈ 1:4, u₂ ∈ 1:4
# c = unroll_cost(X, u₁, u₂, u₁L, u₂L)
# if cb > c
# cb = c
# u₁b = u₁; u₂b = u₂
# end
# end
# u₁b, u₂b, cb
# end

function solve_unroll_iter(X, R, u₁L, u₂L, u₁range, u₂range)
R₁, R₂, R₃, R₄ = R[1], R[2], R[3], R[4]
RR = R₄
Expand Down Expand Up @@ -858,38 +844,6 @@ function isoptranslation(ls::LoopSet, op::Operation, unrollsyms::UnrollSymbols)
end
0, 0x00
end
# function maxnegativeoffset_old(ls::LoopSet, op::Operation, u::Symbol)
# opmref = op.ref
# opref = opmref.ref
# mno = typemin(Int)
# id = 0
# opoffs = opref.offsets
# for opp ∈ operations(ls)
# opp === op && continue
# oppmref = opp.ref
# oppref = oppmref.ref
# sameref(opref, oppref) || continue
# opinds = getindicesonly(op)
# oppinds = getindicesonly(opp)
# oppoffs = oppref.offsets
# # oploopi = opmref.loopedindex
# # opploopi = oppmref.loopedindex
# mnonew = typemin(Int)
# for i ∈ eachindex(opinds)
# if opinds[i] === u
# mnonew = (opoffs[i] - oppoffs[i])
# elseif opoffs[i] != oppoffs[i]
# mnonew = 1
# break
# end
# end
# if mno < mnonew < 0
# mno = mnonew
# id = identifier(opp)
# end
# end
# mno, id
# end
function maxnegativeoffset(ls::LoopSet, op::Operation, u::Symbol)
mno::Int = typemin(Int)
id = 0
Expand Down Expand Up @@ -966,7 +920,7 @@ function load_elimination_cost_factor!(
if !iszero(first(isoptranslation(ls, op, unrollsyms)))
rt, lat, rp = cost(ls, op, (u₁loopsym, u₂loopsym), vloopsym, Wshift, size_T)
# rt = Core.ifelse(isvectorized(op), 0.5rt, rt)
rto = rt
# rto = rt
rt *= iters
# rt *= factor1; rp *= factor2;
choose_to_inline[] = true
Expand Down Expand Up @@ -1412,12 +1366,28 @@ end
# syms = collect(Base.OneTo(num_loops(ls)))
# LoopOrders(syms, similar(syms))
# end

function memcost_sort!(loops::Vector{Symbol}, ls::LoopSet)
costs = zeros(Int, length(loops))
for op ∈ operations(ls)
if accesses_memory(op)
for (i,l) in enumerate(loops)
j = findfirst(==(l), getindices(op))
if j !== nothing
costs[i] += j
end
end
end
end
copyto!(loops, loops[sortperm(costs,rev=true)])
end

struct LoopOrders
syms_nr::Vector{Symbol}
syms_r::Vector{Symbol}
buff::Vector{Symbol}
end

max_size_for_looporder_brute() = 2
function outer_reduct_loopordersplit(ls::LoopSet)
ops = operations(ls)
nonouterreducts = Int[]
Expand Down Expand Up @@ -1461,6 +1431,9 @@ function LoopOrders(ls::LoopSet)
else
reductsyms, nonreductsyms = outer_reduct_loopordersplit(ls)
end
size_thresh = max_size_for_looporder_brute()
length(reductsyms) > size_thresh && memcost_sort!(reductsyms, ls)
length(nonreductsyms) > size_thresh && memcost_sort!(nonreductsyms, ls)
LoopOrders(nonreductsyms, reductsyms, Vector{Symbol}(undef, length(ls.loopsymbols)))
end

Expand All @@ -1472,7 +1445,10 @@ function Base.iterate(lo::LoopOrders)
nr = length(lo.syms_nr)
r = length(lo.syms_r)
state = zeros(Int, nr + r)
lo.buff, (view(state, 1:nr), view(state, 1+nr:nr+r))
size_thresh = max_size_for_looporder_brute()
nrstate = nr > size_thresh ? view(state, 1:1) : view(state, 1:nr)
rstate = r > size_thresh ? view(state, 1+nr:1+nr) : view(state, 1+nr:nr+r)
lo.buff, (nrstate, rstate)
end

function advance_state!(state)
Expand All @@ -1493,12 +1469,6 @@ function advance_state!(state)
end
true
end
function advance_state!(state, Nr)
state_nr = view(state, 1:Nr)
advance_state!(state_nr) && return true
fill!(state_nr, 0)
advance_state!(view(state, 1+Nr:length(state)))
end
swap!(x::AbstractVector, i::Int, j::Int) = (x[j], x[i]) = (x[i], x[j])
function swap!(
dest::AbstractVector{Symbol},
Expand All @@ -1515,11 +1485,12 @@ end
function Base.iterate(lo::LoopOrders, (state_nr, state_r))
if advance_state!(state_nr)
swap!(nonreductview(lo), lo.syms_nr, state_nr)
else
advance_state!(state_r) || return nothing
elseif advance_state!(state_r)
fill!(state_nr, 0)
copyto!(nonreductview(lo), lo.syms_nr)
swap!(reductview(lo), lo.syms_r, state_r)
else
return nothing
end
lo.buff, (state_nr, state_r)
end
Expand Down