diff --git a/src/modeling/determinestrategy.jl b/src/modeling/determinestrategy.jl index ae4576234..ff475d249 100644 --- a/src/modeling/determinestrategy.jl +++ b/src/modeling/determinestrategy.jl @@ -278,7 +278,6 @@ function unroll_no_reductions(ls, order, vloopsym) end end # latency not a concern, because no depchains - innerloop = last(order) compute_l = 0.0 rpp = 0 # register pressure proportional to unrolling rpc = 0 # register pressure independent of unroll factor @@ -475,19 +474,6 @@ end # X[1]*u₂factor*u₁factor + X[4] + X[2] * u₂factor + X[3] * u₁factor X[1] + X[2] * u₂factor + X[3] * u₁factor + X[4] * u₁factor * u₂factor end -# function itertilesize(X, u₁L, u₂L) -# cb = Inf -# u₁b = 1; u₂b = 1 -# for u₁ ∈ 1:4, u₂ ∈ 1:4 -# c = unroll_cost(X, u₁, u₂, u₁L, u₂L) -# if cb > c -# cb = c -# u₁b = u₁; u₂b = u₂ -# end -# end -# u₁b, u₂b, cb -# end - function solve_unroll_iter(X, R, u₁L, u₂L, u₁range, u₂range) R₁, R₂, R₃, R₄ = R[1], R[2], R[3], R[4] RR = R₄ @@ -858,38 +844,6 @@ function isoptranslation(ls::LoopSet, op::Operation, unrollsyms::UnrollSymbols) end 0, 0x00 end -# function maxnegativeoffset_old(ls::LoopSet, op::Operation, u::Symbol) -# opmref = op.ref -# opref = opmref.ref -# mno = typemin(Int) -# id = 0 -# opoffs = opref.offsets -# for opp ∈ operations(ls) -# opp === op && continue -# oppmref = opp.ref -# oppref = oppmref.ref -# sameref(opref, oppref) || continue -# opinds = getindicesonly(op) -# oppinds = getindicesonly(opp) -# oppoffs = oppref.offsets -# # oploopi = opmref.loopedindex -# # opploopi = oppmref.loopedindex -# mnonew = typemin(Int) -# for i ∈ eachindex(opinds) -# if opinds[i] === u -# mnonew = (opoffs[i] - oppoffs[i]) -# elseif opoffs[i] != oppoffs[i] -# mnonew = 1 -# break -# end -# end -# if mno < mnonew < 0 -# mno = mnonew -# id = identifier(opp) -# end -# end -# mno, id -# end function maxnegativeoffset(ls::LoopSet, op::Operation, u::Symbol) mno::Int = typemin(Int) id = 0 @@ -966,7 +920,7 @@ function load_elimination_cost_factor!( if !iszero(first(isoptranslation(ls, op, unrollsyms))) rt, lat, rp = cost(ls, op, (u₁loopsym, u₂loopsym), vloopsym, Wshift, size_T) # rt = Core.ifelse(isvectorized(op), 0.5rt, rt) - rto = rt + # rto = rt rt *= iters # rt *= factor1; rp *= factor2; choose_to_inline[] = true @@ -1412,12 +1366,28 @@ end # syms = collect(Base.OneTo(num_loops(ls))) # LoopOrders(syms, similar(syms)) # end + +function memcost_sort!(loops::Vector{Symbol}, ls::LoopSet) + costs = zeros(Int, length(loops)) + for op ∈ operations(ls) + if accesses_memory(op) + for (i,l) in enumerate(loops) + j = findfirst(==(l), getindices(op)) + if j !== nothing + costs[i] += j + end + end + end + end + copyto!(loops, loops[sortperm(costs,rev=true)]) +end + struct LoopOrders syms_nr::Vector{Symbol} syms_r::Vector{Symbol} buff::Vector{Symbol} end - +max_size_for_looporder_brute() = 2 function outer_reduct_loopordersplit(ls::LoopSet) ops = operations(ls) nonouterreducts = Int[] @@ -1461,6 +1431,9 @@ function LoopOrders(ls::LoopSet) else reductsyms, nonreductsyms = outer_reduct_loopordersplit(ls) end + size_thresh = max_size_for_looporder_brute() + length(reductsyms) > size_thresh && memcost_sort!(reductsyms, ls) + length(nonreductsyms) > size_thresh && memcost_sort!(nonreductsyms, ls) LoopOrders(nonreductsyms, reductsyms, Vector{Symbol}(undef, length(ls.loopsymbols))) end @@ -1472,7 +1445,10 @@ function Base.iterate(lo::LoopOrders) nr = length(lo.syms_nr) r = length(lo.syms_r) state = zeros(Int, nr + r) - lo.buff, (view(state, 1:nr), view(state, 1+nr:nr+r)) + size_thresh = max_size_for_looporder_brute() + nrstate = nr > size_thresh ? view(state, 1:1) : view(state, 1:nr) + rstate = r > size_thresh ? view(state, 1+nr:1+nr) : view(state, 1+nr:nr+r) + lo.buff, (nrstate, rstate) end function advance_state!(state) @@ -1493,12 +1469,6 @@ function advance_state!(state) end true end -function advance_state!(state, Nr) - state_nr = view(state, 1:Nr) - advance_state!(state_nr) && return true - fill!(state_nr, 0) - advance_state!(view(state, 1+Nr:length(state))) -end swap!(x::AbstractVector, i::Int, j::Int) = (x[j], x[i]) = (x[i], x[j]) function swap!( dest::AbstractVector{Symbol}, @@ -1515,11 +1485,12 @@ end function Base.iterate(lo::LoopOrders, (state_nr, state_r)) if advance_state!(state_nr) swap!(nonreductview(lo), lo.syms_nr, state_nr) - else - advance_state!(state_r) || return nothing + elseif advance_state!(state_r) fill!(state_nr, 0) copyto!(nonreductview(lo), lo.syms_nr) swap!(reductview(lo), lo.syms_r, state_r) + else + return nothing end lo.buff, (state_nr, state_r) end