Skip to content

Commit

Permalink
pseudo-fix for plans
Browse files Browse the repository at this point in the history
  • Loading branch information
epolack committed Sep 12, 2023
1 parent c2c8fb6 commit 79538c9
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 12 deletions.
13 changes: 11 additions & 2 deletions src/PlaneWaveBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,17 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Number, fft_size, variational,
# Setup FFT plans
Gs = to_device(architecture, G_vectors(fft_size))
(ipFFT, opFFT, ipBFFT, opBFFT) = build_fft_plans!(similar(Gs, Complex{T}, fft_size))
(ipFFT_mc, opFFT_mc, ipBFFT_mc, opBFFT_mc) = build_fft_plans!(similar(Gs, Complex{T}, model.n_components, fft_size...),
[2,3,4])
# TODO: fix this edge case
if T <: Union{Float32, Float64}
(ipFFT_mc, opFFT_mc, ipBFFT_mc, opBFFT_mc) = build_fft_plans!(similar(Gs, Complex{T}, model.n_components, fft_size...),
[2,3,4])
else
@assert model.n_components == 1
ipFFT_mc = nothing
opFFT_mc = nothing
ipBFFT_mc = nothing
opBFFT_mc = nothing
end

# Normalization constants
# fft = fft_normalization * FFT
Expand Down
6 changes: 3 additions & 3 deletions src/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,13 @@ end
Plan a FFT of type `T` and size `fft_size`, spending some time on finding an
optimal algorithm. (Inplace, out-of-place) x (forward, backward) FFT plans are returned.
"""
function build_fft_plans!(tmp::Array{Complex{Float64}}, dims=1:ndims(tmp))
function build_fft_plans!(tmp::Array{Complex{Float64}}, dims::AbstractVector{Int}=1:ndims(tmp))
ipFFT = FFTW.plan_fft!(tmp, dims; flags=FFTW.MEASURE)
opFFT = FFTW.plan_fft(tmp, dims; flags=FFTW.MEASURE)
# backwards-FFT by inverting and stripping off normalizations
ipFFT, opFFT, inv(ipFFT).p, inv(opFFT).p
end
function build_fft_plans!(tmp::Array{Complex{Float32}}, dims=1:ndims(tmp))
function build_fft_plans!(tmp::Array{Complex{Float32}}, dims::AbstractVector{Int}=1:ndims(tmp))
# For Float32 there are issues with aligned FFTW plans, so we
# fall back to unaligned FFTW plans (which are generally discouraged).
ipFFT = FFTW.plan_fft!(tmp, dims; flags=FFTW.MEASURE | FFTW.UNALIGNED)
Expand All @@ -294,7 +294,7 @@ function build_fft_plans!(tmp::Array{Complex{Float32}}, dims=1:ndims(tmp))
ipFFT, opFFT, inv(ipFFT).p, inv(opFFT).p
end
function build_fft_plans!(tmp::AbstractArray{Complex{T}},
dims=1:ndims(tmp)) where {T<:Union{Float32,Float64}}
dims::AbstractVector{Int}=1:ndims(tmp)) where {T<:Union{Float32,Float64}}
ipFFT = AbstractFFTs.plan_fft!(tmp, dims)
opFFT = AbstractFFTs.plan_fft(tmp, dims)
# backwards-FFT by inverting and stripping off normalizations
Expand Down
8 changes: 5 additions & 3 deletions src/terms/Hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,12 @@ Base.:*(H::Hamiltonian, ψ::AbstractArray) = mul!(deepcopy(ψ), H, ψ)

# Loop through bands, IFFT to get ψ in real space, loop through terms, FFT and accumulate into Hψ
# For the common DftHamiltonianBlock there is an optimized version below
# TODO: use components
@views @timing "Hamiltonian multiplication" function LinearAlgebra.mul!(Hψ::AbstractArray,
H::GenericHamiltonianBlock,
ψ::AbstractArray)
T = eltype(H.basis)
n_bands = size(ψ, 3)
n_bands = size(ψ, 2)
Hψ_fourier = similar(Hψ[:, 1])
ψ_real = similar(ψ, complex(T), H.basis.fft_size...)
Hψ_real = similar(Hψ, complex(T), H.basis.fft_size...)
Expand Down Expand Up @@ -146,10 +147,11 @@ end
end

if have_divAgrad
@assert H.basis.model.n_components == 1
@timeit to "divAgrad" begin
apply!((fourier=Hψ[:, iband], real=nothing),
apply!((fourier=Hψ[1, :, iband], real=nothing),
H.divAgrad_op,
(fourier=ψ[:, iband], real=nothing),
(fourier=ψ[1, :, iband], real=nothing),
ψ_real) # ψ_real used as scratch
end
end
Expand Down
13 changes: 11 additions & 2 deletions src/terms/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,19 @@ struct NonlocalOperator{T <: Real, PT, DT} <: RealFourierOperator
D::DT
end
function apply!(Hψ, op::NonlocalOperator, ψ)
fix_that_apply!(Hψ.fourier, op, ψ.fourier)
.fourier
end
# TODO
function fix_that_apply!(Hψ, op::NonlocalOperator, ψ::AbstractVecOrMat)
.+= op.P * (op.D * (op.P' * ψ))
end
function fix_that_apply!(Hψ, op::NonlocalOperator, ψ::AbstractArray3)
for σ in 1:op.basis.model.n_components
.fourier[σ, :, :] .+= op.P * (op.D * (op.P' * ψ.fourier[σ, :, :]))
Hψ[σ, :, :] .+= op.P * (op.D * (op.P' * ψ[σ, :, :]))
end
.fourier
end
function tensor(op::NonlocalOperator)
n_Gk = length(G_vectors(op.basis, op.kpoint))
Expand Down
2 changes: 1 addition & 1 deletion test/multicomponents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using DFTK

include("testcases.jl")

@testset "Consistency" begin
@testset "Multicomponents consistency" begin
kgrid = (4,4,4)
Ecut = 5.0
tol = 1e-5
Expand Down
2 changes: 1 addition & 1 deletion test/silicon_redHF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function run_silicon_redHF(T; Ecut=5, grid_size=15, spin_polarization=:none, kwa
ref_etot = -5.440593269861395

fft_size = fill(grid_size, 3)
fft_size = DFTK.next_working_fft_size(T, fft_size) # ad-hoc fix for buggy generic FFTs
fft_size = DFTK.next_working_fft_size(T, fft_size) # ad-hoc fix for buggy generic FFTs
Si = ElementPsp(silicon.atnum, psp=load_psp("hgh/lda/si-q4"))
atoms = [Si, Si]

Expand Down

0 comments on commit 79538c9

Please sign in to comment.