diff --git a/README.md b/README.md index 8ce0617..e1adddb 100644 --- a/README.md +++ b/README.md @@ -17,12 +17,12 @@ Inspired by the [LatticeQCD.jl](https://github.com/akio-tomiya/LatticeQCD.jl/tre - [x] Improved Topological charge definitions (clover, rectangle clover-improved) - [x] Wilson(-Clover) fermions - [x] Staggered fermions -- [ ] Even-odd preconditioner for Wilson(-Clover) +- [x] Even-odd preconditioner for Wilson(-Clover) - [x] Even-odd preconditioner for Staggered - [ ] Mass-splitting preconditioner / Hasenbusch trick - [x] RHMC to simulate odd number of flavours -- [ ] Support for CUDA and ROCm backends -- [ ] Multi-node parallelism using MPI +- [ ] Full support for CUDA and ROCm backends +- [x] Multi-node parallelism using MPI ## Installation: First make sure you have Julia version 1.9.x or 1.10.x installed. You can use [juliaup](https://github.com/JuliaLang/juliaup) for that or just install the release from the [Julia website](https://julialang.org/downloads/). diff --git a/src/diracoperators/DiracOperators.jl b/src/diracoperators/DiracOperators.jl index 312ba81..1df051e 100644 --- a/src/diracoperators/DiracOperators.jl +++ b/src/diracoperators/DiracOperators.jl @@ -25,8 +25,8 @@ import KernelAbstractions as KA import ..Fields: AbstractField, FieldTopology, Gaugefield, Paulifield, Spinorfield import ..Fields: SpinorfieldEO, Tensorfield import ..Fields: check_dims, clear!, clover_square, dims, even_odd, gaussian_pseudofermions! -import ..Fields: Clover, Checkerboard2, Sequential, @latmap, @latsum, set_source!, volume -import ..Fields: fieldstrength_eachsite!, num_colors, num_dirac +import ..Fields: @latmap, @latsum, Clover, Checkerboard2, Sequential, set_source!, volume +import ..Fields: @groupreduce, fieldstrength_eachsite!, num_colors, num_dirac import ..Fields: PeriodicBC, AntiPeriodicBC, apply_bc, create_bc, distributed_reduce abstract type AbstractDiracOperator end @@ -81,6 +81,7 @@ include("wilson_eo.jl") include("gpu_kernels/staggered.jl") include("gpu_kernels/staggered_eo.jl") include("gpu_kernels/wilson.jl") +include("gpu_kernels/wilson_eo.jl") include("arnoldi.jl") """ diff --git a/src/diracoperators/gpu_kernels/wilson_eo.jl b/src/diracoperators/gpu_kernels/wilson_eo.jl new file mode 100644 index 0000000..fabb4e3 --- /dev/null +++ b/src/diracoperators/gpu_kernels/wilson_eo.jl @@ -0,0 +1,159 @@ +function mul_oe!( + ψ_eo::TF, U::Gaugefield{B,T}, ϕ_eo::TF, bc, into_odd, ::Val{dagg}; fac=1 +) where {B,T,TF<:WilsonEOPreSpinorfield{B,T},dagg} + check_dims(ψ_eo, ϕ_eo, U) + ψ = ψ_eo.parent + ϕ = ϕ_eo.parent + fdims = dims(ψ) + NV = ψ.NV + + @latmap( + Sequential(), Val(1), wilson_eo_kernel!, ψ, U, ϕ, bc, fdims, NV, fac, T, Val(dagg) + ) + update_halo!(ψ_eo) + return nothing +end + +@kernel function wilson_eo_kernel!( + ψ, @Const(U), @Const(ϕ), bc, fdims, NV, fac, ::Type{T}, ::Val{dagg} +) where {T,dagg} + site = @index(Global, Cartesian) + _site = if into_odd + eo_site(site, fdims..., NV) + else + eo_site_switch(site, fdims..., NV) + end + + if iseven(site) + @inbounds ψ[_site] = fac * wilson_eo_kernel(U, ϕ, site, bc, T, Val(dagg)) + end +end + +function mul_eo!( + ψ_eo::TF, U::Gaugefield{B,T}, ϕ_eo::TF, bc, into_odd, ::Val{dagg}; fac=1 +) where {B,T,TF<:WilsonEOPreSpinorfield{B,T},dagg} + check_dims(ψ_eo, ϕ_eo, U) + ψ = ψ_eo.parent + ϕ = ϕ_eo.parent + fdims = dims(ψ) + NV = ψ.NV + + @latmap( + Sequential(), Val(1), wilson_oe_kernel!, ψ, U, ϕ, bc, fdims, NV, fac, T, Val(dagg) + ) + update_halo!(ψ_eo) + return nothing +end + +@kernel function wilson_oe_kernel!( + ψ, @Const(U), @Const(ϕ), bc, fdims, NV, fac, ::Type{T}, ::Val{dagg} +) where {T,dagg} + site = @index(Global, Cartesian) + _site = if into_odd + eo_site_switch(site, fdims..., NV) + else + eo_site(site, fdims..., NV) + end + + if iseven(site) + @inbounds ψ[_site] = fac * wilson_eo_kernel(U, ϕ, site, bc, T, Val(dagg)) + end +end + +function calc_diag!( + D_diag::TW, D_oo_inv::TW, ::Nothing, U::Gaugefield{B,T}, mass +) where {B,T,M,TW<:Paulifield{B,T,M,false}} + check_dims(D_diag, D_oo_inv, U) + mass_term = Complex{T}(4 + mass) + fdims = dims(U) + NV = U.NV + + @latmap(Sequential(), Val(1), calc_diag_kernel!, D_diag, D_oo_inv, mass_term, fdims, NV) + return nothing +end + +@kernel function calc_diag_kernel!( + D_diag, D_oo_inv, mass_term, fdims, NV, ::Type{T} +) where {T} + site = @index(Global, Cartesian) + calc_diag_kernel!(D_diag, D_oo_inv, mass_term, site, fdims, NV, T) +end + +function calc_diag!( + D_diag::TW, D_oo_inv::TW, Fμν::Tensorfield{B,T,M}, U::Gaugefield{B,T,M}, mass +) where {B,T,M,TW<:Paulifield{B,T,M,true}} + check_dims(D_diag, D_oo_inv, U) + mass_term = Complex{T}(4 + mass) + fdims = dims(U) + NV = U.NV + fac = Complex{T}(D_diag.csw / 2) + + @latmap( + Sequential(), Val(1), calc_diag_kernel!, D_diag, D_oo_inv, mass_term, fdims, NV, fac, T + ) + return nothing +end + +@kernel function calc_diag_kernel!( + D_diag, D_oo_inv, @Const(Fμν), mass_term, fdims, NV, fac, ::Type{T} +) where {T} + site = @index(Global, Cartesian) + calc_diag_kernel!(D_diag, D_oo_inv, Fμν, mass_term, site, fdims, NV, fac, T) +end + +function mul_oo_inv!( + ϕ_eo::WilsonEOPreSpinorfield{B,T,M}, D_oo_inv::Paulifield{B,T,M} +) where {B,T,M} + check_dims(ϕ_eo, D_oo_inv) + ϕ = ϕ_eo.parent + fdims = dims(U) + NV = U.NV + + @latmap(EvenSites(), Val(1), mul_oo_inv_kernel!, ϕ, D_oo_inv, fdims, NV) + return nothing +end + +@kernel function mul_oo_inv_kernel!(ϕ, D_oo_inv, fdims, NV) + _site = @index(Global, Cartesian) + o_site = switch_sides(_site, fdims..., NV) + ϕ[o_site] = cmvmul_block(D_oo_inv[_site], ϕ[o_site]) +end + +function axmy!( + D_diag::Paulifield{B,T,M}, ψ_eo::TF, ϕ_eo::TF +) where {B,T,M,TF<:WilsonEOPreSpinorfield{B,T,M}} # even on even is the default + check_dims(ϕ_eo, ψ_eo) + ϕ = ϕ_eo.parent + ψ = ψ_eo.parent + + @latmap(EvenSites(), Val(1), axmy_kernel!, D_diag, ψ, ϕ) + return nothing +end + +@kernel function axmy_kernel!(@Const(D_diag), ψ, ϕ) + _site = @index(Global, Cartesian) + ϕ[_site] = cmvmul_block(D_diag[_site], ψ[_site]) - ϕ[_site] +end + +function trlog(D_diag::Paulifield{B,T,M,true}, ::Any) where {B,T,M} # With clover term + fdims = dims(D_diag) + NV = D_diag.NV + return @latsum(EvenSites(), Val(1), Float64, trlog_kernel, D_diag, fdims, NV) +end + +@kernel function trlog_kernel(out, @Const(D_diag), fdims, NV) + bi = @index(Group, Linear) + _site = @index(Global, Cartesian) + o_site = switch_sides(_site, fdims..., NV) + + resₙ = 0.0 + p = D_diag[o_site] + resₙ += log(real(det(p.upper)) * real(det(p.lower))) + + out_group = @groupreduce(+, resₙ, 0.0) + + ti = @index(Local) + if ti == 1 + @inbounds out[bi] = out_group + end +end diff --git a/src/diracoperators/wilson_eo.jl b/src/diracoperators/wilson_eo.jl index b807db2..07d97e5 100644 --- a/src/diracoperators/wilson_eo.jl +++ b/src/diracoperators/wilson_eo.jl @@ -255,8 +255,8 @@ end # The Gaugefields module into CG.jl, which also allows us to use the solvers for # for arbitrary arrays, not just fermion fields and dirac operators (good for testing) function LinearAlgebra.mul!( - ψ_eo::TF, D::WilsonEOPreDiracOperator{CPU,T,C,TF,TG,TX,TO}, ϕ_eo::TF -) where {T,C,TF,TG,TX,TO} + ψ_eo::TF, D::WilsonEOPreDiracOperator{B,T,C,TF,TG,TX,TO}, ϕ_eo::TF +) where {B,T,C,TF,TG,TX,TO} @assert TG !== Nothing "Dirac operator has no gauge background, do `D(U)`" U = D.U check_dims(ψ_eo, ϕ_eo, U) @@ -272,8 +272,8 @@ function LinearAlgebra.mul!( end function LinearAlgebra.mul!( - ψ_eo::TF, D::Daggered{WilsonEOPreDiracOperator{CPU,T,C,TF,TG,TX,TO,BC}}, ϕ_eo::TF -) where {T,C,TF,TG,TX,TO,BC} + ψ_eo::TF, D::Daggered{WilsonEOPreDiracOperator{B,T,C,TF,TG,TX,TO,BC}}, ϕ_eo::TF +) where {B,T,C,TF,TG,TX,TO,BC} @assert TG !== Nothing "Dirac operator has no gauge background, do `D(U)`" U = D.parent.U check_dims(ψ_eo, ϕ_eo, U) @@ -289,8 +289,8 @@ function LinearAlgebra.mul!( end function LinearAlgebra.mul!( - ψ_eo::TF, D::DdaggerD{WilsonEOPreDiracOperator{CPU,T,C,TF,TG,TX,TO,BC}}, ϕ_eo::TF -) where {T,C,TF,TG,TX,TO,BC} + ψ_eo::TF, D::DdaggerD{WilsonEOPreDiracOperator{B,T,C,TF,TG,TX,TO,BC}}, ϕ_eo::TF +) where {B,T,C,TF,TG,TX,TO,BC} temp = D.parent.temp mul!(temp, D.parent, ϕ_eo) # temp = Dϕ mul!(ψ_eo, adjoint(D.parent), temp) # ψ = D†Dϕ @@ -306,13 +306,14 @@ function mul_oe!( fdims = dims(ψ) NV = ψ.NV - #= @batch =#for site in eachindex(ψ) + @batch for site in eachindex(ψ) isodd(site) || continue _site = if into_odd eo_site(site, fdims..., NV) else eo_site_switch(site, fdims..., NV) end + ψ[_site] = fac * wilson_eo_kernel(U, ϕ, site, bc, T, Val(dagg)) end @@ -329,13 +330,14 @@ function mul_eo!( fdims = dims(ψ) NV = ψ.NV - #= @batch =#for site in eachindex(ψ) + @batch for site in eachindex(ψ) iseven(site) || continue _site = if into_odd eo_site_switch(site, fdims..., NV) else eo_site(site, fdims..., NV) end + ψ[_site] = fac * wilson_eo_kernel(U, ϕ, site, bc, T, Val(dagg)) end @@ -382,7 +384,7 @@ function wilson_eo_kernel(U, ϕ, site, bc, ::Type{T}, ::Val{dagg}) where {T,dagg end function calc_diag!( - D_diag::TW, D_oo_inv::TW, ::Nothing, U::Gaugefield{CPU,T}, mass + D_diag::TW, D_oo_inv::TW, ::Nothing, U::Gaugefield{CPU,T,M}, mass ) where {T,M,TW<:Paulifield{CPU,T,M,false}} check_dims(D_diag, D_oo_inv, U) mass_term = Complex{T}(4 + mass) @@ -390,21 +392,29 @@ function calc_diag!( NV = U.NV @batch for site in eachindex(U) - _site = eo_site(site, fdims..., NV) - A = SMatrix{6,6,Complex{T},36}(mass_term * I) - D_diag[site] = PauliMatrix(A, A) - - if isodd(site) - o_site = switch_sides(_site, fdims..., NV) - A_inv = SMatrix{6,6,Complex{T},36}(1/mass_term * I) - D_oo_inv[o_site] = PauliMatrix(A_inv, A_inv) - end + calc_diag_kernel!(D_diag, D_oo_inv, mass_term, site, fdims, NV, T) end end +function calc_diag_kernel!( + D_diag, D_oo_inv, mass_term, site, fdims, NV, ::Type{T} +) where {T} + _site = eo_site(site, fdims..., NV) + A = SMatrix{6,6,Complex{T},36}(mass_term * I) + D_diag[site] = PauliMatrix(A, A) + + if isodd(site) + o_site = switch_sides(_site, fdims..., NV) + A_inv = SMatrix{6,6,Complex{T},36}(1/mass_term * I) + D_oo_inv[o_site] = PauliMatrix(A_inv, A_inv) + end + + return nothing +end + function calc_diag!( - D_diag::TW, D_oo_inv::TW, Fμν, U::Gaugefield{CPU,T}, mass -) where {T,MP,TW<:Paulifield{CPU,T,MP,true}} # With clover term + D_diag::TW, D_oo_inv::TW, Fμν::Tensorfield{B,T,M}, U::Gaugefield{CPU,T,M}, mass +) where {B,T,M,TW<:Paulifield{CPU,T,M,true}} # With clover term check_dims(D_diag, D_oo_inv, U) mass_term = Complex{T}(4 + mass) fdims = dims(U) @@ -414,49 +424,55 @@ function calc_diag!( fieldstrength_eachsite!(Clover(), Fμν, U) @batch for site in eachindex(U) - _site = eo_site(site, fdims..., NV) - M = SMatrix{6,6,Complex{T},36}(mass_term * I) - i = SVector((1, 2)) - j = SVector((3, 4)) - - F₁₂ = Fμν[1, 2, site] - σ = σ12(T) - A₊ = ckron(σ[i, i], F₁₂) - A₋ = ckron(σ[j, j], F₁₂) - - F₁₃ = Fμν[1, 3, site] - σ = σ13(T) - A₊ += ckron(σ[i, i], F₁₃) - A₋ += ckron(σ[j, j], F₁₃) - - F₁₄ = Fμν[1, 4, site] - σ = σ14(T) - A₊ += ckron(σ[i, i], F₁₄) - A₋ += ckron(σ[j, j], F₁₄) - - F₂₃ = Fμν[2, 3, site] - σ = σ23(T) - A₊ += ckron(σ[i, i], F₂₃) - A₋ += ckron(σ[j, j], F₂₃) - - F₂₄ = Fμν[2, 4, site] - σ = σ24(T) - A₊ += ckron(σ[i, i], F₂₄) - A₋ += ckron(σ[j, j], F₂₄) - - F₃₄ = Fμν[3, 4, site] - σ = σ34(T) - A₊ += ckron(σ[i, i], F₃₄) - A₋ += ckron(σ[j, j], F₃₄) - - A₊ = fac * A₊ + M - A₋ = fac * A₋ + M - D_diag[_site] = PauliMatrix(A₊, A₋) - - if isodd(site) - o_site = switch_sides(_site, fdims..., NV) - D_oo_inv[o_site] = PauliMatrix(cinv(A₊), cinv(A₋)) - end + calc_diag_kernel!(D_diag, D_oo_inv, Fμν, mass_term, site, fdims, NV, fac, T) + end +end + +function calc_diag_kernel!( + D_diag, D_oo_inv, Fμν, mass_term, site, fdims, NV, fac, ::Type{T} +) where {T} + _site = eo_site(site, fdims..., NV) + M = SMatrix{6,6,Complex{T},36}(mass_term * I) + i = SVector((1, 2)) + j = SVector((3, 4)) + + F₁₂ = Fμν[1, 2, site] + σ = σ12(T) + A₊ = ckron(σ[i, i], F₁₂) + A₋ = ckron(σ[j, j], F₁₂) + + F₁₃ = Fμν[1, 3, site] + σ = σ13(T) + A₊ += ckron(σ[i, i], F₁₃) + A₋ += ckron(σ[j, j], F₁₃) + + F₁₄ = Fμν[1, 4, site] + σ = σ14(T) + A₊ += ckron(σ[i, i], F₁₄) + A₋ += ckron(σ[j, j], F₁₄) + + F₂₃ = Fμν[2, 3, site] + σ = σ23(T) + A₊ += ckron(σ[i, i], F₂₃) + A₋ += ckron(σ[j, j], F₂₃) + + F₂₄ = Fμν[2, 4, site] + σ = σ24(T) + A₊ += ckron(σ[i, i], F₂₄) + A₋ += ckron(σ[j, j], F₂₄) + + F₃₄ = Fμν[3, 4, site] + σ = σ34(T) + A₊ += ckron(σ[i, i], F₃₄) + A₋ += ckron(σ[j, j], F₃₄) + + A₊ = fac * A₊ + M + A₋ = fac * A₋ + M + D_diag[_site] = PauliMatrix(A₊, A₋) + + if isodd(site) + o_site = switch_sides(_site, fdims..., NV) + D_oo_inv[o_site] = PauliMatrix(cinv(A₊), cinv(A₋)) end end @@ -491,7 +507,7 @@ function axmy!( return nothing end -function trlog(D_diag::Paulifield{CPU,T,M,false}, mass) where {T,M} # Without clover term +function trlog(D_diag::Paulifield{B,T,M,false}, mass) where {B,T,M} # Without clover term NC = num_colors(D_diag) mass_term = Float64(4 + mass) logd = 4NC * log(mass_term) diff --git a/src/fields/gpu_iterators.jl b/src/fields/gpu_iterators.jl index ba5d33e..e48c699 100644 --- a/src/fields/gpu_iterators.jl +++ b/src/fields/gpu_iterators.jl @@ -11,7 +11,7 @@ function __latmap( # KernelAbstractions requires an ndrange (indices we iterate over) and a # workgroupsize (number of threads in each workgroup / thread block on the GPU) ndrange = local_dims(U) - workgroupsize = (4, 4, 4, 4) # 4^4 = 256 threads per workgroup should be fine + workgroupsize = ntuple(i -> min(ndrange[i], 4), Val(4)) # 4^4 = 256 threads per workgroup should be fine kernel! = f!(B(), workgroupsize) # I couldn't be bothered making GPUs work with array wrappers such as AbstractField # (see Adapt.jl), so we extract the actual array from all AbstractFields in the args @@ -34,7 +34,7 @@ function __latmap( NX, NY, NZ, NT = local_dims(ϕ_eo) @assert iseven(NT) "NT must be even for even-odd preconditioned fermions" ndrange = (NX, NY, NZ, div(NT, 2)) - workgroupsize = (4, 4, 4, 2) + workgroupsize = ntuple(i -> min(ndrange[i], 4), Val(4)) kernel! = f!(B(), workgroupsize) raw_args = get_raws(args...) @@ -46,57 +46,57 @@ function __latmap( return nothing end -function __latmap( - ::Checkerboard2, ::Val{COUNT}, f!::F, U::AbstractField{B}, args... -) where {COUNT,F,B<:GPU} - COUNT == 0 && return nothing - NX, NY, NZ, NT = local_dims(U) - @assert( - mod.((NX, NY, NZ, NT), 2) == (0, 0, 0, 0), - "CB2 only works for side lengths that are multiples of 2" - ) - ndrange = (NY, NZ, NT) - workgroupsize = (4, 4, 4) # since we only have 64 threads per block, we can use heavier kernels - kernel! = f!(B(), workgroupsize) - raw_args = get_raws(args...) - - for _ in 1:COUNT - for μ in 1:4 - for pass in 1:2 - kernel!(U.U, μ, pass, raw_args...; ndrange=ndrange) - KA.synchronize(B()) - end - end - end - - return nothing -end - -function __latmap( - ::Checkerboard4, ::Val{COUNT}, f!::F, U::AbstractField{B}, args... -) where {COUNT,F,B<:GPU} - COUNT == 0 && return nothing - NX, NY, NZ, NT = local_dims(U) - @assert( - mod.((NX, NY, NZ, NT), 4) == (0, 0, 0, 0), - "CB4 only works for side lengths that are multiples of 4" - ) - ndrange = (NY, NZ, NT) - workgroupsize = (4, 4, 4) - kernel! = f!(B(), workgroupsize) - raw_args = get_raws(args...) - - for _ in 1:COUNT - for μ in 1:4 - for pass in 1:4 - kernel!(U.U, μ, pass, raw_args...; ndrange=ndrange) - KA.synchronize(B()) - end - end - end - - return nothing -end +# function __latmap( +# ::Checkerboard2, ::Val{COUNT}, f!::F, U::AbstractField{B}, args... +# ) where {COUNT,F,B<:GPU} +# COUNT == 0 && return nothing +# NX, NY, NZ, NT = local_dims(U) +# @assert( +# mod.((NX, NY, NZ, NT), 2) == (0, 0, 0, 0), +# "CB2 only works for side lengths that are multiples of 2" +# ) +# ndrange = (NY, NZ, NT) +# workgroupsize = ntuple(i -> min(ndrange[i], 4), Val(3)) +# kernel! = f!(B(), workgroupsize) +# raw_args = get_raws(args...) +# +# for _ in 1:COUNT +# for μ in 1:4 +# for pass in 1:2 +# kernel!(U.U, μ, pass, raw_args...; ndrange=ndrange) +# KA.synchronize(B()) +# end +# end +# end +# +# return nothing +# end +# +# function __latmap( +# ::Checkerboard4, ::Val{COUNT}, f!::F, U::AbstractField{B}, args... +# ) where {COUNT,F,B<:GPU} +# COUNT == 0 && return nothing +# NX, NY, NZ, NT = local_dims(U) +# @assert( +# mod.((NX, NY, NZ, NT), 4) == (0, 0, 0, 0), +# "CB4 only works for side lengths that are multiples of 4" +# ) +# ndrange = (NY, NZ, NT) +# workgroupsize = ntuple(i -> min(ndrange[i], 4), Val(3)) +# kernel! = f!(B(), workgroupsize) +# raw_args = get_raws(args...) +# +# for _ in 1:COUNT +# for μ in 1:4 +# for pass in 1:4 +# kernel!(U.U, μ, pass, raw_args...; ndrange=ndrange) +# KA.synchronize(B()) +# end +# end +# end +# +# return nothing +# end macro latsum(itr, C, f!, U, args...) quote @@ -109,7 +109,7 @@ function __latsum( ) where {COUNT,OutType,F,B<:GPU} COUNT == 0 && return 0.0 ndrange = local_dims(U) - workgroupsize = (4, 4, 4, 4) + workgroupsize = ntuple(i -> min(ndrange[i], 4), Val(4)) numblocks = cld(U.NV, prod(workgroupsize)) out = KA.zeros(B(), OutType, numblocks) kernel! = f!(B(), workgroupsize) @@ -130,7 +130,7 @@ function __latsum( NX, NY, NZ, NT = local_dims(ϕ_eo) @assert iseven(NT) "NT must be even for even-odd preconditioned fermions" ndrange = (NX, NY, NZ, div(NT, 2)) - workgroupsize = (4, 4, 4, 2) + workgroupsize = ntuple(i -> min(ndrange[i], 4), Val(4)) numblocks = cld(div(ϕ_eo.parent.NV, 2), prod(workgroupsize)) out = KA.zeros(B(), OutType, numblocks) kernel! = f!(B(), workgroupsize) @@ -144,58 +144,78 @@ function __latsum( return sum(out) end -function __latsum( - ::Checkerboard2, ::Val{COUNT}, f!::F, U::AbstractField{B}, args... -) where {COUNT,F,B<:GPU} - COUNT == 0 && return 0.0 - NX, NY, NZ, NT = local_dims(U) - @assert( - mod.((NX, NY, NZ, NT), 2) == (0, 0, 0, 0), - "CB2 only works for side lengths that are multiples of 2" - ) - ndrange = (NY, NZ, NT) - workgroupsize = (4, 4, 4) - numblocks = cld(U.NV, prod(workgroupsize)) - out = KA.zeros(B(), Float64, numblocks) - kernel! = f!(B(), workgroupsize) - raw_args = get_raws(args...) - numpasses = U isa SpinorfieldEO ? 2 : 1 - - for _ in 1:COUNT - for μ in 1:4 - for pass in 1:numpasses - kernel!(out, U.U, μ, pass, raw_args...; ndrange=ndrange) - KA.synchronize(B()) - end - end - end - - return sum(out) -end +# function __latsum( +# ::Checkerboard2, ::Val{COUNT}, ::Type{OutType}, f!::F, U::AbstractField{B}, args... +# ) where {COUNT,OutType,F,B<:GPU} +# COUNT == 0 && return 0.0 +# NX, NY, NZ, NT = local_dims(U) +# @assert( +# mod.((NX, NY, NZ, NT), 2) == (0, 0, 0, 0), +# "CB2 only works for side lengths that are multiples of 2" +# ) +# ndrange = (NY, NZ, NT) +# workgroupsize = ntuple(i -> min(ndrange[i], 4), Val(3)) +# numblocks = cld(U.NV, prod(workgroupsize)) +# out = KA.zeros(B(), OutType, numblocks) +# kernel! = f!(B(), workgroupsize) +# raw_args = get_raws(args...) +# numpasses = U isa SpinorfieldEO ? 2 : 1 +# +# for _ in 1:COUNT +# for μ in 1:4 +# for pass in 1:numpasses +# kernel!(out, U.U, μ, pass, raw_args...; ndrange=ndrange) +# KA.synchronize(B()) +# end +# end +# end +# +# return sum(out) +# end +# +# function __latsum( +# ::Checkerboard4, ::Val{COUNT}, ::Type{OutType}, f!::F, U::AbstractField{B}, args... +# ) where {COUNT,OutType,F,B<:GPU} +# COUNT == 0 && return 0.0 +# NX, NY, NZ, NT = local_dims(U) +# @assert( +# mod.((NX, NY, NZ, NT), 4) == (0, 0, 0, 0), +# "CB4 only works for side lengths that are multiples of 4" +# ) +# ndrange = (NY, NZ, NT) +# workgroupsize = ntuple(i -> min(ndrange[i], 4), Val(3)) +# numblocks = cld(U.NV, prod(workgroupsize)) +# out = KA.zeros(B(), OutType, numblocks) +# kernel! = f!(B(), workgroupsize) +# raw_args = get_raws(args...) +# +# for _ in 1:COUNT +# for μ in 1:4 +# for pass in 1:4 +# kernel!(out, U.U, μ, pass, raw_args...; ndrange=ndrange) +# KA.synchronize(B()) +# end +# end +# end +# +# return sum(out) +# end function __latsum( - ::Checkerboard4, ::Val{COUNT}, f!::F, U::AbstractField{B}, args... -) where {COUNT,F,B<:GPU} + ::EvenSites, ::Val{COUNT}, ::Type{OutType}, f!::F, U::AbstractField{B}, args... +) where {COUNT,OutType,F,B<:GPU} COUNT == 0 && return 0.0 - NX, NY, NZ, NT = local_dims(U) - @assert( - mod.((NX, NY, NZ, NT), 4) == (0, 0, 0, 0), - "CB4 only works for side lengths that are multiples of 4" - ) - ndrange = (NY, NZ, NT) - workgroupsize = (4, 4, 4) + fdims = local_dims(U) + ndrange = ntuple(i -> (i == 4 ? fdims[i]÷2 : fdims[i]), Val(4)) + workgroupsize = ntuple(i -> min(ndrange[i], 4), Val(4)) numblocks = cld(U.NV, prod(workgroupsize)) - out = KA.zeros(B(), Float64, numblocks) + out = KA.zeros(B(), OutType, numblocks) kernel! = f!(B(), workgroupsize) raw_args = get_raws(args...) for _ in 1:COUNT - for μ in 1:4 - for pass in 1:4 - kernel!(out, U.U, μ, pass, raw_args...; ndrange=ndrange) - KA.synchronize(B()) - end - end + kernel!(out, U.U, raw_args...; ndrange=ndrange) + KA.synchronize(B()) end return sum(out) diff --git a/src/utils/Utils.jl b/src/utils/Utils.jl index 986992d..872215b 100644 --- a/src/utils/Utils.jl +++ b/src/utils/Utils.jl @@ -23,7 +23,7 @@ export antihermitian, hermitian, traceless_antihermitian, traceless_hermitian, m export zero2, zero3, zerov3, eye2, eye3, onev3, gaussian_TA_mat, rand_SU3 export SiteCoords, eo_site, eo_site_switch, move, switch_sides export cartesian_to_linear -export Sequential, Checkerboard2, Checkerboard4 +export Sequential, Checkerboard2, Checkerboard4, EvenSites, OddSites export λ, expλ, γ1, γ2, γ3, γ4, γ5, σ12, σ13, σ14, σ23, σ24, σ34 export cmatmul_oo, cmatmul_dd, cmatmul_do, cmatmul_od export cmatmul_ooo, @@ -60,6 +60,8 @@ abstract type AbstractIterator end struct Sequential <: AbstractIterator end struct Checkerboard2 <: AbstractIterator end struct Checkerboard4 <: AbstractIterator end +struct EvenSites <: AbstractIterator end +struct OddSites <: AbstractIterator end _unwrap_val(::Val{B}) where {B} = B