Skip to content

Commit

Permalink
Tests pass locally
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Apr 12, 2024
1 parent 2c81058 commit 77a9b0b
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 30 deletions.
68 changes: 44 additions & 24 deletions src/TriangularSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ end
# Each iter:
# A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)}
for nk SafeCloseOpen(n) # nmuladd
U_ki = vload(spu, $(Unroll{2,W,U,2,W,zero(UInt),1})(nk, n))
U_ki = vload(spu, $(Unroll{2,W,U,2,W,zero(UInt),1})((nk, n)))
Base.Cartesian.@nexprs $W c ->
A11_c = vfnmadd_fast(U_ki, vload(spc, (static(c - 1), nk)), A11_c)
end
Expand Down Expand Up @@ -320,7 +320,7 @@ end
end
@generated function ldiv_solve_W!(
spc,
s,
spa,
spu,
n,
::StaticInt{W},
Expand Down Expand Up @@ -354,16 +354,18 @@ end
# Each iter:
# A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)}
for nk SafeCloseOpen(n) # nmuladd
U_ki = vload(spu, (nk, $(MM{W}(z))))
U_ki = vload(spu, (nk, $(MM{W})(n)))
Base.Cartesian.@nexprs $W c ->
A11_c = vfnmadd_fast(U_ki, vload(spc, (static(c - 1), nk)), A11_c)
end
# Base.Cartesian.@nexprs $W c -> @show A11_c
# solve AU wants us to transpose
# We then have column-major multiplies
# take A[(u-1)*W,u*W), [0,W)]
X = VectorizationBase.transpose_vecunroll(
VecUnroll(Base.Cartesian.@ntuple $W A11)
)
# @show X
C_u = solve_AU(X, spu, n, $(Val(UNIT)))
vstore!(spc, C_u, $(Unroll{2,1,W,1,W,zero(UInt),1})(($z, n)))
end
Expand Down Expand Up @@ -402,13 +404,13 @@ end
A11 =
getfield(vload(spa, $(Unroll{1,1,R,2,W,zero(UInt),1})(($z, n))), :data)
# The `W` rows
Base.Cartesian.@nexprs $W r -> A11_r = getfield(A11, r)
Base.Cartesian.@nexprs $R r -> A11_r = getfield(A11, r)
# compute
# A_{j,i} - \sum_{k=1}^{i-1}U_{k,i}C_{j,k})
# Each iter:
# A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)}
for nk SafeCloseOpen(n) # nmuladd
U_ki = vload(spu, (nk, $(MM{W}(z))))
U_ki = vload(spu, (nk, $(MM{W})(n)))
Base.Cartesian.@nexprs $R r ->
A11_r = vfnmadd_fast(U_ki, vload(spc, (static(r - 1), nk)), A11_r)
end
Expand All @@ -432,13 +434,13 @@ end
push!(q.args, q2)
q3 = if R == Wpad
quote
i = $(Unroll{2,1,W,1,W,zero(UInt),1})(($z, n))
i = $(Unroll{2,1,W,1,Wpad,zero(UInt),1})(($z, n))
vstore!(spc, C_u, i)
end
else
quote
mask = VectorizationBase.mask($WS, $(static(R)))
i = $(Unroll{2,1,W,1,W,(-1 % UInt),1})(($z, n))
i = $(Unroll{2,1,W,1,Wpad,(-1 % UInt),1})(($z, n))
vstore!(spc, C_u, i, mask)
end
end
Expand Down Expand Up @@ -890,7 +892,8 @@ function (f::RDivBlockMandNv2{UNIT,XC,XA})(
Core.ifelse(block == Nblock - 1, Mrem, mtb),
N,
Val{UNIT}(),
static(XC)static(XA)
static(XC),
static(XA)
)
end
end
Expand Down Expand Up @@ -1000,29 +1003,35 @@ end
N,
m,
Nr,
::Val{W},
::StaticInt{W},
::Val{UNIT},
::Val{r}
::StaticInt{r}
) where {W,UNIT,r}
r <= 0 && throw("Remainder of `<= 0` shouldn't be called, but had $r.")
r >= W && throw("Reaminderof `>= $W` shouldn't be called, but had $r.")
if r == 1
vlxj = :(vload(spc, (M - 1, j)))
if !UNIT
vlxj = :($vlxj / vload(spu, (j, j)))
z = static(0)
vlxj = :(vload(spc, ($z, j)))
if UNIT
vlxj = :(xj = $vlxj)
else
vlxj = quote
xj = $vlxj / vload(spu, (j, j))
vstore!(spc, xj, ($z, j))
end
end
quote
if pointer(spc) != pointer(spa)
for n = 0:N-1
vstore!(spc, vload(spa, (M - 1, n)), (M - 1, n))
vstore!(spc, vload(spa, ($z, n)), ($z, n))
end
end
for j = 0:N-1
xj = $vlxj
$vlxj
for i = (j+1):N-1
xi = vload(spc, (M - 1, i))
xi = vload(spc, ($z, i))
Uji = vload(spu, (j, i))
vstore!(spc, xi - xj * Uji, (M - 1, i))
vstore!(spc, xi - xj * Uji, ($z, i))
end
end
end
Expand All @@ -1033,14 +1042,14 @@ end
n = Nr # non factor of W remainder
if n > 0
mask = $(VectorizationBase.mask(WS, r))
BdivU_small_kern!(spc, nothing, spa, spu, n, mask, Val(UNIT))
BdivU_small_kern!(spc, nothing, spa, spu, n, mask, $(Val(UNIT)))
end
# while n < N - $(W * U - 1)
# ldiv_solve_W_u!(spc, spa, spu, n, $WS, $US, Val(UNIT), Val(w))
# ldiv_solve_W_u!(spc, spa, spu, n, $WS, $US, Val(UNIT), Val(r))
# n += $(W * U)
# end
while n != N
ldiv_solve_W!(spc, spa, spu, n, $WS, Val(UNIT), Val(w))
ldiv_solve_W!(spc, spa, spu, n, $WS, $(Val(UNIT)), $(StaticInt(r)))
n += $W
end
end
Expand All @@ -1054,16 +1063,26 @@ end
N,
m,
Nr,
::Val{W},
::StaticInt{W},
# ::Val{U},
::Val{UNIT}
) where {W,UNIT}
WS = static(W)
# US = static(U)
quote
$(Expr(:meta, :inline))
Base.Cartesian.@nif $W w -> m == M - w w ->
ldiv_remainder!(spc, spa, spu, M, N, m, Nr, $WS, $(Val(UNIT)), Val(w))
Base.Cartesian.@nif $W w -> m == M - w w -> ldiv_remainder!(
spc,
spa,
spu,
M,
N,
m,
Nr,
$WS,
$(Val(UNIT)),
StaticInt(w)
)
end
end

Expand All @@ -1087,11 +1106,12 @@ function rdiv_U!(
MU = UF > 1 ? M : 0
Nd, Nr = VectorizationBase.vdivrem(N, WS)
m = 0
# @show M,N
# m, no remainder
while m < M - WS + 1
n = Nr # non factor of W remainder
if n > 0
BdivU_small_kern_u!(spc, nothing, spa, spu, n, Val(1), Val(UNIT))
BdivU_small_kern_u!(spc, nothing, spa, spu, n, StaticInt(1), Val(UNIT))
end
while n < N - (WU - 1)
ldiv_solve_W_u!(spc, spa, spu, n, WS, UF, Val(UNIT))
Expand Down
41 changes: 35 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
using TriangularSolve, LinearAlgebra
using Test

function check_box_for_nans(A, M, N)
# blocks start at 17, and are MxN
@test all(isnan, @view(A[1:16, :]))
@test all(isnan, @view(A[17+M:end, :]))
@test all(isnan, @view(A[17:16+M, 1:16]))
@test all(isnan, @view(A[17:16+M, 17+N:end]))
end

function test_solve(::Type{T}) where {T}
for n 1:(T === Float32 ? 100 : 200)
maxN = (T === Float32 ? 100 : 200)
maxM = maxN + 10
AA = fill(T(NaN), maxM + 32, maxM + 32)
RR = fill(T(NaN), maxM + 32, maxM + 32)
BB = fill(T(NaN), maxN + 32, maxN + 32)
for n 1:maxN
@show n
for m max(1, n - 10):n+10
A = rand(T, m, n)
res = similar(A)
B = rand(T, n, n) + I
A = @view AA[17:16+m, 17:16+n]
res = @view RR[17:16+m, 17:16+n]
B = @view BB[17:16+n, 17:16+n]

A .= rand.(T)
B .= rand.(T)
@view(B[diagind(B)]) .+= one(T)

@test TriangularSolve.rdiv!(res, A, UpperTriangular(B)) *
UpperTriangular(B) A
@test TriangularSolve.rdiv!(res, A, UnitUpperTriangular(B)) *
Expand All @@ -16,8 +34,15 @@ function test_solve(::Type{T}) where {T}
UpperTriangular(B) A
@test TriangularSolve.rdiv!(res, A, UnitUpperTriangular(B), Val(false)) *
UnitUpperTriangular(B) A
A = rand(T, n, m)
res = similar(A)

check_box_for_nans(RR, m, n)
res .= NaN
A .= NaN

A = @view AA[17:16+n, 17:16+m]
res = @view RR[17:16+n, 17:16+m]
A .= rand.(T)

@test LowerTriangular(B) *
TriangularSolve.ldiv!(res, LowerTriangular(B), A) A
@test UnitLowerTriangular(B) *
Expand All @@ -27,6 +52,10 @@ function test_solve(::Type{T}) where {T}
@test UnitLowerTriangular(B) *
TriangularSolve.ldiv!(res, UnitLowerTriangular(B), A, Val(false))
A
check_box_for_nans(RR, n, m)
res .= NaN
A .= NaN
B .= NaN
end
end
end
Expand Down

0 comments on commit 77a9b0b

Please sign in to comment.