Skip to content

Commit

Permalink
Merge pull request #25 from JuliaSIMD/liballocsbutiamgivingup
Browse files Browse the repository at this point in the history
trying to cut down heisen allocations to no avail
  • Loading branch information
chriselrod authored Oct 14, 2022
2 parents 7fa5823 + f2ce9ef commit 8e2d7e4
Showing 1 changed file with 23 additions and 159 deletions.
182 changes: 23 additions & 159 deletions src/TriangularSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,89 +31,13 @@ using Polyester
end
end

# @generated function nmuladd(A::VecUnroll{Nm1},B::AbstractStridedPointer,C::VecUnroll{Nm1}) where {Nm1}
# N = Nm1 + 1
# quote
# $(Expr(:meta,:inline))
# Ad = VectorizationBase.data(A);
# Cd = VectorizationBase.data(C);
# bp = stridedpointer(B)
# Base.Cartesian.@nexprs $N n -> C_n = Cd[n]
# Base.Cartesian.@nexprs $N k -> begin
# A_k = Ad[k]
# Base.Cartesian.@nexprs $N n -> begin
# C_n = Base.FastMath.sub_fast(C_n, Base.FastMath.mul_fast(A_k, vload(B, (k-1,n-1))))
# end
# end
# VecUnroll(Base.Cartesian.@ntuple $N C)
# end
# end

# @inline function solve_Wx3W(A11::V, A12::V, A13::V, U::AbstractMatrix, ::StaticInt{W}) where {V<:VecUnroll,W}
# WS = StaticInt{W}()

# U11 = view(U,StaticInt(1):WS,StaticInt(1):WS)
# A11 = solve_AU(A11, U11)

# U12 = view(U,StaticInt(1):WS, StaticInt(1)+WS:WS*StaticInt(2))
# A12 = nmuladd(A11, U12, A12)
# U22 = view(U,StaticInt(1)+WS:WS*StaticInt(2),StaticInt(1)+WS:WS*StaticInt(2))
# A12 = solve_AU(A12, U22)

# U13 = view(U,StaticInt(1):WS, StaticInt(1)+WS*StaticInt(2):WS*StaticInt(3))
# A13 = nmuladd(A11, U13, A13)
# U23 = view(U,StaticInt(1)+WS:WS*StaticInt(2),StaticInt(1)+WS*StaticInt(2):WS*StaticInt(3))
# A13 = nmuladd(A12, U23, A13)
# U33 = view(U,StaticInt(1)+WS*StaticInt(2):WS*StaticInt(3),StaticInt(1)+WS*StaticInt(2):WS*StaticInt(3))
# A13 = solve_AU(A13, U33)

# return A11, A12, A13
# end

# @inline function solve_Wx3W!(ap::AbstractStridedPointer{T}, bp::AbstractStridedPointer{T}, U, rowoffset, coloffset) where {T}
# WS = VectorizationBase.pick_vector_width(T)
# W = Int(WS)
# A11 = vload(bp, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset)))
# A12 = vload(bp, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset+WS)))
# A13 = vload(bp, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset+WS+WS)))

# A11, A12, A13 = solve_Wx3W(A11, A12, A13, U, WS)

# vstore!(ap, A11, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset)))
# vstore!(ap, A12, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset+WS)))
# vstore!(ap, A13, Unroll{2,1,W,1,W,zero(UInt),1}((rowoffset,coloffset+WS+WS)))
# end
# @inline function solve_Wx3W!(ap::AbstractStridedPointer{T}, bp::AbstractStridedPointer{T}, U, rowoffset, coloffset, m::VectorizationBase.AbstractMask) where {T}
# WS = VectorizationBase.pick_vector_width(T)
# W = Int(WS)
# A11 = vload(bp, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset)), m)
# A12 = vload(bp, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset+WS)), m)
# A13 = vload(bp, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset+WS+WS)), m)

# A11, A12, A13 = solve_Wx3W(A11, A12, A13, U, WS)

# vstore!(ap, A11, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset)), m)
# vstore!(ap, A12, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset+WS)), m)
# vstore!(ap, A13, Unroll{2,1,W,1,W,(-1%UInt),1}((rowoffset,coloffset+WS+WS)), m)
# end

# solve_3Wx3W!(A,B,U::UpperTriangular) = solve_3Wx3W!(A,B,parent(U))
# function solve_3Wx3W!(A::AbstractMatrix{T},B,U) where {T}
# W = VectorizationBase.pick_vector_width(T)
# ap = stridedpointer(A);
# bp = stridedpointer(B);
# solve_Wx3W!(ap, bp, U, StaticInt(1), StaticInt(1))
# solve_Wx3W!(ap, bp, U, StaticInt(1) + W, StaticInt(1))
# solve_Wx3W!(ap, bp, U, StaticInt(1) + W + W, StaticInt(1))
# end

@inline maybestore!(p, v, i) = vstore!(p, v, i)
@inline maybestore!(::Nothing, v, i) = nothing

@inline maybestore!(p, v, i, m) = vstore!(p, v, i, m)
@inline maybestore!(::Nothing, v, i, m) = nothing

@inline function store_small_kern!(spa, sp, v, spu, i, n, mask, ::Val{true})
@inline function store_small_kern!(spa, sp, v, _, i, n, mask, ::Val{true})
vstore!(spa, v, i, mask)
vstore!(sp, v, i, mask)
end
Expand Down Expand Up @@ -160,46 +84,6 @@ end
store_small_kern!(spa, sp, Amn, spu, Unroll{1,W,U,1,W,zero(UInt),1}((StaticInt(0),n)), n, Val{UNIT}())
end
end
# function BdivU_small!(A::AbstractMatrix{T}, B::AbstractMatrix{T}, U::AbstractMatrix{T}) where {T}
# W = VectorizationBase.pick_vector_width(T)
# M, N = size(A)
# m = 0
# spa = stridedpointer(A)
# spb = stridedpointer(B)
# spu = stridedpointer(U)
# while m < M
# ml = m+1
# mu = m+W
# maskiter = mu > M
# mask = maskiter ? VectorizationBase.mask(W, M) : VectorizationBase.max_mask(W)
# for n ∈ 1:N
# Amn = vload(spb, (MM(W, ml),n), mask)
# for k ∈ 1:n-1
# Amn = vfnmadd_fast(vload(spa, (MM(W, ml),k), mask), vload(spu, (k,n)), Amn)
# end
# vstore!(spa, Amn / vload(spu, (n,n)), (MM(W, ml),n), mask)
# end
# m = mu
# end
# # @inbounds @fastmath for m ∈ 1:M
# # for n ∈ 1:N
# # Amn = B[m,n]
# # for k ∈ 1:n-1
# # Amn -= A[m,k]*U[k,n]
# # end
# # A[m,n] = Amn / U[n,n]
# # end
# # end
# end
# function nmuladd!(C,A,B,D)
# @turbo for n ∈ axes(C,2), m ∈ axes(C,1)
# Cmn = D[m,n]
# for k ∈ axes(B,1)
# Cmn -= A[m,k]*B[k,n]
# end
# C[m,n] = Cmn
# end
# end

@generated function rdiv_solve_W_u!(spc, spb, spa, spu, n, ::StaticInt{W}, ::StaticInt{U}, ::Val{UNIT}) where {W, U, UNIT}
quote
Expand Down Expand Up @@ -286,7 +170,7 @@ const LDIVBUFFERS = Vector{UInt8}[]
buff = LDIVBUFFERS[Threads.threadid()]
RSUF = StaticInt{UF}()*VectorizationBase.register_size()
L = RSUF*N
L > length(buff) && resize!(buff, L)
L > length(buff) && resize!(buff, L%UInt)
ptr = Base.unsafe_convert(Ptr{T}, buff)
si = StrideIndex{2,(1,2),1}((VectorizationBase.static_sizeof(T), RSUF), (StaticInt(0),StaticInt(0)))
stridedpointer(ptr, si, StaticInt{0}())
Expand Down Expand Up @@ -412,24 +296,14 @@ function rdiv_block_N!(
N_temp = Core.ifelse(repeat, B_normalized, N)
while true
# println("Solve with N_temp = $N_temp and n = $n")
rdiv_U!(spc, spa_rdiv, gesp(spu, (n,StaticInt{0}())), M, N_temp, StaticInt{X}(), Val(UNIT))
rdiv_U!(spc, spa_rdiv, gesp(spu, (n,StaticInt{0}())), M, N_temp, StaticInt{X}(), Val{UNIT}())
repeat || break
spa = gesp(spa, (StaticInt(0), B_normalized))
spc = gesp(spc, (StaticInt(0), B_normalized))
spu = gesp(spu, (StaticInt(0), B_normalized))
nnext = n + B_normalized
# N_temp =
n += B_normalized
repeat = n + B_normalized < N
N_temp = repeat ? N_temp : N - n
# N_temp = min(n + B_normalized, N) - n
# println("nmuladd with N_temp = $N_temp and n = $n")
# mul!(
# copyto!(view(C, :, n+1:n+N_temp), view(A, :, n+1:n+N_temp)),
# view(C, :, 1:n),
# view(U, 1:n, n+1:n+N_temp),
# -1.0, 1.0
# )
nmuladd!(spc_base, spa, spu, M, n, N_temp)
spa_rdiv = spc
end
Expand All @@ -439,15 +313,14 @@ function rdiv_block_MandN!(
) where {T,UNIT,X}
B = block_size(Val(T))
W = VectorizationBase.pick_vector_width(T)
B_normalized = VectorizationBase.vcld(N, VectorizationBase.vcld(N, B)*W)*W
WUF = W*unroll_factor(W)
B_m = VectorizationBase.vcld(M, VectorizationBase.vcld(M, B)*WUF)*WUF
m = 0
while m < M
mu = m + B_m
Mtemp = min(M, mu) - m
rdiv_block_N!(
spc, spa, spu, Mtemp, N, Val(UNIT), StaticInt{X}(),
spc, spa, spu, Mtemp, N, Val{UNIT}(), StaticInt{X}(),
VectorizationBase.vcld(N, VectorizationBase.vcld(N, B)*W)*W
)
spa = gesp(spa, (B_m, StaticInt{0}()))
Expand All @@ -458,42 +331,33 @@ function rdiv_block_MandN!(
end
function m_thread_block_size(M, N, nthreads, ::Val{T}) where {T}
W = VectorizationBase.pick_vector_width(T)
WUF = W * unroll_factor(W)
nb = clamp(VectorizationBase.vdiv(M * N, StaticInt{256}() * W), 1, nthreads)
min(M, VectorizationBase.vcld(M, nb*W)*W)
end

struct RDivBlockMandNv2{UNIT,X} end
function (f::RDivBlockMandNv2{UNIT,X})(allargs, blockstart, blockstop) where {UNIT,X}
spc, spa, spu, N, Mrem, Nblock, mtb = allargs
for block = blockstart-1:blockstop-1
rdiv_block_MandN!(
gesp(spc, (mtb*block, StaticInt{0}())),
gesp(spa, (mtb*block, StaticInt{0}())),
spu, Core.ifelse(block == Nblock-1, Mrem, mtb), N, Val{UNIT}(), static(X)
)
end
end


function multithread_rdiv!(
spc::AbstractStridedPointer{T}, spa, spu, M, N, mtb, ::Val{UNIT}, ::StaticInt{X}
) where {X,T,UNIT}
mtb = 8
spc::AbstractStridedPointer{TC}, spa::AbstractStridedPointer{TA}, spu::AbstractStridedPointer{TU}, M::Int, N::Int, mtb::Int, ::Val{UNIT}, ::StaticInt{X}
) where {X,UNIT,TC,TA,TU}
# Main._a[] = (spc, spa, spu, M, N, mtb, Val(UNIT), static(X));
(Md, Mr) = VectorizationBase.vdivrem(M, mtb)
Nblock = Md + (Mr 0)
Mrem = Core.ifelse(Mr 0, Mr, mtb)
# @show mtb, Nblock, Mrem, Md, Mr
# return
let Md = Md, Mr = Mr, Nblock = Md + (Mr 0), Mrem = Core.ifelse(Mr 0, Mr, mtb), VUNIT = Val{UNIT}(), StaticX = StaticInt{X}()
@batch for block in CloseOpen(Nblock)
# for block in CloseOpen(Nblock)
# let block = 0
rdiv_block_MandN!(
# rdiv_block_N!(
gesp(spc, (mtb*block, StaticInt{0}())),
gesp(spa, (mtb*block, StaticInt{0}())),
spu, Core.ifelse(block == Nblock-1, Mrem, mtb), N, VUNIT, StaticX
# spu, M, N, Val{UNIT}(), StaticInt{X}()
)
end
end
f = RDivBlockMandNv2{UNIT,X}()
batch(f, (Nblock,min(Nblock,Threads.nthreads())), spc, spa, spu, N, Mrem, Nblock, mtb)
nothing
# nlaunch = Md - (Mr == 0)
# threads, torelease = Polyester.request_threads(Base.Threads.threadid(), nlaunch)
# nthread = length(threads)
# if (nthread % Int32) ≤ zero(Int32)
# return rdiv_block_MandN!(spc, spa, spu, M, N, Val(UNIT), StaticInt{X}())
# end
# nbatch = nthread + one(nthread)

end

# We're using `W x W` blocks, consuming `W` registers
Expand Down Expand Up @@ -521,7 +385,7 @@ function rdiv_U!(spc::AbstractStridedPointer{T}, spa::AbstractStridedPointer, sp
if n > 0
BdivU_small_kern_u!(spb, spc, spa, spu, n, UF, Val(UNIT))
end
for i 1:Nd
for _ 1:Nd
rdiv_solve_W_u!(spb, spc, spa, spu, n, WS, UF, Val(UNIT))
n += W
end
Expand Down

0 comments on commit 8e2d7e4

Please sign in to comment.