Skip to content

Commit

Permalink
tests pass locally
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Aug 27, 2024
1 parent af56460 commit 8cc9737
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 66 deletions.
99 changes: 40 additions & 59 deletions src/rdivl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,34 +273,18 @@ end
end
@generated function ldivu_solve_W_u!(
spa,
spu,
spl,
n,
::StaticInt{W},
::StaticInt{U},
::Val{UNIT}
) where {W,U,UNIT}
z = static(0)
# B_{n,m} = (A_{n,m} - \sum_{i=n+1}^N U_{n,i}B_{i,m})/U_{n,n}
# Actually a row-major rdivl
# B_{m,n} = (A_{m,n} - \sum_{i=n+1}^N B_{m,i}L_{i,n})/L_{n,n}
Aind = Unroll{1,1,W,2,W,zero(UInt),1}(Unroll{2,W,U,2,W,zero(UInt),1}((z, z)))
q = quote
# $(Expr(:meta, :inline))
# C = U \ A; U * C = A
# A_{i,j} = U_{i,i}*C_{i,j} + \sum_{k=i+1}^{N}U_{i,k}C_{k,j}
# C_{i,j} = U_{i,i} \ (A_{i,j} - \sum_{k=i+1}^{N}U_{i,k}C_{k,j})
# The inputs here are transposed, as the library was formulated in terms of `rdiv!`,
# so we have
# C_{j,i} = (A_{j,i} - \sum_{k=i+1}^{N}C_{j,k}U_{k,i}) / L_{i,i}
# This solves for the block: C_{j+[0,W],i+[0,W*U)}
# This can be viewed as `U` blocks that are each `W`x`W`
# E.g. U=3, rough alg:
# r=[0,W); c=[0,WU)
# X = A_{j+r,i+c} - \sum_{k=1}^{i-1}C_{j+r,k}*U_{k,i+c}
# C_{j+r,i+r} = X[:, r] / U_{i+r,i+r}
# C_{j+r,i+W+r} = (X[:, W+r] - C_{j+r,i+r}*U_{i+r,i+W+r}) / U_{i+W+r,i+W+r}
# C_{j+r,i+2W+r} = (X[:, 2W+r] - C_{j+r,i+r}*U_{i+r,i+2W+r} - C_{j+r,i+W+r}*U_{i+W+r,i+2W+r}) / U_{i+2W+r,i+2W+r}
#
# outer unroll are `W` rows
# Inner unroll are `W*U` columns (U simd vecs)
#
A11 = getfield(vload(spa, $Aind), :data)
# The `W` rows
Expand All @@ -310,8 +294,8 @@ end
# Each iter:
# A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)}
for nk SafeCloseOpen(n) # nmuladd
nkw = nk + $W
U_ki = vload(spu, $(Unroll{2,W,U,2,W,zero(UInt),1})((nkw, $z)))
nkw = nk + $(W * U)
U_ki = vload(spl, $(Unroll{2,W,U,2,W,zero(UInt),1})((nkw, $z)))
Base.Cartesian.@nexprs $W c ->
A11_c = vfnmadd_fast(U_ki, vload(spa, (static(c - 1), nkw)), A11_c)
end
Expand All @@ -326,6 +310,7 @@ end
Xu = Vector{Symbol}(undef, W)
Csym = Vector{Symbol}(undef, U)
for u = 1:U
# X_u are future
X_u = Symbol(:X_, u)
push!(
q.args,
Expand All @@ -341,18 +326,19 @@ end
)
)
)
# push!(q.args, :(println($X_u)))
for c = 1:W
X_u_c = Xu[c] = Symbol(:X_, u, :_, c)
push!(q.args, Expr(:(=), X_u_c, Expr(:call, getfield, X_u, c)))
end
# take A[(U-u+1)*W,u*W), [0,W)]
for j = 1:u-1
for k = 1:W
for c = 1:W
urow = ((W - k) + ((j - 1) * W))
for j = 1:u-1 # iter over all blocks ordered after
for k = 1:W # reduction dimension, reverse order
for c = 1:W # columns of C
urow = ((W - k) + ((U - j) * W))
ucol = ((c - 1) + ((U - u) * W))
push!(q.args, Expr(:call, println, "Row = $urow; Col = $ucol"))
Uexpr = :(vload(spu, ($urow, $ucol)))
# push!(q.args, Expr(:call, println, "Row = $urow; Col = $ucol"))
Uexpr = :(vload(spl, ($urow, $ucol)))
X_u_c = Xu[c]
C_j_k = Symbol(:C_, j, :_, W + 1 - k)
Xucexpr = Expr(:call, vfnmadd_fast, C_j_k, Uexpr, X_u_c)
Expand All @@ -361,7 +347,7 @@ push!(q.args, Expr(:call, println, "Row = $urow; Col = $ucol"))
end
end
o = (U - u) * W
sp = Expr(:call, gesp, :spu, (o, o))
sp = Expr(:call, gesp, :spl, (o, o))
Xut = Expr(:tuple)
for c = 1:W
push!(Xut.args, Xu[c])
Expand All @@ -379,6 +365,7 @@ push!(q.args, Expr(:call, println, "Row = $urow; Col = $ucol"))
end
end
for u = 1:U
# u = 1 is last, first processed (reverse order)
ui = Unroll{2,1,W,1,W,zero(UInt),1}((z, (U - u) * W))
C_u = Csym[u]
push!(q.args, :(vstore!(spa, $C_u, $ui)))
Expand All @@ -387,7 +374,7 @@ push!(q.args, Expr(:call, println, "Row = $urow; Col = $ucol"))
end
@generated function ldivu_solve_W!(
spa,
spu,
spl,
n,
::StaticInt{W},
::Val{UNIT},
Expand Down Expand Up @@ -425,7 +412,7 @@ end
# A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)}
for nk SafeCloseOpen(n) # nmuladd
nkw = nk + $W
U_ki = vload(spu, (nkw, $(MM{W}(z))))
U_ki = vload(spl, (nkw, $(MM{W}(z))))
Base.Cartesian.@nexprs $R r ->
A11_r = vfnmadd_fast(U_ki, vload(spa, (static(r - 1), nkw)), A11_r)
end
Expand All @@ -444,7 +431,7 @@ end
# We then have column-major multiplies
# take A[(u-1)*W,u*W), [0,W)]
X = VectorizationBase.transpose_vecunroll(VecUnroll($t))
C_u = solve_AL(X, spu, $(Val(UNIT)))
C_u = solve_AL(X, spl, $(Val(UNIT)))
end
push!(q.args, q2)
q3 = if R == Wpad
Expand All @@ -465,7 +452,7 @@ end

@generated function _ldivu_remainder!(
spa,
spu,
spl,
N,
Nr,
::StaticInt{W},
Expand All @@ -486,7 +473,7 @@ end
vlxj = :(xj = $vlxj)
else
vlxj = quote
xj = $div($vlxj, vload(spu, (j, j)))
xj = $div($vlxj, vload(spl, (j, j)))
vstore!(spa, xj, ($z, j))
end
end
Expand All @@ -500,7 +487,7 @@ end
while i > 0
i -= 1
xi = vload(spa, ($z, i))
Uji = vload(spu, (j, i))
Uji = vload(spl, (j, i))
vstore!(spa, $sub(xi, $mul(xj, Uji)), ($z, i))
end
j == 0 && break
Expand All @@ -514,19 +501,17 @@ end
mask = $(getfield(_mask(WS, r), :u) % UInt32)
n = N - Nr
if Nr > 0
@show pointer(spa), pointer(spu), n, Nr
let t = (gesp(spa, ($z, n)), gesp(spu, (n, n))), ft = flatten_to_tup(t)
let t = (gesp(spa, ($z, n)), gesp(spl, (n, n))), ft = flatten_to_tup(t)
BdivL_small_kern!(Nr, mask, $WS, $(Val(UNIT)), typeof(t), ft...)
end
end
# non-U, order first as matmul kern is smaller than optimal
while n != 0
k = N - n
n -= W
@show pointer(spa), pointer(spu), k, n
ldivu_solve_W!(
gesp(spa, ($z, n)),
gesp(spu, (n, n)),
gesp(spl, (n, n)),
k,
$WS,
Val(UNIT),
Expand All @@ -551,25 +536,25 @@ end
if W == 2
quote
$(Expr(:meta, :inline))
spa, spu = reassemble_tup(Args, args)
_ldivu_remainder!(spa, spu, N, Nrr, Nru, $WS, $(Val(UNIT)), $(static(1)))
spa, spl = reassemble_tup(Args, args)
_ldivu_remainder!(spa, spl, N, Nrr, Nru, $WS, $(Val(UNIT)), $(static(1)))
nothing
end
elseif W == 8
s8 = StaticInt(8)
quote
# $(Expr(:meta, :inline))
spa, spu = reassemble_tup(Args, args)
spa, spl = reassemble_tup(Args, args)
if m == M - 1
_ldivu_remainder!(spa, spu, N, Nr, $s8, $(Val(UNIT)), $(StaticInt(1)))
_ldivu_remainder!(spa, spl, N, Nr, $s8, $(Val(UNIT)), $(StaticInt(1)))
else
if m == M - 2
_ldivu_remainder!(spa, spu, N, Nr, $s8, $(Val(UNIT)), $(StaticInt(2)))
_ldivu_remainder!(spa, spl, N, Nr, $s8, $(Val(UNIT)), $(StaticInt(2)))
else
if m == M - 3
_ldivu_remainder!(
spa,
spu,
spl,
N,
Nr,
$s8,
Expand All @@ -580,7 +565,7 @@ end
if m == M - 4
_ldivu_remainder!(
spa,
spu,
spl,
N,
Nr,
$s8,
Expand All @@ -591,7 +576,7 @@ end
if m == M - 5
_ldivu_remainder!(
spa,
spu,
spl,
N,
Nr,
$s8,
Expand All @@ -602,7 +587,7 @@ end
if m == M - 6
_ldivu_remainder!(
spa,
spu,
spl,
N,
Nr,
$s8,
Expand All @@ -612,7 +597,7 @@ end
else
_ldivu_remainder!(
spa,
spu,
spl,
N,
Nr,
$s8,
Expand All @@ -630,9 +615,9 @@ end
else
quote
# $(Expr(:meta, :inline))
spa, spu = reassemble_tup(Args, args)
spa, spl = reassemble_tup(Args, args)
Base.Cartesian.@nif $(W - 1) w -> m == M - w w ->
_ldivu_remainder!(spa, spu, N, Nr, $WS, $(Val(UNIT)), static(w))
_ldivu_remainder!(spa, spl, N, Nr, $WS, $(Val(UNIT)), static(w))
nothing
end
end
Expand All @@ -646,7 +631,7 @@ function _ldivu_L!(
args::Vararg{Any,K}
) where {UNIT,Args,K}
# B_{n,m} = (A_{n,m} - \sum_{i=n+1}^N U_{n,i}B_{i,m})/U_{n,n}
spa, spu = reassemble_tup(Args, args)
spa, spl = reassemble_tup(Args, args)
T = eltype(spa)
WS = pick_vector_width(T)
W = Int(WS)
Expand All @@ -664,8 +649,7 @@ function _ldivu_L!(
while m < M - WS + 1
n::Int = nstart
if Nrr > 0
let t = (gesp(spa, (z, n)), gesp(spu, (n, n))), ft = flatten_to_tup(t)
@show 0, n
let t = (gesp(spa, (z, n)), gesp(spl, (n, n))), ft = flatten_to_tup(t)
compute && BdivL_small_kern_u!(
Nrr,
StaticInt(1),
Expand All @@ -680,10 +664,9 @@ function _ldivu_L!(
for _ 1:Ndr
k = N - n
n -= W
@show 1, n, k
compute && ldivu_solve_W!(
gesp(spa, (z, n)),
gesp(spu, (n, n)),
gesp(spl, (n, n)),
k,
WS,
Val(UNIT),
Expand All @@ -694,10 +677,9 @@ function _ldivu_L!(
while n != 0
k = N - n
n -= Int(WU)
@show 2, n, k
compute && ldivu_solve_W_u!(
gesp(spa, (z, n)),
gesp(spu, (n, n)),
gesp(spl, (n, n)),
k,
WS,
UF,
Expand All @@ -709,8 +691,7 @@ function _ldivu_L!(
end
# remainder on `m`
if m < M
let tup = (spa, spu), ftup = flatten_to_tup(tup)
@show m, Nrr, M
let tup = (spa, spl), ftup = flatten_to_tup(tup)
compute &&
ldivu_remainder!(M, N, m, Nrr, WS, Val(UNIT), typeof(tup), ftup...)
end
Expand Down
14 changes: 7 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function test_solve(::Type{T}) where {T}
for n 1:maxN
@show n
for m max(1, n - 10):n+10
@show m
# @show m
A = @view AA[17:16+m, 17:16+n]
res = @view RR[17:16+m, 17:16+n]
B = @view BB[17:16+n, 17:16+n]
Expand All @@ -30,10 +30,10 @@ function test_solve(::Type{T}) where {T}
for C in (
UpperTriangular(B),
UnitUpperTriangular(B),
# LowerTriangular(B),
# UnitLowerTriangular(B)
LowerTriangular(B),
UnitLowerTriangular(B)
)
@show typeof(C)
# @show typeof(C)
@test TriangularSolve.rdiv!(res, A, C) * C A
check_box_for_nans(RR, m, n)
@test TriangularSolve.rdiv!(res, A, C, Val(false)) * C A
Expand All @@ -48,12 +48,12 @@ function test_solve(::Type{T}) where {T}
A .= rand.(T)

for C in (
# UpperTriangular(B),
# UnitUpperTriangular(B),
UpperTriangular(B),
UnitUpperTriangular(B),
LowerTriangular(B),
UnitLowerTriangular(B)
)
@show typeof(C)
# @show typeof(C)
@test C * TriangularSolve.ldiv!(res, C, A) A
check_box_for_nans(RR, n, m)
@test C * TriangularSolve.ldiv!(res, C, A, Val(false)) A
Expand Down

0 comments on commit 8cc9737

Please sign in to comment.