diff --git a/examples/amdgpu.jl b/examples/amdgpu.jl new file mode 100644 index 0000000000..20161682e9 --- /dev/null +++ b/examples/amdgpu.jl @@ -0,0 +1,14 @@ +using DFTK +using AMDGPU + +a = 10.26 # Silicon lattice constant in Bohr +lattice = a / 2 * [[0 1 1.]; + [1 0 1.]; + [1 1 0.]] +Si = ElementPsp(:Si, psp=load_psp("hgh/lda/Si-q4")) +atoms = [Si, Si] +positions = [ones(3)/8, -ones(3)/8] +model = model_PBE(lattice, atoms, positions) + +basis = PlaneWaveBasis(model; Ecut=30, kgrid=(5, 5, 5), architecture=DFTK.GPU(AMDGPU.ROCArray) +scfres = self_consistent_field(basis; tol=1e-2, solver=scf_damping_solver()) diff --git a/examples/gpu.jl b/examples/cuda.jl similarity index 100% rename from examples/gpu.jl rename to examples/cuda.jl diff --git a/src/DFTK.jl b/src/DFTK.jl index d1c4c048d4..211534c70b 100644 --- a/src/DFTK.jl +++ b/src/DFTK.jl @@ -230,15 +230,14 @@ function __init__() @require DoubleFloats="497a8b3b-efae-58df-a0af-a86822472b78" begin !isdefined(DFTK, :GENERIC_FFT_LOADED) && include("workarounds/fft_generic.jl") end - @require Plots="91a5bcdd-55d7-5caf-9e0b-520d859cae80" include("plotting.jl") - @require JLD2="033835bb-8acc-5ee8-8aae-3f567f8a3819" include("external/jld2io.jl") + @require Plots="91a5bcdd-55d7-5caf-9e0b-520d859cae80" include("plotting.jl") + @require JLD2="033835bb-8acc-5ee8-8aae-3f567f8a3819" include("external/jld2io.jl") @require WriteVTK="64499a7a-5c06-52f2-abe2-ccb03c286192" include("external/vtkio.jl") @require wannier90_jll="c5400fa0-8d08-52c2-913f-1e3f656c1ce9" begin include("external/wannier90.jl") end - @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin - include("workarounds/cuda_arrays.jl") - end + @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" include("workarounds/cuda_arrays.jl") + @require AMDGPU="21141c5a-9bdb-4563-92ae-f87d6854732e" include("workarounds/roc_arrays.jl") end # Precompilation block with a basic workflow diff --git a/src/terms/xc.jl b/src/terms/xc.jl index 06a9b2b92c..249e3c8fc0 100644 --- a/src/terms/xc.jl +++ b/src/terms/xc.jl @@ -79,9 +79,8 @@ function xc_potential_real(term::TermXc, basis::PlaneWaveBasis{T}, ψ, occupatio max_ρ_derivs = maximum(max_required_derivative, term.functionals) density = LibxcDensities(basis, max_ρ_derivs, ρ, τ) - # Evaluate terms and energy contribution (zk == energy per unit particle) - # It may happen that a functional does only provide a potenital and not an energy term - # Therefore skip_unsupported_derivatives=true to avoid an error. + # Evaluate terms and energy contribution + # If the XC functional is not supported for an architecture, terms is on the CPU terms = potential_terms(term.functionals, density) @assert haskey(terms, :Vρ) && haskey(terms, :e) E = term.scaling_factor * sum(terms.e) * basis.dvol @@ -94,14 +93,14 @@ function xc_potential_real(term::TermXc, basis::PlaneWaveBasis{T}, ψ, occupatio # Potential contributions Vρ -2 ∇⋅(Vσ ∇ρ) + ΔVl potential = zero(ρ) @views for s in 1:n_spin - Vρ = reshape(terms.Vρ, n_spin, basis.fft_size...) + Vρ = to_device(basis.architecture, reshape(terms.Vρ, n_spin, basis.fft_size...)) potential[:, :, :, s] .+= Vρ[s, :, :, :] if haskey(terms, :Vσ) && any(x -> abs(x) > potential_threshold, terms.Vσ) # Need gradient correction # TODO Drop do-block syntax here? potential[:, :, :, s] .+= -2divergence_real(basis) do α - Vσ = reshape(terms.Vσ, :, basis.fft_size...) + Vσ = to_device(basis.architecture, reshape(terms.Vσ, :, basis.fft_size...)) # Extra factor (1/2) for s != t is needed because libxc only keeps σ_{αβ} # in the energy expression. See comment block below on spin-polarised XC. @@ -113,7 +112,7 @@ function xc_potential_real(term::TermXc, basis::PlaneWaveBasis{T}, ψ, occupatio if haskey(terms, :Vl) && any(x -> abs(x) > potential_threshold, terms.Vl) @warn "Meta-GGAs with a Δρ term have not yet been thoroughly tested." maxlog=1 mG² = .-norm2.(G_vectors_cart(basis)) - Vl = reshape(terms.Vl, n_spin, basis.fft_size...) + Vl = to_device(basis.architecture, reshape(terms.Vl, n_spin, basis.fft_size...)) Vl_fourier = fft(basis, Vl[s, :, :, :]) potential[:, :, :, s] .+= irfft(basis, mG² .* Vl_fourier) # ΔVl end @@ -123,7 +122,7 @@ function xc_potential_real(term::TermXc, basis::PlaneWaveBasis{T}, ψ, occupatio Vτ = nothing if haskey(terms, :Vτ) && any(x -> abs(x) > potential_threshold, terms.Vτ) # Need meta-GGA non-local operator (Note: -½ part of the definition of DivAgrid) - Vτ = reshape(terms.Vτ, n_spin, basis.fft_size...) + Vτ = to_device(basis.architecture, reshape(terms.Vτ, n_spin, basis.fft_size...)) Vτ = term.scaling_factor * permutedims(Vτ, (2, 3, 4, 1)) end @@ -379,10 +378,11 @@ function apply_kernel(term::TermXc, basis::PlaneWaveBasis{T}, δρ; ρ, kwargs.. ] end + # If the XC functional is not supported for an architecture, terms is on the CPU terms = kernel_terms(term.functionals, density) δV = zero(ρ) # [ix, iy, iz, iσ] - Vρρ = reshape(terms.Vρρ, n_spin, n_spin, basis.fft_size...) + Vρρ = to_device(basis.architecture, reshape(terms.Vρρ, n_spin, n_spin, basis.fft_size...)) @views for s in 1:n_spin, t in 1:n_spin # LDA term δV[:, :, :, s] .+= Vρρ[s, t, :, :, :] .* δρ[t, :, :, :] end @@ -422,9 +422,9 @@ function add_kernel_gradient_correction!(δV, terms, density, perturbation, cros δρ = perturbation.ρ_real ∇δρ = perturbation.∇ρ_real δσ = cross_derivatives[:δσ] - Vρσ = reshape(terms.Vρσ, n_spin, spin_σ, basis.fft_size...) - Vσσ = reshape(terms.Vσσ, spin_σ, spin_σ, basis.fft_size...) - Vσ = reshape(terms.Vσ, spin_σ, basis.fft_size...) + Vρσ = to_device(basis.architecture, reshape(terms.Vρσ, n_spin, spin_σ, basis.fft_size...)) + Vσσ = to_device(basis.architecture, reshape(terms.Vσσ, spin_σ, spin_σ, basis.fft_size...)) + Vσ = to_device(basis.architecture, reshape(terms.Vσ, spin_σ, basis.fft_size...)) T = eltype(ρ) tσ = DftFunctionals.spinindex_σ diff --git a/src/workarounds/gpu_arrays.jl b/src/workarounds/gpu_arrays.jl index 2bffa1f46a..ba32c8e2fa 100644 --- a/src/workarounds/gpu_arrays.jl +++ b/src/workarounds/gpu_arrays.jl @@ -11,3 +11,13 @@ function lowpass_for_symmetry!(ρ::AbstractGPUArray, basis; symmetries=basis.sym ρ_CPU = lowpass_for_symmetry!(to_cpu(ρ), basis; symmetries) ρ .= to_device(basis.architecture, ρ_CPU) end + +for fun in (:potential_terms, :kernel_terms) + @eval function DftFunctionals.$fun(fun::DispatchFunctional, ρ::AT, + args...) where {AT <: AbstractGPUArray{Float64}} + # Fallback implementation for the GPU: Transfer to the CPU and run computation there + cpuify(::Nothing) = nothing + cpuify(x::AbstractArray) = Array(x) + $fun(fun, Array(ρ), cpuify.(args)...) + end +end diff --git a/src/workarounds/roc_arrays.jl b/src/workarounds/roc_arrays.jl new file mode 100644 index 0000000000..4c34875c7d --- /dev/null +++ b/src/workarounds/roc_arrays.jl @@ -0,0 +1 @@ +synchronize_device(::GPU{<:AMDGPU.ROCArray}) = AMDGPU.Device.sync_workgroup()