Skip to content

Commit

Permalink
Some more progress on even-odd Wilson
Browse files Browse the repository at this point in the history
  • Loading branch information
GianlucaFuwa committed Jul 3, 2024
1 parent 8079406 commit cb27448
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 87 deletions.
6 changes: 3 additions & 3 deletions src/diracoperators/wilson_eo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct WilsonEOPreDiracOperator{B,T,C,TF,TG,TX,TO} <: AbstractDiracOperator
κ = 1 / (2mass + 8)
U = nothing
C = csw == 0 ? false : true
Xμν = C ? Tensorfield(U) : nothing
Xμν = C ? Tensorfield(f) : nothing
temp = even_odd(Fermionfield{B,T,4}(dims(f)...))
D_diag = WilsonEODiagonal(temp, mass, csw)
D_oo_inv = WilsonEODiagonal(temp, mass, csw; inverse=true)
Expand Down Expand Up @@ -320,9 +320,9 @@ function LinearAlgebra.mul!(
D_oo_inv = D.parent.D_oo_inv
D_diag = D.parent.D_diag

mul_oe!(ψ_eo, U, ϕ_eo, anti, true, Val(-1)) # ψₒ = Dₒₑϕₑ
mul_oe!(ψ_eo, U, ϕ_eo, anti, true, Val(-1)) # ψₒ = Dₑₒ†ϕₑ
mul_oo_inv!(ψ_eo, D_oo_inv) # ψₒ = Dₒₒ⁻¹Dₒₑϕₑ
mul_eo!(ψ_eo, U, ψ_eo, anti, false, Val(-1)) # ψₑ = DₑₒDₒₒ⁻¹Dₒₑϕₑ
mul_eo!(ψ_eo, U, ψ_eo, anti, false, Val(-1)) # ψₑ = Dₒₑ†Dₒₒ⁻¹Dₑₒ†ϕₑ
axmy!(D_diag, ϕ_eo, ψ_eo) # ψₑ = Dₑₑϕₑ - DₑₒDₒₒ⁻¹Dₒₑϕₑ
return nothing
end
Expand Down
179 changes: 95 additions & 84 deletions src/forces/wilson_eo_force.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ function calc_dSfdU!(
clear!(X_eo)
solve_dirac!(X_eo, DdagD, ϕ_eo, Y_eo, temp1, temp2, cg_tol, cg_maxiters) # Y is used here merely as a temp LinearAlgebra.mul!(Y, D, X) # Need to prefix with LinearAlgebra to avoid ambiguity with Gaugefields.mul!

# set odd components of X and Y
LinearAlgebra.mul!(Y_eo, D, X_eo)
mul_oe!(X_eo, U, X_eo, anti, true, Val(1)) # Need to prefix with LinearAlgebra to avoid ambiguity with Gaugefields.mul!
mul_oe!(Y_eo, U, Y_eo, anti, true, Val(1)) # Need to prefix with LinearAlgebra to avoid ambiguity with Gaugefields.mul!
mul_oe!(Y_eo, U, Y_eo, anti, true, Val(-1)) # Need to prefix with LinearAlgebra to avoid ambiguity with Gaugefields.mul!
mul_oo_inv!(X_eo, D.D_oo_inv)
mul_oo_inv!(Y_eo, D.D_oo_inv)
add_wilson_eo_derivative!(dU, U, X_eo, Y_eo, anti)

# TODO: Clover derivatives
if C
Xμν = fermion_action.Xμν
Xμν = D.Xμν
calc_Xμν_eachsite!(Xμν, X_eo, Y_eo)
add_clover_eo_derivative!(dU, U, Xμν, D.csw)
end
Expand Down Expand Up @@ -55,9 +55,14 @@ function calc_dSfdU!(

for i in 1:n
LinearAlgebra.mul!(Ys[i+1], D, Xs[i+1]) # Need to prefix with LinearAlgebra to avoid ambiguity with Gaugefields.mul!
mul_oe!(Xs[i+1], U, Xs[i+1], anti, true, Val(1)) # Need to prefix with LinearAlgebra to avoid ambiguity with Gaugefields.mul!
mul_oe!(Ys[i+1], U, Ys[i+1], anti, true, Val(-1)) # Need to prefix with LinearAlgebra to avoid ambiguity with Gaugefields.mul!
mul_oo_inv!(Xs[i+1], D.D_oo_inv)
mul_oo_inv!(Ys[i+1], D.D_oo_inv)
add_wilson_derivative!(dU, U, Xs[i+1], Ys[i+1], anti; coeff=coeffs[i])
# TODO: Clover derivatives
if C
Xμν = fermion_action.Xμν
Xμν = D.Xμν
calc_Xμν_eachsite!(Xμν, Xs[i+1], Ys[i+1])
add_clover_derivative!(dU, U, Xμν, D.csw; coeff=coeffs[i])
end
Expand All @@ -72,46 +77,43 @@ function add_wilson_eo_derivative!(
NT = dims(U)[4]
fac = T(0.5coeff)

# If we write out the kernel and use @batch, the program crashes for some reason
# INFO: If we write out the kernel and use @batch, the program crashes for some reason
# Stems from "pload" from StrideArraysCore.jl but ONLY if we write it out AND overload
# "object_and_preserve" (cant reproduce in MWE yet)
# is fine, because writing it like this makes the GPU port easier

#= @batch =#for site in eachindex(dU)
bc⁺ = boundary_factor(anti, site[4], 1, NT)
if isodd(site)
add_wilson_eo_derivative_kernel!(dU, U, X_eo, Y_eo, site, bc⁺, fac, Val(-1))
else
add_wilson_eo_derivative_kernel!(dU, U, Y_eo, X_eo, site, bc⁺, fac, Val(1))
end
add_wilson_eo_derivative_kernel!(dU, U, X_eo, Y_eo, site, bc⁺, fac)
end
end

function add_wilson_eo_derivative_kernel!(
dU, U, X_eo, Y_eo, site, bc⁺, fac, ::Val{DIR}
) where {DIR}
function add_wilson_eo_derivative_kernel!(dU, U, X_eo, Y_eo, site, bc⁺, fac)
# sites that begin with a "_" are meant for indexing into the even-odd preconn'ed
# fermion field
NX, NY, NZ, NT = dims(U)
NV = NX * NY * NZ * NT
_site = eo_site(site, NX, NY, NZ, NT, NV)

_siteμ⁺ = eo_site(move(site, 1, 1, NX), NX, NY, NZ, NT, NV)
# B = spintrace(spin_proj(X_eo[_siteμ⁺], Val(-1)), Y_eo[_site])
C = spintrace(spin_proj(Y_eo[_siteμ⁺], Val(1DIR)), X_eo[_site])
dU[1i32, site] += fac * traceless_antihermitian(cmatmul_oo(U[1, site], C))
B = spintrace(spin_proj(X_eo[_siteμ⁺], Val(-1)), Y_eo[_site])
C = spintrace(spin_proj(Y_eo[_siteμ⁺], Val(1)), X_eo[_site])
dU[1i32, site] += fac * traceless_antihermitian(cmatmul_oo(U[1, site], B + C))

_siteμ⁺ = eo_site(move(site, 2, 1, NY), NX, NY, NZ, NT, NV)
# B = spintrace(spin_proj(X_eo[_siteμ⁺], Val(-2)), Y_eo[site])
C = spintrace(spin_proj(Y_eo[_siteμ⁺], Val(2DIR)), X_eo[site])
dU[2i32, site] += fac * traceless_antihermitian(cmatmul_oo(U[2, site], C))
B = spintrace(spin_proj(X_eo[_siteμ⁺], Val(-2)), Y_eo[_site])
C = spintrace(spin_proj(Y_eo[_siteμ⁺], Val(2)), X_eo[_site])
dU[2i32, site] += fac * traceless_antihermitian(cmatmul_oo(U[2, site], B + C))

_siteμ⁺ = eo_site(move(site, 3, 1, NZ), NX, NY, NZ, NT, NV)
# B = spintrace(spin_proj(X_eo[_siteμ⁺], Val(-3)), Y_eo[site])
C = spintrace(spin_proj(Y_eo[_siteμ⁺], Val(3DIR)), X_eo[site])
dU[3i32, site] += fac * traceless_antihermitian(cmatmul_oo(U[3, site], C))
B = spintrace(spin_proj(X_eo[_siteμ⁺], Val(-3)), Y_eo[_site])
C = spintrace(spin_proj(Y_eo[_siteμ⁺], Val(3)), X_eo[_site])
dU[3i32, site] += fac * traceless_antihermitian(cmatmul_oo(U[3, site], B + C))

_siteμ⁺ = eo_site(move(site, 4, 1, NT), NX, NY, NZ, NT, NV)
# B = spintrace(spin_proj(X_eo[_siteμ⁺], Val(-4)), Y_eo[site])
C = spintrace(spin_proj(Y_eo[_siteμ⁺], Val(4DIR)), X_eo[site])
dU[4i32, site] += bc⁺ * fac * traceless_antihermitian(cmatmul_oo(U[4, site], C))
B = spintrace(spin_proj(X_eo[_siteμ⁺], Val(-4)), Y_eo[_site])
C = spintrace(spin_proj(Y_eo[_siteμ⁺], Val(4)), X_eo[_site])
dU[4i32, site] += bc⁺ * fac * traceless_antihermitian(cmatmul_oo(U[4, site], B + C))
return nothing
end

Expand All @@ -122,79 +124,88 @@ function add_clover_eo_derivative!(
fac = T(csw * coeff / 2)

#= @batch =#for site in eachindex(dU)
add_clover_eo_derivative_kernel!(dU, U, Xμν, site, fac, T)
add_clover_derivative_kernel!(dU, U, Xμν, site, fac, T)
end

return nothing
end

function add_clover_eo_derivative_kernel!(dU, U, Xμν, site, fac, ::Type{T}) where {T}
tmp =
Xμν∇Fμν(Xμν, U, 1, 2, site, T) +
Xμν∇Fμν(Xμν, U, 1, 3, site, T) +
Xμν∇Fμν(Xμν, U, 1, 4, site, T)
dU[1i32, site] += fac * traceless_antihermitian(cmatmul_oo(U[1i32, site], tmp))

tmp =
Xμν∇Fμν(Xμν, U, 2, 1, site, T) +
Xμν∇Fμν(Xμν, U, 2, 3, site, T) +
Xμν∇Fμν(Xμν, U, 2, 4, site, T)
dU[2i32, site] += fac * traceless_antihermitian(cmatmul_oo(U[2i32, site], tmp))

tmp =
Xμν∇Fμν(Xμν, U, 3, 1, site, T) +
Xμν∇Fμν(Xμν, U, 3, 2, site, T) +
Xμν∇Fμν(Xμν, U, 3, 4, site, T)
dU[3i32, site] += fac * traceless_antihermitian(cmatmul_oo(U[3i32, site], tmp))

tmp =
Xμν∇Fμν(Xμν, U, 4, 1, site, T) +
Xμν∇Fμν(Xμν, U, 4, 2, site, T) +
Xμν∇Fμν(Xμν, U, 4, 3, site, T)
dU[4i32, site] += fac * traceless_antihermitian(cmatmul_oo(U[4i32, site], tmp))
return nothing
end

function calc_Xμν_eachsite!(
Xμν::Tensorfield{CPU,T}, X_eo::TF, Y_eo::TF
) where {T,TF<:WilsonEOPreFermionfield}
check_dims(Xμν, X_eo, Y_eo)
even = true

#= @batch =#for site in eachindex(even, Xμν)
calc_Xμν_kernel!(Xμν, X_eo, Y_eo, site)
#= @batch =#for site in eachindex(Xμν)
if isodd(site)
calc_Xμν_eo_kernel!(Xμν, X_eo, Y_eo, site)
else
clear_Xμν_eo_kernel!(Xμν, site, T)
end
end

return nothing
end

function Xμν∇Fμν(Xμν, U, μ, ν, site, ::Type{T}) where {T}
= dims(U)[μ]
= dims(U)[ν]
siteμ⁺ = move(site, μ, 1i32, Nμ)
siteν⁺ = move(site, ν, 1i32, Nν)
siteν⁻ = move(site, ν, -1i32, Nν)
siteμ⁺ν⁺ = move(siteμ⁺, ν, 1i32, Nν)
siteμ⁺ν⁻ = move(siteμ⁺, ν, -1i32, Nν)

# get reused matrices up to cache (can precalculate some products too)
# Uνsiteμ⁺ = U[ν,siteμ⁺]
# Uμsiteν⁺ = U[μ,siteν⁺]
# Uνsite = U[ν,site]
# Uνsiteμ⁺ν⁻ = U[ν,siteμ⁺ν⁻]
# Uμsiteν⁻ = U[μ,siteν⁻]
# Uνsiteν⁻ = U[ν,siteν⁻]

component =
cmatmul_oddo(U[ν, siteμ⁺], U[μ, siteν⁺], U[ν, site], Xμν[μ, ν, site]) +
cmatmul_odod(U[ν, siteμ⁺], U[μ, siteν⁺], Xμν[μ, ν, siteν⁺], U[ν, site]) +
cmatmul_oodd(U[ν, siteμ⁺], Xμν[μ, ν, siteμ⁺ν⁺], U[μ, siteν⁺], U[ν, site]) +
cmatmul_oodd(Xμν[μ, ν, siteμ⁺], U[ν, siteμ⁺], U[μ, siteν⁺], U[ν, site]) -
cmatmul_ddoo(U[ν, siteμ⁺ν⁻], U[μ, siteν⁻], U[ν, siteν⁻], Xμν[μ, ν, site]) -
cmatmul_ddoo(U[ν, siteμ⁺ν⁻], U[μ, siteν⁻], Xμν[μ, ν, siteν⁻], U[ν, siteν⁻]) -
cmatmul_dodo(U[ν, siteμ⁺ν⁻], Xμν[μ, ν, siteμ⁺ν⁻], U[μ, siteν⁻], U[ν, siteν⁻]) -
cmatmul_oddo(Xμν[μ, ν, siteμ⁺], U[ν, siteμ⁺ν⁻], U[μ, siteν⁻], U[ν, siteν⁻])

return im * T(1/8) * component
function calc_Xμν_eo_kernel!(Xμν, X_eo, Y_eo, site)
NX, NY, NZ, NT = dims(X_eo)
NV = NX * NY * NZ * NT
_site = eo_site(site, NX, NY, NZ, NT, NV)

X₁₂ =
spintrace(σμν_spin_mul(X_eo[_site], Val(1), Val(2)), Y_eo[_site]) +
spintrace(σμν_spin_mul(Y_eo[_site], Val(1), Val(2)), X_eo[_site])
Xμν[1i32, 2i32, site] = X₁₂
Xμν[2i32, 1i32, site] = -X₁₂

X₁₃ =
spintrace(σμν_spin_mul(X_eo[_site], Val(1), Val(3)), Y_eo[_site]) +
spintrace(σμν_spin_mul(Y_eo[_site], Val(1), Val(3)), X_eo[_site])
Xμν[1i32, 3i32, site] = X₁₃
Xμν[3i32, 1i32, site] = -X₁₃

X₁₄ =
spintrace(σμν_spin_mul(X_eo[_site], Val(1), Val(4)), Y_eo[_site]) +
spintrace(σμν_spin_mul(Y_eo[_site], Val(1), Val(4)), X_eo[_site])
Xμν[1i32, 4i32, site] = X₁₄
Xμν[4i32, 1i32, site] = -X₁₄

X₂₃ =
spintrace(σμν_spin_mul(X_eo[_site], Val(2), Val(3)), Y_eo[_site]) +
spintrace(σμν_spin_mul(Y_eo[_site], Val(2), Val(3)), X_eo[_site])
Xμν[2i32, 3i32, site] = X₂₃
Xμν[3i32, 2i32, site] = -X₂₃

X₂₄ =
spintrace(σμν_spin_mul(X_eo[_site], Val(2), Val(4)), Y_eo[_site]) +
spintrace(σμν_spin_mul(Y_eo[_site], Val(2), Val(4)), X_eo[_site])
Xμν[2i32, 4i32, site] = X₂₄
Xμν[4i32, 2i32, site] = -X₂₄

X₃₄ =
spintrace(σμν_spin_mul(X_eo[_site], Val(3), Val(4)), Y_eo[_site]) +
spintrace(σμν_spin_mul(Y_eo[_site], Val(3), Val(4)), X_eo[_site])
Xμν[3i32, 4i32, site] = X₃₄
Xμν[4i32, 3i32, site] = -X₃₄
return nothing
end

function clear_Xμν_eo_kernel!(Xμν, site, ::Type{T}) where {T}
Xμν[1i32, 2i32, site] = zero3(T)
Xμν[2i32, 1i32, site] = zero3(T)

Xμν[1i32, 3i32, site] = zero3(T)
Xμν[3i32, 1i32, site] = zero3(T)

Xμν[1i32, 4i32, site] = zero3(T)
Xμν[4i32, 1i32, site] = zero3(T)

Xμν[2i32, 3i32, site] = zero3(T)
Xμν[3i32, 2i32, site] = zero3(T)

Xμν[2i32, 4i32, site] = zero3(T)
Xμν[4i32, 2i32, site] = zero3(T)

Xμν[3i32, 4i32, site] = zero3(T)
Xμν[4i32, 3i32, site] = zero3(T)
return nothing
end

0 comments on commit cb27448

Please sign in to comment.