Skip to content

Commit

Permalink
optimize bareiss_update_virtual_colswap_mtk!
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Dec 22, 2023
1 parent e089c9e commit 04962e5
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 43 deletions.
30 changes: 0 additions & 30 deletions src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,36 +382,6 @@ end

swap!(v, i, j) = v[i], v[j] = v[j], v[i]

function getcoeff(vars, coeffs, var)
Nvars = length(vars)
i = 0
chunk_size = 8
@inbounds while i < Nvars - chunk_size + 1
btup = let vars = vars, var = var, i = i
ntuple(Val(chunk_size)) do j
@inbounds vars[i + j] == var
end
end
inds = ntuple(Base.Fix2(-, 1), Val(8))
eights = ntuple(Returns(8), Val(8))
inds = map(ifelse, btup, inds, eights)
inds4 = (min(inds[1], inds[5]),
min(inds[2], inds[6]),
min(inds[3], inds[7]),
min(inds[4], inds[8]))
inds2 = (min(inds4[1], inds4[3]), min(inds4[2], inds4[4]))
ind = min(inds2[1], inds2[2])
if ind != 8
return coeffs[i + ind + 1]
end
i += chunk_size
end
@inbounds for vj in (i + 1):Nvars
vars[vj] == var && return coeffs[vj]
end
return 0
end

"""
$(SIGNATURES)
Expand Down
120 changes: 107 additions & 13 deletions src/systems/sparsematrixclil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
# case for MTK (where most pivots are `1` or `-1`).
pivot_equal = pivot_equal_optimization && abs(pivot) == abs(last_pivot)

for ei in (k + 1):size(M, 1)
@inbounds for ei in (k + 1):size(M, 1)
# eliminate `v`
coeff = 0
ivars = eadj[ei]
Expand All @@ -193,18 +193,112 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
tmp_coeffs = similar(old_cadj[ei], 0)
# TODO: We know both ivars and kvars are sorted, we could just write
# a quick iterator here that does this without allocation/faster.
vars = sort(union(ivars, kvars))

for v in vars
v == vpivot && continue
ck = getcoeff(kvars, kcoeffs, v)
ci = getcoeff(ivars, icoeffs, v)
p1 = Base.Checked.checked_mul(pivot, ci)
p2 = Base.Checked.checked_mul(coeff, ck)
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
if !iszero(ci)
push!(tmp_incidence, v)
push!(tmp_coeffs, ci)
numkvars = length(kvars)
numivars = length(ivars)
kvind = ivind = 0
if _debug_mode
# in debug mode, we at least check to confirm we're iterating over
# `v`s in the correct order
vars = sort(union(ivars, kvars))
vi = 0
end
if numivars > 0 && numkvars > 0
kvv = kvars[kvind += 1]
ivv = ivars[ivind += 1]
dobreak = false
while true
if kvv == ivv
v = kvv
ck = kcoeffs[kvind]
ci = icoeffs[ivind]
kvind += 1
ivind += 1
if kvind > numkvars
dobreak = true
else
kvv = kvars[kvind]
end
if ivind > numivars
dobreak = true
else
ivv = ivars[ivind]
end
elseif kvv < ivv
v = kvv
ck = kcoeffs[kvind]
ci = zero(eltype(icoeffs))
kvind += 1
if kvind > numkvars
dobreak = true
else
kvv = kvars[kvind]
end
else # kvv > ivv
v = ivv
ck = zero(eltype(kcoeffs))
ci = icoeffs[ivind]
ivind += 1
if ivind > numivars
dobreak = true
else
ivv = ivars[ivind]
end
end
if _debug_mode
@assert v == vars[vi += 1]
end
if v != vpivot
p1 = Base.Checked.checked_mul(pivot, ci)
p2 = Base.Checked.checked_mul(coeff, ck)
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
if !iszero(ci)
push!(tmp_incidence, v)
push!(tmp_coeffs, ci)
end
end
dobreak && break
end
elseif numivars == 0
ivind = 1
kvv = kvars[kvind += 1]

Check warning on line 263 in src/systems/sparsematrixclil.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/sparsematrixclil.jl#L261-L263

Added lines #L261 - L263 were not covered by tests
else # numkvars == 0
kvind = 1
ivv = ivars[ivind += 1]

Check warning on line 266 in src/systems/sparsematrixclil.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/sparsematrixclil.jl#L265-L266

Added lines #L265 - L266 were not covered by tests
end
if kvind <= numkvars
v = kvv
while true
if _debug_mode
@assert v == vars[vi += 1]
end
if v != vpivot
ck = kcoeffs[kvind]
p2 = Base.Checked.checked_mul(coeff, ck)
ci = exactdiv(Base.Checked.checked_sub(0, p2), last_pivot)
if !iszero(ci)
push!(tmp_incidence, v)
push!(tmp_coeffs, ci)
end
end
(kvind == numkvars) && break
v = kvars[kvind += 1]
end
elseif ivind <= numivars
v = ivv
while true
if _debug_mode
@assert v == vars[vi += 1]
end
if v != vpivot
p1 = Base.Checked.checked_mul(pivot, icoeffs[ivind])
ci = exactdiv(p1, last_pivot)
if !iszero(ci)
push!(tmp_incidence, v)
push!(tmp_coeffs, ci)
end
end
(ivind == numivars) && break
v = ivars[ivind += 1]
end
end
eadj[ei] = tmp_incidence
Expand Down

0 comments on commit 04962e5

Please sign in to comment.