Skip to content

Commit

Permalink
Add GPU kernels for even-odd Wilson-Clover
Browse files Browse the repository at this point in the history
  • Loading branch information
GianlucaFuwa committed Oct 14, 2024
1 parent 29b5e9f commit 9885e49
Show file tree
Hide file tree
Showing 6 changed files with 367 additions and 169 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ Inspired by the [LatticeQCD.jl](https://github.com/akio-tomiya/LatticeQCD.jl/tre
- [x] Improved Topological charge definitions (clover, rectangle clover-improved)
- [x] Wilson(-Clover) fermions
- [x] Staggered fermions
- [ ] Even-odd preconditioner for Wilson(-Clover)
- [x] Even-odd preconditioner for Wilson(-Clover)
- [x] Even-odd preconditioner for Staggered
- [ ] Mass-splitting preconditioner / Hasenbusch trick
- [x] RHMC to simulate odd number of flavours
- [ ] Support for CUDA and ROCm backends
- [ ] Multi-node parallelism using MPI
- [ ] Full support for CUDA and ROCm backends
- [x] Multi-node parallelism using MPI

## Installation:
First make sure you have Julia version 1.9.x or 1.10.x installed. You can use [juliaup](https://github.com/JuliaLang/juliaup) for that or just install the release from the [Julia website](https://julialang.org/downloads/).
Expand Down
5 changes: 3 additions & 2 deletions src/diracoperators/DiracOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import KernelAbstractions as KA
import ..Fields: AbstractField, FieldTopology, Gaugefield, Paulifield, Spinorfield
import ..Fields: SpinorfieldEO, Tensorfield
import ..Fields: check_dims, clear!, clover_square, dims, even_odd, gaussian_pseudofermions!
import ..Fields: Clover, Checkerboard2, Sequential, @latmap, @latsum, set_source!, volume
import ..Fields: fieldstrength_eachsite!, num_colors, num_dirac
import ..Fields: @latmap, @latsum, Clover, Checkerboard2, Sequential, set_source!, volume
import ..Fields: @groupreduce, fieldstrength_eachsite!, num_colors, num_dirac
import ..Fields: PeriodicBC, AntiPeriodicBC, apply_bc, create_bc, distributed_reduce

abstract type AbstractDiracOperator end
Expand Down Expand Up @@ -81,6 +81,7 @@ include("wilson_eo.jl")
include("gpu_kernels/staggered.jl")
include("gpu_kernels/staggered_eo.jl")
include("gpu_kernels/wilson.jl")
include("gpu_kernels/wilson_eo.jl")
include("arnoldi.jl")

"""
Expand Down
159 changes: 159 additions & 0 deletions src/diracoperators/gpu_kernels/wilson_eo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
function mul_oe!(
ψ_eo::TF, U::Gaugefield{B,T}, ϕ_eo::TF, bc, into_odd, ::Val{dagg}; fac=1
) where {B,T,TF<:WilsonEOPreSpinorfield{B,T},dagg}
check_dims(ψ_eo, ϕ_eo, U)
ψ = ψ_eo.parent
ϕ = ϕ_eo.parent
fdims = dims(ψ)
NV = ψ.NV

@latmap(
Sequential(), Val(1), wilson_eo_kernel!, ψ, U, ϕ, bc, fdims, NV, fac, T, Val(dagg)
)
update_halo!(ψ_eo)
return nothing
end

@kernel function wilson_eo_kernel!(
ψ, @Const(U), @Const(ϕ), bc, fdims, NV, fac, ::Type{T}, ::Val{dagg}
) where {T,dagg}
site = @index(Global, Cartesian)
_site = if into_odd
eo_site(site, fdims..., NV)
else
eo_site_switch(site, fdims..., NV)
end

if iseven(site)
@inbounds ψ[_site] = fac * wilson_eo_kernel(U, ϕ, site, bc, T, Val(dagg))
end
end

function mul_eo!(
ψ_eo::TF, U::Gaugefield{B,T}, ϕ_eo::TF, bc, into_odd, ::Val{dagg}; fac=1
) where {B,T,TF<:WilsonEOPreSpinorfield{B,T},dagg}
check_dims(ψ_eo, ϕ_eo, U)
ψ = ψ_eo.parent
ϕ = ϕ_eo.parent
fdims = dims(ψ)
NV = ψ.NV

@latmap(
Sequential(), Val(1), wilson_oe_kernel!, ψ, U, ϕ, bc, fdims, NV, fac, T, Val(dagg)
)
update_halo!(ψ_eo)
return nothing
end

@kernel function wilson_oe_kernel!(
ψ, @Const(U), @Const(ϕ), bc, fdims, NV, fac, ::Type{T}, ::Val{dagg}
) where {T,dagg}
site = @index(Global, Cartesian)
_site = if into_odd
eo_site_switch(site, fdims..., NV)
else
eo_site(site, fdims..., NV)
end

if iseven(site)
@inbounds ψ[_site] = fac * wilson_eo_kernel(U, ϕ, site, bc, T, Val(dagg))
end
end

function calc_diag!(
D_diag::TW, D_oo_inv::TW, ::Nothing, U::Gaugefield{B,T}, mass
) where {B,T,M,TW<:Paulifield{B,T,M,false}}
check_dims(D_diag, D_oo_inv, U)
mass_term = Complex{T}(4 + mass)
fdims = dims(U)
NV = U.NV

@latmap(Sequential(), Val(1), calc_diag_kernel!, D_diag, D_oo_inv, mass_term, fdims, NV)
return nothing
end

@kernel function calc_diag_kernel!(
D_diag, D_oo_inv, mass_term, fdims, NV, ::Type{T}
) where {T}
site = @index(Global, Cartesian)
calc_diag_kernel!(D_diag, D_oo_inv, mass_term, site, fdims, NV, T)
end

function calc_diag!(
D_diag::TW, D_oo_inv::TW, Fμν::Tensorfield{B,T,M}, U::Gaugefield{B,T,M}, mass
) where {B,T,M,TW<:Paulifield{B,T,M,true}}
check_dims(D_diag, D_oo_inv, U)
mass_term = Complex{T}(4 + mass)
fdims = dims(U)
NV = U.NV
fac = Complex{T}(D_diag.csw / 2)

@latmap(
Sequential(), Val(1), calc_diag_kernel!, D_diag, D_oo_inv, mass_term, fdims, NV, fac, T
)
return nothing
end

@kernel function calc_diag_kernel!(
D_diag, D_oo_inv, @Const(Fμν), mass_term, fdims, NV, fac, ::Type{T}
) where {T}
site = @index(Global, Cartesian)
calc_diag_kernel!(D_diag, D_oo_inv, Fμν, mass_term, site, fdims, NV, fac, T)
end

function mul_oo_inv!(
ϕ_eo::WilsonEOPreSpinorfield{B,T,M}, D_oo_inv::Paulifield{B,T,M}
) where {B,T,M}
check_dims(ϕ_eo, D_oo_inv)
ϕ = ϕ_eo.parent
fdims = dims(U)
NV = U.NV

@latmap(EvenSites(), Val(1), mul_oo_inv_kernel!, ϕ, D_oo_inv, fdims, NV)
return nothing
end

@kernel function mul_oo_inv_kernel!(ϕ, D_oo_inv, fdims, NV)
_site = @index(Global, Cartesian)
o_site = switch_sides(_site, fdims..., NV)
ϕ[o_site] = cmvmul_block(D_oo_inv[_site], ϕ[o_site])
end

function axmy!(
D_diag::Paulifield{B,T,M}, ψ_eo::TF, ϕ_eo::TF
) where {B,T,M,TF<:WilsonEOPreSpinorfield{B,T,M}} # even on even is the default
check_dims(ϕ_eo, ψ_eo)
ϕ = ϕ_eo.parent
ψ = ψ_eo.parent

@latmap(EvenSites(), Val(1), axmy_kernel!, D_diag, ψ, ϕ)
return nothing
end

@kernel function axmy_kernel!(@Const(D_diag), ψ, ϕ)
_site = @index(Global, Cartesian)
ϕ[_site] = cmvmul_block(D_diag[_site], ψ[_site]) - ϕ[_site]
end

function trlog(D_diag::Paulifield{B,T,M,true}, ::Any) where {B,T,M} # With clover term
fdims = dims(D_diag)
NV = D_diag.NV
return @latsum(EvenSites(), Val(1), Float64, trlog_kernel, D_diag, fdims, NV)
end

@kernel function trlog_kernel(out, @Const(D_diag), fdims, NV)
bi = @index(Group, Linear)
_site = @index(Global, Cartesian)
o_site = switch_sides(_site, fdims..., NV)

resₙ = 0.0
p = D_diag[o_site]
resₙ += log(real(det(p.upper)) * real(det(p.lower)))

out_group = @groupreduce(+, resₙ, 0.0)

ti = @index(Local)
if ti == 1
@inbounds out[bi] = out_group
end
end
Loading

0 comments on commit 9885e49

Please sign in to comment.