From 77a9b0bee7d6513043aee1e1677fc776c1b6b883 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Fri, 12 Apr 2024 14:54:22 -0400 Subject: [PATCH] Tests pass locally --- src/TriangularSolve.jl | 68 +++++++++++++++++++++++++++--------------- test/runtests.jl | 41 +++++++++++++++++++++---- 2 files changed, 79 insertions(+), 30 deletions(-) diff --git a/src/TriangularSolve.jl b/src/TriangularSolve.jl index b3072fc..c25e5c9 100644 --- a/src/TriangularSolve.jl +++ b/src/TriangularSolve.jl @@ -267,7 +267,7 @@ end # Each iter: # A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)} for nk ∈ SafeCloseOpen(n) # nmuladd - U_ki = vload(spu, $(Unroll{2,W,U,2,W,zero(UInt),1})(nk, n)) + U_ki = vload(spu, $(Unroll{2,W,U,2,W,zero(UInt),1})((nk, n))) Base.Cartesian.@nexprs $W c -> A11_c = vfnmadd_fast(U_ki, vload(spc, (static(c - 1), nk)), A11_c) end @@ -320,7 +320,7 @@ end end @generated function ldiv_solve_W!( spc, - s, + spa, spu, n, ::StaticInt{W}, @@ -354,16 +354,18 @@ end # Each iter: # A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)} for nk ∈ SafeCloseOpen(n) # nmuladd - U_ki = vload(spu, (nk, $(MM{W}(z)))) + U_ki = vload(spu, (nk, $(MM{W})(n))) Base.Cartesian.@nexprs $W c -> A11_c = vfnmadd_fast(U_ki, vload(spc, (static(c - 1), nk)), A11_c) end + # Base.Cartesian.@nexprs $W c -> @show A11_c # solve AU wants us to transpose # We then have column-major multiplies # take A[(u-1)*W,u*W), [0,W)] X = VectorizationBase.transpose_vecunroll( VecUnroll(Base.Cartesian.@ntuple $W A11) ) + # @show X C_u = solve_AU(X, spu, n, $(Val(UNIT))) vstore!(spc, C_u, $(Unroll{2,1,W,1,W,zero(UInt),1})(($z, n))) end @@ -402,13 +404,13 @@ end A11 = getfield(vload(spa, $(Unroll{1,1,R,2,W,zero(UInt),1})(($z, n))), :data) # The `W` rows - Base.Cartesian.@nexprs $W r -> A11_r = getfield(A11, r) + Base.Cartesian.@nexprs $R r -> A11_r = getfield(A11, r) # compute # A_{j,i} - \sum_{k=1}^{i-1}U_{k,i}C_{j,k}) # Each iter: # A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)} for nk ∈ SafeCloseOpen(n) # nmuladd - U_ki = vload(spu, (nk, $(MM{W}(z)))) + U_ki = vload(spu, (nk, $(MM{W})(n))) Base.Cartesian.@nexprs $R r -> A11_r = vfnmadd_fast(U_ki, vload(spc, (static(r - 1), nk)), A11_r) end @@ -432,13 +434,13 @@ end push!(q.args, q2) q3 = if R == Wpad quote - i = $(Unroll{2,1,W,1,W,zero(UInt),1})(($z, n)) + i = $(Unroll{2,1,W,1,Wpad,zero(UInt),1})(($z, n)) vstore!(spc, C_u, i) end else quote mask = VectorizationBase.mask($WS, $(static(R))) - i = $(Unroll{2,1,W,1,W,(-1 % UInt),1})(($z, n)) + i = $(Unroll{2,1,W,1,Wpad,(-1 % UInt),1})(($z, n)) vstore!(spc, C_u, i, mask) end end @@ -890,7 +892,8 @@ function (f::RDivBlockMandNv2{UNIT,XC,XA})( Core.ifelse(block == Nblock - 1, Mrem, mtb), N, Val{UNIT}(), - static(XC)static(XA) + static(XC), + static(XA) ) end end @@ -1000,29 +1003,35 @@ end N, m, Nr, - ::Val{W}, + ::StaticInt{W}, ::Val{UNIT}, - ::Val{r} + ::StaticInt{r} ) where {W,UNIT,r} r <= 0 && throw("Remainder of `<= 0` shouldn't be called, but had $r.") r >= W && throw("Reaminderof `>= $W` shouldn't be called, but had $r.") if r == 1 - vlxj = :(vload(spc, (M - 1, j))) - if !UNIT - vlxj = :($vlxj / vload(spu, (j, j))) + z = static(0) + vlxj = :(vload(spc, ($z, j))) + if UNIT + vlxj = :(xj = $vlxj) + else + vlxj = quote + xj = $vlxj / vload(spu, (j, j)) + vstore!(spc, xj, ($z, j)) + end end quote if pointer(spc) != pointer(spa) for n = 0:N-1 - vstore!(spc, vload(spa, (M - 1, n)), (M - 1, n)) + vstore!(spc, vload(spa, ($z, n)), ($z, n)) end end for j = 0:N-1 - xj = $vlxj + $vlxj for i = (j+1):N-1 - xi = vload(spc, (M - 1, i)) + xi = vload(spc, ($z, i)) Uji = vload(spu, (j, i)) - vstore!(spc, xi - xj * Uji, (M - 1, i)) + vstore!(spc, xi - xj * Uji, ($z, i)) end end end @@ -1033,14 +1042,14 @@ end n = Nr # non factor of W remainder if n > 0 mask = $(VectorizationBase.mask(WS, r)) - BdivU_small_kern!(spc, nothing, spa, spu, n, mask, Val(UNIT)) + BdivU_small_kern!(spc, nothing, spa, spu, n, mask, $(Val(UNIT))) end # while n < N - $(W * U - 1) - # ldiv_solve_W_u!(spc, spa, spu, n, $WS, $US, Val(UNIT), Val(w)) + # ldiv_solve_W_u!(spc, spa, spu, n, $WS, $US, Val(UNIT), Val(r)) # n += $(W * U) # end while n != N - ldiv_solve_W!(spc, spa, spu, n, $WS, Val(UNIT), Val(w)) + ldiv_solve_W!(spc, spa, spu, n, $WS, $(Val(UNIT)), $(StaticInt(r))) n += $W end end @@ -1054,7 +1063,7 @@ end N, m, Nr, - ::Val{W}, + ::StaticInt{W}, # ::Val{U}, ::Val{UNIT} ) where {W,UNIT} @@ -1062,8 +1071,18 @@ end # US = static(U) quote $(Expr(:meta, :inline)) - Base.Cartesian.@nif $W w -> m == M - w w -> - ldiv_remainder!(spc, spa, spu, M, N, m, Nr, $WS, $(Val(UNIT)), Val(w)) + Base.Cartesian.@nif $W w -> m == M - w w -> ldiv_remainder!( + spc, + spa, + spu, + M, + N, + m, + Nr, + $WS, + $(Val(UNIT)), + StaticInt(w) + ) end end @@ -1087,11 +1106,12 @@ function rdiv_U!( MU = UF > 1 ? M : 0 Nd, Nr = VectorizationBase.vdivrem(N, WS) m = 0 + # @show M,N # m, no remainder while m < M - WS + 1 n = Nr # non factor of W remainder if n > 0 - BdivU_small_kern_u!(spc, nothing, spa, spu, n, Val(1), Val(UNIT)) + BdivU_small_kern_u!(spc, nothing, spa, spu, n, StaticInt(1), Val(UNIT)) end while n < N - (WU - 1) ldiv_solve_W_u!(spc, spa, spu, n, WS, UF, Val(UNIT)) diff --git a/test/runtests.jl b/test/runtests.jl index e57a803..f8e1b7d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,13 +1,31 @@ using TriangularSolve, LinearAlgebra using Test +function check_box_for_nans(A, M, N) + # blocks start at 17, and are MxN + @test all(isnan, @view(A[1:16, :])) + @test all(isnan, @view(A[17+M:end, :])) + @test all(isnan, @view(A[17:16+M, 1:16])) + @test all(isnan, @view(A[17:16+M, 17+N:end])) +end + function test_solve(::Type{T}) where {T} - for n ∈ 1:(T === Float32 ? 100 : 200) + maxN = (T === Float32 ? 100 : 200) + maxM = maxN + 10 + AA = fill(T(NaN), maxM + 32, maxM + 32) + RR = fill(T(NaN), maxM + 32, maxM + 32) + BB = fill(T(NaN), maxN + 32, maxN + 32) + for n ∈ 1:maxN @show n for m ∈ max(1, n - 10):n+10 - A = rand(T, m, n) - res = similar(A) - B = rand(T, n, n) + I + A = @view AA[17:16+m, 17:16+n] + res = @view RR[17:16+m, 17:16+n] + B = @view BB[17:16+n, 17:16+n] + + A .= rand.(T) + B .= rand.(T) + @view(B[diagind(B)]) .+= one(T) + @test TriangularSolve.rdiv!(res, A, UpperTriangular(B)) * UpperTriangular(B) ≈ A @test TriangularSolve.rdiv!(res, A, UnitUpperTriangular(B)) * @@ -16,8 +34,15 @@ function test_solve(::Type{T}) where {T} UpperTriangular(B) ≈ A @test TriangularSolve.rdiv!(res, A, UnitUpperTriangular(B), Val(false)) * UnitUpperTriangular(B) ≈ A - A = rand(T, n, m) - res = similar(A) + + check_box_for_nans(RR, m, n) + res .= NaN + A .= NaN + + A = @view AA[17:16+n, 17:16+m] + res = @view RR[17:16+n, 17:16+m] + A .= rand.(T) + @test LowerTriangular(B) * TriangularSolve.ldiv!(res, LowerTriangular(B), A) ≈ A @test UnitLowerTriangular(B) * @@ -27,6 +52,10 @@ function test_solve(::Type{T}) where {T} @test UnitLowerTriangular(B) * TriangularSolve.ldiv!(res, UnitLowerTriangular(B), A, Val(false)) ≈ A + check_box_for_nans(RR, n, m) + res .= NaN + A .= NaN + B .= NaN end end end