diff --git a/src/fft.jl b/src/fft.jl index 1615c9df9..c8c868c6c 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -54,7 +54,10 @@ We define the FFTGrid struct, containing all the data required to perform FFTs. Note that the FFT plans are not normalized. Normalization takes place explicitely when the fft()/ifft() functions are called """ -struct FFTGrid{T, VT <: Real} +struct FFTGrid{T, + VT <: Real, + T_G_vectors <: AbstractArray{Vec3{Int}, 3}, + T_r_vectors <: AbstractArray{Vec3{VT}, 3}} fft_size::Tuple{Int, Int, Int} opFFT @@ -64,8 +67,8 @@ struct FFTGrid{T, VT <: Real} fft_normalization::T ifft_normalization::T - G_vectors::Array{Vec3{Int}, 3} - r_vectors::Array{Vec3{VT}, 3} + G_vectors::T_G_vectors + r_vectors::T_r_vectors architecture::AbstractArchitecture end @@ -89,8 +92,9 @@ function FFTGrid(fft_size::Tuple{Int, Int, Int}, unit_cell_volume::T, for idx in CartesianIndices(fft_size)] r_vectors = to_device(arch, r_vectors) - FFTGrid{T, VT}(fft_size, opFFT, ipFFT, opBFFT, ipBFFT, fft_normalization, - ifft_normalization, Gs, r_vectors, arch) + FFTGrid{T, VT, typeof(Gs), typeof(r_vectors)}(fft_size, opFFT, ipFFT, opBFFT, ipBFFT, + fft_normalization, ifft_normalization, + Gs, r_vectors, arch) end G_vectors(fft_grid::FFTGrid) = fft_grid.G_vectors