From 04962e526e93009acf226d3454e34a96e3480749 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Fri, 22 Dec 2023 09:07:58 -0500 Subject: [PATCH] optimize `bareiss_update_virtual_colswap_mtk!` --- src/systems/alias_elimination.jl | 30 -------- src/systems/sparsematrixclil.jl | 120 +++++++++++++++++++++++++++---- 2 files changed, 107 insertions(+), 43 deletions(-) diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 38de4fbb31..6af95b5902 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -382,36 +382,6 @@ end swap!(v, i, j) = v[i], v[j] = v[j], v[i] -function getcoeff(vars, coeffs, var) - Nvars = length(vars) - i = 0 - chunk_size = 8 - @inbounds while i < Nvars - chunk_size + 1 - btup = let vars = vars, var = var, i = i - ntuple(Val(chunk_size)) do j - @inbounds vars[i + j] == var - end - end - inds = ntuple(Base.Fix2(-, 1), Val(8)) - eights = ntuple(Returns(8), Val(8)) - inds = map(ifelse, btup, inds, eights) - inds4 = (min(inds[1], inds[5]), - min(inds[2], inds[6]), - min(inds[3], inds[7]), - min(inds[4], inds[8])) - inds2 = (min(inds4[1], inds4[3]), min(inds4[2], inds4[4])) - ind = min(inds2[1], inds2[2]) - if ind != 8 - return coeffs[i + ind + 1] - end - i += chunk_size - end - @inbounds for vj in (i + 1):Nvars - vars[vj] == var && return coeffs[vj] - end - return 0 -end - """ $(SIGNATURES) diff --git a/src/systems/sparsematrixclil.jl b/src/systems/sparsematrixclil.jl index dca48973c4..92f12ca926 100644 --- a/src/systems/sparsematrixclil.jl +++ b/src/systems/sparsematrixclil.jl @@ -169,7 +169,7 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap # case for MTK (where most pivots are `1` or `-1`). pivot_equal = pivot_equal_optimization && abs(pivot) == abs(last_pivot) - for ei in (k + 1):size(M, 1) + @inbounds for ei in (k + 1):size(M, 1) # eliminate `v` coeff = 0 ivars = eadj[ei] @@ -193,18 +193,112 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap tmp_coeffs = similar(old_cadj[ei], 0) # TODO: We know both ivars and kvars are sorted, we could just write # a quick iterator here that does this without allocation/faster. - vars = sort(union(ivars, kvars)) - - for v in vars - v == vpivot && continue - ck = getcoeff(kvars, kcoeffs, v) - ci = getcoeff(ivars, icoeffs, v) - p1 = Base.Checked.checked_mul(pivot, ci) - p2 = Base.Checked.checked_mul(coeff, ck) - ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot) - if !iszero(ci) - push!(tmp_incidence, v) - push!(tmp_coeffs, ci) + numkvars = length(kvars) + numivars = length(ivars) + kvind = ivind = 0 + if _debug_mode + # in debug mode, we at least check to confirm we're iterating over + # `v`s in the correct order + vars = sort(union(ivars, kvars)) + vi = 0 + end + if numivars > 0 && numkvars > 0 + kvv = kvars[kvind += 1] + ivv = ivars[ivind += 1] + dobreak = false + while true + if kvv == ivv + v = kvv + ck = kcoeffs[kvind] + ci = icoeffs[ivind] + kvind += 1 + ivind += 1 + if kvind > numkvars + dobreak = true + else + kvv = kvars[kvind] + end + if ivind > numivars + dobreak = true + else + ivv = ivars[ivind] + end + elseif kvv < ivv + v = kvv + ck = kcoeffs[kvind] + ci = zero(eltype(icoeffs)) + kvind += 1 + if kvind > numkvars + dobreak = true + else + kvv = kvars[kvind] + end + else # kvv > ivv + v = ivv + ck = zero(eltype(kcoeffs)) + ci = icoeffs[ivind] + ivind += 1 + if ivind > numivars + dobreak = true + else + ivv = ivars[ivind] + end + end + if _debug_mode + @assert v == vars[vi += 1] + end + if v != vpivot + p1 = Base.Checked.checked_mul(pivot, ci) + p2 = Base.Checked.checked_mul(coeff, ck) + ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot) + if !iszero(ci) + push!(tmp_incidence, v) + push!(tmp_coeffs, ci) + end + end + dobreak && break + end + elseif numivars == 0 + ivind = 1 + kvv = kvars[kvind += 1] + else # numkvars == 0 + kvind = 1 + ivv = ivars[ivind += 1] + end + if kvind <= numkvars + v = kvv + while true + if _debug_mode + @assert v == vars[vi += 1] + end + if v != vpivot + ck = kcoeffs[kvind] + p2 = Base.Checked.checked_mul(coeff, ck) + ci = exactdiv(Base.Checked.checked_sub(0, p2), last_pivot) + if !iszero(ci) + push!(tmp_incidence, v) + push!(tmp_coeffs, ci) + end + end + (kvind == numkvars) && break + v = kvars[kvind += 1] + end + elseif ivind <= numivars + v = ivv + while true + if _debug_mode + @assert v == vars[vi += 1] + end + if v != vpivot + p1 = Base.Checked.checked_mul(pivot, icoeffs[ivind]) + ci = exactdiv(p1, last_pivot) + if !iszero(ci) + push!(tmp_incidence, v) + push!(tmp_coeffs, ci) + end + end + (ivind == numivars) && break + v = ivars[ivind += 1] end end eadj[ei] = tmp_incidence