From 1dec4d72457b3968ed14b4a4adacfeb3e7d62f97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Polack?= Date: Tue, 12 Sep 2023 15:05:12 +0200 Subject: [PATCH] pseudo-fix for plans --- src/PlaneWaveBasis.jl | 13 +++++++++++-- src/fft.jl | 6 +++--- src/terms/Hamiltonian.jl | 7 ++++--- test/silicon_redHF.jl | 2 +- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/PlaneWaveBasis.jl b/src/PlaneWaveBasis.jl index 45400dc1dc..dd839a09d2 100644 --- a/src/PlaneWaveBasis.jl +++ b/src/PlaneWaveBasis.jl @@ -211,8 +211,17 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Number, fft_size, variational, # Setup FFT plans Gs = to_device(architecture, G_vectors(fft_size)) (ipFFT, opFFT, ipBFFT, opBFFT) = build_fft_plans!(similar(Gs, Complex{T}, fft_size)) - (ipFFT_mc, opFFT_mc, ipBFFT_mc, opBFFT_mc) = build_fft_plans!(similar(Gs, Complex{T}, model.n_components, fft_size...), - [2,3,4]) + # TODO: fix this edge case + if T <: Union{Float32, Float64} + @assert model.n_components == 1 + (ipFFT_mc, opFFT_mc, ipBFFT_mc, opBFFT_mc) = build_fft_plans!(similar(Gs, Complex{T}, model.n_components, fft_size...), + [2,3,4]) + else + ipFFT_mc = nothing + opFFT_mc = nothing + ipBFFT_mc = nothing + opBFFT_mc = nothing + end # Normalization constants # fft = fft_normalization * FFT diff --git a/src/fft.jl b/src/fft.jl index 64d9abe34e..aaca333bf5 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -279,13 +279,13 @@ end Plan a FFT of type `T` and size `fft_size`, spending some time on finding an optimal algorithm. (Inplace, out-of-place) x (forward, backward) FFT plans are returned. """ -function build_fft_plans!(tmp::Array{Complex{Float64}}, dims=1:ndims(tmp)) +function build_fft_plans!(tmp::Array{Complex{Float64}}, dims::AbstractVector{Int}=1:ndims(tmp)) ipFFT = FFTW.plan_fft!(tmp, dims; flags=FFTW.MEASURE) opFFT = FFTW.plan_fft(tmp, dims; flags=FFTW.MEASURE) # backwards-FFT by inverting and stripping off normalizations ipFFT, opFFT, inv(ipFFT).p, inv(opFFT).p end -function build_fft_plans!(tmp::Array{Complex{Float32}}, dims=1:ndims(tmp)) +function build_fft_plans!(tmp::Array{Complex{Float32}}, dims::AbstractVector{Int}=1:ndims(tmp)) # For Float32 there are issues with aligned FFTW plans, so we # fall back to unaligned FFTW plans (which are generally discouraged). ipFFT = FFTW.plan_fft!(tmp, dims; flags=FFTW.MEASURE | FFTW.UNALIGNED) @@ -294,7 +294,7 @@ function build_fft_plans!(tmp::Array{Complex{Float32}}, dims=1:ndims(tmp)) ipFFT, opFFT, inv(ipFFT).p, inv(opFFT).p end function build_fft_plans!(tmp::AbstractArray{Complex{T}}, - dims=1:ndims(tmp)) where {T<:Union{Float32,Float64}} + dims::AbstractVector{Int}=1:ndims(tmp)) where {T<:Union{Float32,Float64}} ipFFT = AbstractFFTs.plan_fft!(tmp, dims) opFFT = AbstractFFTs.plan_fft(tmp, dims) # backwards-FFT by inverting and stripping off normalizations diff --git a/src/terms/Hamiltonian.jl b/src/terms/Hamiltonian.jl index b24fcf36b7..ecaf862cd8 100644 --- a/src/terms/Hamiltonian.jl +++ b/src/terms/Hamiltonian.jl @@ -93,7 +93,7 @@ Base.:*(H::Hamiltonian, ψ::AbstractArray) = mul!(deepcopy(ψ), H, ψ) H::GenericHamiltonianBlock, ψ::AbstractArray) T = eltype(H.basis) - n_bands = size(ψ, 3) + n_bands = size(ψ, 2) Hψ_fourier = similar(Hψ[:, 1]) ψ_real = similar(ψ, complex(T), H.basis.fft_size...) Hψ_real = similar(Hψ, complex(T), H.basis.fft_size...) @@ -146,10 +146,11 @@ end end if have_divAgrad + @assert H.basis.model.n_components == 1 @timeit to "divAgrad" begin - apply!((fourier=Hψ[:, iband], real=nothing), + apply!((fourier=Hψ[1, :, iband], real=nothing), H.divAgrad_op, - (fourier=ψ[:, iband], real=nothing), + (fourier=ψ[1, :, iband], real=nothing), ψ_real) # ψ_real used as scratch end end diff --git a/test/silicon_redHF.jl b/test/silicon_redHF.jl index 06a4d70219..1df0bc4ffa 100644 --- a/test/silicon_redHF.jl +++ b/test/silicon_redHF.jl @@ -25,7 +25,7 @@ function run_silicon_redHF(T; Ecut=5, grid_size=15, spin_polarization=:none, kwa ref_etot = -5.440593269861395 fft_size = fill(grid_size, 3) - fft_size = DFTK.next_working_fft_size(T, fft_size) # ad-hoc fix for buggy generic FFTs + fft_size = DFTK.next_working_fft_size(T, fft_size) # ad-hoc fix for buggy generic FFTs Si = ElementPsp(silicon.atnum, psp=load_psp("hgh/lda/si-q4")) atoms = [Si, Si]