Skip to content

Commit

Permalink
pseudo-fix for plans
Browse files Browse the repository at this point in the history
  • Loading branch information
epolack committed Sep 12, 2023
1 parent c2c8fb6 commit 1dec4d7
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 9 deletions.
13 changes: 11 additions & 2 deletions src/PlaneWaveBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,17 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Number, fft_size, variational,
# Setup FFT plans
Gs = to_device(architecture, G_vectors(fft_size))
(ipFFT, opFFT, ipBFFT, opBFFT) = build_fft_plans!(similar(Gs, Complex{T}, fft_size))
(ipFFT_mc, opFFT_mc, ipBFFT_mc, opBFFT_mc) = build_fft_plans!(similar(Gs, Complex{T}, model.n_components, fft_size...),
[2,3,4])
# TODO: fix this edge case
if T <: Union{Float32, Float64}
@assert model.n_components == 1
(ipFFT_mc, opFFT_mc, ipBFFT_mc, opBFFT_mc) = build_fft_plans!(similar(Gs, Complex{T}, model.n_components, fft_size...),
[2,3,4])
else
ipFFT_mc = nothing
opFFT_mc = nothing
ipBFFT_mc = nothing
opBFFT_mc = nothing
end

# Normalization constants
# fft = fft_normalization * FFT
Expand Down
6 changes: 3 additions & 3 deletions src/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,13 @@ end
Plan a FFT of type `T` and size `fft_size`, spending some time on finding an
optimal algorithm. (Inplace, out-of-place) x (forward, backward) FFT plans are returned.
"""
function build_fft_plans!(tmp::Array{Complex{Float64}}, dims=1:ndims(tmp))
function build_fft_plans!(tmp::Array{Complex{Float64}}, dims::AbstractVector{Int}=1:ndims(tmp))
ipFFT = FFTW.plan_fft!(tmp, dims; flags=FFTW.MEASURE)
opFFT = FFTW.plan_fft(tmp, dims; flags=FFTW.MEASURE)
# backwards-FFT by inverting and stripping off normalizations
ipFFT, opFFT, inv(ipFFT).p, inv(opFFT).p
end
function build_fft_plans!(tmp::Array{Complex{Float32}}, dims=1:ndims(tmp))
function build_fft_plans!(tmp::Array{Complex{Float32}}, dims::AbstractVector{Int}=1:ndims(tmp))
# For Float32 there are issues with aligned FFTW plans, so we
# fall back to unaligned FFTW plans (which are generally discouraged).
ipFFT = FFTW.plan_fft!(tmp, dims; flags=FFTW.MEASURE | FFTW.UNALIGNED)
Expand All @@ -294,7 +294,7 @@ function build_fft_plans!(tmp::Array{Complex{Float32}}, dims=1:ndims(tmp))
ipFFT, opFFT, inv(ipFFT).p, inv(opFFT).p
end
function build_fft_plans!(tmp::AbstractArray{Complex{T}},
dims=1:ndims(tmp)) where {T<:Union{Float32,Float64}}
dims::AbstractVector{Int}=1:ndims(tmp)) where {T<:Union{Float32,Float64}}
ipFFT = AbstractFFTs.plan_fft!(tmp, dims)
opFFT = AbstractFFTs.plan_fft(tmp, dims)
# backwards-FFT by inverting and stripping off normalizations
Expand Down
7 changes: 4 additions & 3 deletions src/terms/Hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ Base.:*(H::Hamiltonian, ψ::AbstractArray) = mul!(deepcopy(ψ), H, ψ)
H::GenericHamiltonianBlock,
ψ::AbstractArray)
T = eltype(H.basis)
n_bands = size(ψ, 3)
n_bands = size(ψ, 2)
Hψ_fourier = similar(Hψ[:, 1])
ψ_real = similar(ψ, complex(T), H.basis.fft_size...)
Hψ_real = similar(Hψ, complex(T), H.basis.fft_size...)
Expand Down Expand Up @@ -146,10 +146,11 @@ end
end

if have_divAgrad
@assert H.basis.model.n_components == 1
@timeit to "divAgrad" begin
apply!((fourier=Hψ[:, iband], real=nothing),
apply!((fourier=Hψ[1, :, iband], real=nothing),
H.divAgrad_op,
(fourier=ψ[:, iband], real=nothing),
(fourier=ψ[1, :, iband], real=nothing),
ψ_real) # ψ_real used as scratch
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/silicon_redHF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function run_silicon_redHF(T; Ecut=5, grid_size=15, spin_polarization=:none, kwa
ref_etot = -5.440593269861395

fft_size = fill(grid_size, 3)
fft_size = DFTK.next_working_fft_size(T, fft_size) # ad-hoc fix for buggy generic FFTs
fft_size = DFTK.next_working_fft_size(T, fft_size) # ad-hoc fix for buggy generic FFTs
Si = ElementPsp(silicon.atnum, psp=load_psp("hgh/lda/si-q4"))
atoms = [Si, Si]

Expand Down

0 comments on commit 1dec4d7

Please sign in to comment.