Skip to content

Commit

Permalink
Experimental AMDGPU implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mfherbst committed May 18, 2023
1 parent 904a6e2 commit cf52a94
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 16 deletions.
14 changes: 14 additions & 0 deletions examples/amdgpu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using DFTK
using AMDGPU

a = 10.26 # Silicon lattice constant in Bohr
lattice = a / 2 * [[0 1 1.];
[1 0 1.];
[1 1 0.]]
Si = ElementPsp(:Si, psp=load_psp("hgh/lda/Si-q4"))
atoms = [Si, Si]
positions = [ones(3)/8, -ones(3)/8]
model = model_PBE(lattice, atoms, positions)

basis = PlaneWaveBasis(model; Ecut=30, kgrid=(5, 5, 5), architecture=DFTK.GPU(AMDGPU.ROCArray)
scfres = self_consistent_field(basis; tol=1e-2, solver=scf_damping_solver())
File renamed without changes.
9 changes: 4 additions & 5 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,14 @@ function __init__()
@require DoubleFloats="497a8b3b-efae-58df-a0af-a86822472b78" begin
!isdefined(DFTK, :GENERIC_FFT_LOADED) && include("workarounds/fft_generic.jl")
end
@require Plots="91a5bcdd-55d7-5caf-9e0b-520d859cae80" include("plotting.jl")
@require JLD2="033835bb-8acc-5ee8-8aae-3f567f8a3819" include("external/jld2io.jl")
@require Plots="91a5bcdd-55d7-5caf-9e0b-520d859cae80" include("plotting.jl")
@require JLD2="033835bb-8acc-5ee8-8aae-3f567f8a3819" include("external/jld2io.jl")
@require WriteVTK="64499a7a-5c06-52f2-abe2-ccb03c286192" include("external/vtkio.jl")
@require wannier90_jll="c5400fa0-8d08-52c2-913f-1e3f656c1ce9" begin
include("external/wannier90.jl")
end
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
include("workarounds/cuda_arrays.jl")
end
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" include("workarounds/cuda_arrays.jl")
@require AMDGPU="21141c5a-9bdb-4563-92ae-f87d6854732e" include("workarounds/roc_arrays.jl")
end

# Precompilation block with a basic workflow
Expand Down
22 changes: 11 additions & 11 deletions src/terms/xc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ function xc_potential_real(term::TermXc, basis::PlaneWaveBasis{T}, ψ, occupatio
max_ρ_derivs = maximum(max_required_derivative, term.functionals)
density = LibxcDensities(basis, max_ρ_derivs, ρ, τ)

# Evaluate terms and energy contribution (zk == energy per unit particle)
# It may happen that a functional does only provide a potenital and not an energy term
# Therefore skip_unsupported_derivatives=true to avoid an error.
# Evaluate terms and energy contribution
# If the XC functional is not supported for an architecture, terms is on the CPU
terms = potential_terms(term.functionals, density)
@assert haskey(terms, :Vρ) && haskey(terms, :e)
E = term.scaling_factor * sum(terms.e) * basis.dvol
Expand All @@ -94,14 +93,14 @@ function xc_potential_real(term::TermXc, basis::PlaneWaveBasis{T}, ψ, occupatio
# Potential contributions Vρ -2 ∇⋅(Vσ ∇ρ) + ΔVl
potential = zero(ρ)
@views for s in 1:n_spin
= reshape(terms.Vρ, n_spin, basis.fft_size...)
= to_device(basis.architecture, reshape(terms.Vρ, n_spin, basis.fft_size...))

potential[:, :, :, s] .+= Vρ[s, :, :, :]
if haskey(terms, :Vσ) && any(x -> abs(x) > potential_threshold, terms.Vσ)
# Need gradient correction
# TODO Drop do-block syntax here?
potential[:, :, :, s] .+= -2divergence_real(basis) do α
= reshape(terms.Vσ, :, basis.fft_size...)
= to_device(basis.architecture, reshape(terms.Vσ, :, basis.fft_size...))

# Extra factor (1/2) for s != t is needed because libxc only keeps σ_{αβ}
# in the energy expression. See comment block below on spin-polarised XC.
Expand All @@ -113,7 +112,7 @@ function xc_potential_real(term::TermXc, basis::PlaneWaveBasis{T}, ψ, occupatio
if haskey(terms, :Vl) && any(x -> abs(x) > potential_threshold, terms.Vl)
@warn "Meta-GGAs with a Δρ term have not yet been thoroughly tested." maxlog=1
mG² = .-norm2.(G_vectors_cart(basis))
Vl = reshape(terms.Vl, n_spin, basis.fft_size...)
Vl = to_device(basis.architecture, reshape(terms.Vl, n_spin, basis.fft_size...))
Vl_fourier = fft(basis, Vl[s, :, :, :])
potential[:, :, :, s] .+= irfft(basis, mG² .* Vl_fourier) # ΔVl
end
Expand All @@ -123,7 +122,7 @@ function xc_potential_real(term::TermXc, basis::PlaneWaveBasis{T}, ψ, occupatio
= nothing
if haskey(terms, :Vτ) && any(x -> abs(x) > potential_threshold, terms.Vτ)
# Need meta-GGA non-local operator (Note: -½ part of the definition of DivAgrid)
= reshape(terms.Vτ, n_spin, basis.fft_size...)
= to_device(basis.architecture, reshape(terms.Vτ, n_spin, basis.fft_size...))
= term.scaling_factor * permutedims(Vτ, (2, 3, 4, 1))
end

Expand Down Expand Up @@ -379,10 +378,11 @@ function apply_kernel(term::TermXc, basis::PlaneWaveBasis{T}, δρ; ρ, kwargs..
]
end

# If the XC functional is not supported for an architecture, terms is on the CPU
terms = kernel_terms(term.functionals, density)
δV = zero(ρ) # [ix, iy, iz, iσ]

Vρρ = reshape(terms.Vρρ, n_spin, n_spin, basis.fft_size...)
Vρρ = to_device(basis.architecture, reshape(terms.Vρρ, n_spin, n_spin, basis.fft_size...))
@views for s in 1:n_spin, t in 1:n_spin # LDA term
δV[:, :, :, s] .+= Vρρ[s, t, :, :, :] .* δρ[t, :, :, :]
end
Expand Down Expand Up @@ -422,9 +422,9 @@ function add_kernel_gradient_correction!(δV, terms, density, perturbation, cros
δρ = perturbation.ρ_real
∇δρ = perturbation.∇ρ_real
δσ = cross_derivatives[:δσ]
Vρσ = reshape(terms.Vρσ, n_spin, spin_σ, basis.fft_size...)
Vσσ = reshape(terms.Vσσ, spin_σ, spin_σ, basis.fft_size...)
= reshape(terms.Vσ, spin_σ, basis.fft_size...)
Vρσ = to_device(basis.architecture, reshape(terms.Vρσ, n_spin, spin_σ, basis.fft_size...))
Vσσ = to_device(basis.architecture, reshape(terms.Vσσ, spin_σ, spin_σ, basis.fft_size...))
= to_device(basis.architecture, reshape(terms.Vσ, spin_σ, basis.fft_size...))

T = eltype(ρ)
= DftFunctionals.spinindex_σ
Expand Down
10 changes: 10 additions & 0 deletions src/workarounds/gpu_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@ function lowpass_for_symmetry!(ρ::AbstractGPUArray, basis; symmetries=basis.sym
ρ_CPU = lowpass_for_symmetry!(to_cpu(ρ), basis; symmetries)
ρ .= to_device(basis.architecture, ρ_CPU)
end

for fun in (:potential_terms, :kernel_terms)
@eval function DftFunctionals.$fun(fun::DispatchFunctional, ρ::AT,
args...) where {AT <: AbstractGPUArray{Float64}}
# Fallback implementation for the GPU: Transfer to the CPU and run computation there
cpuify(::Nothing) = nothing
cpuify(x::AbstractArray) = Array(x)
$fun(fun, Array(ρ), cpuify.(args)...)
end
end
1 change: 1 addition & 0 deletions src/workarounds/roc_arrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
synchronize_device(::GPU{<:AMDGPU.ROCArray}) = AMDGPU.Device.sync_workgroup()

0 comments on commit cf52a94

Please sign in to comment.