Skip to content

Commit

Permalink
fft stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
epolack committed Sep 12, 2023
1 parent 79538c9 commit 15f2dee
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
29 changes: 26 additions & 3 deletions src/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ function ifft!(f_real::AbstractArray3, basis::PlaneWaveBasis,
f_real
end
# For multicomponents
function ifft!(f_real::AbstractArray4, basis::PlaneWaveBasis, f_fourier::AbstractArray4)
function ifft!(f_real::AbstractArray4, basis::PlaneWaveBasis, f_fourier::AbstractMatrix)
mul!(f_real, basis.opBFFT_mc, f_fourier)
f_real .*= basis.ifft_normalization
end
function ifft!(f_real::AbstractArray4, basis::PlaneWaveBasis,
kpt::Kpoint, f_fourier::AbstractVector; normalize=true)
kpt::Kpoint, f_fourier::AbstractMatrix; normalize=true)
n_comps = basis.model.n_components
@assert length(f_fourier) == n_comps*length(kpt.mapping)
@assert size(f_real) == basis.fft_size
@assert size(f_real)[2:end] == basis.fft_size

# Pad the input data
fill!(f_real, 0)
Expand Down Expand Up @@ -115,6 +115,29 @@ function fft!(f_fourier::AbstractVector, basis::PlaneWaveBasis,
normalize && (f_fourier .*= basis.fft_normalization)
f_fourier
end
# For multicomponents
function fft!(f_fourier::AbstractArray4, basis::PlaneWaveBasis, f_real::AbstractArray4)
if eltype(f_real) <: Real
f_real = complex.(f_real)
end
mul!(f_fourier, basis.opFFT_mc, f_real)
f_fourier .*= basis.fft_normalization
end
function fft!(f_fourier::AbstractMatrix, basis::PlaneWaveBasis,
kpt::Kpoint, f_real::AbstractArray4; normalize=true)
@assert size(f_real)[2:end] == basis.fft_size
@assert length(f_fourier) == length(kpt.mapping)

# FFT
mul!(f_real, basis.ipFFT_mc, f_real)

# Truncate
for σ in 1:basis.model.n_components
f_fourier[σ, :] .= @view(f_real[σ, :, :, :][kpt.mapping])
end
normalize && (f_fourier .*= basis.fft_normalization)
f_fourier
end

"""
fft(basis::PlaneWaveBasis, [kpt::Kpoint, ] f_real)
Expand Down
7 changes: 4 additions & 3 deletions src/terms/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ function LinearAlgebra.mul!(Hψ::AbstractVector, op::RealFourierOperator, ψ::Ab
Hψ .= Hψ_fourier .+ fft(op.basis, op.kpoint, Hψ_real)
end
# TODO: remove this hack
function LinearAlgebra.mul!(Hψ::AbstractArray3, op::RealFourierOperator, ψ::AbstractArray3)
@views for i = 1:size(ψ, 3)
@views for nc in 1:size(ψ, 1)
mul!(Hψ[nc, :, i], op, ψ[nc, :, i])
@views for σ in 1:size(ψ, 1)
mul!(Hψ[σ, :, i], op, ψ[σ, :, i])
end
end
Expand Down Expand Up @@ -137,7 +138,7 @@ function apply!(Hψ, op::NonlocalOperator, ψ)
fix_that_apply!(Hψ.fourier, op, ψ.fourier)
.fourier
end
# TODO
# TODO self-explainatory
function fix_that_apply!(Hψ, op::NonlocalOperator, ψ::AbstractVecOrMat)
.+= op.P * (op.D * (op.P' * ψ))
Expand Down

0 comments on commit 15f2dee

Please sign in to comment.