Skip to content

Commit

Permalink
Fix #1026: Allow GPU arrays in FFTGrid (#1029)
Browse files Browse the repository at this point in the history
  • Loading branch information
Technici4n authored Nov 28, 2024
1 parent c52f525 commit 6cb2c19
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ We define the FFTGrid struct, containing all the data required to perform FFTs.
Note that the FFT plans are not normalized. Normalization takes place explicitely
when the fft()/ifft() functions are called
"""
struct FFTGrid{T, VT <: Real}
struct FFTGrid{T,
VT <: Real,
T_G_vectors <: AbstractArray{Vec3{Int}, 3},
T_r_vectors <: AbstractArray{Vec3{VT}, 3}}
fft_size::Tuple{Int, Int, Int}

opFFT
Expand All @@ -64,8 +67,8 @@ struct FFTGrid{T, VT <: Real}
fft_normalization::T
ifft_normalization::T

G_vectors::Array{Vec3{Int}, 3}
r_vectors::Array{Vec3{VT}, 3}
G_vectors::T_G_vectors
r_vectors::T_r_vectors

architecture::AbstractArchitecture
end
Expand All @@ -89,8 +92,9 @@ function FFTGrid(fft_size::Tuple{Int, Int, Int}, unit_cell_volume::T,
for idx in CartesianIndices(fft_size)]
r_vectors = to_device(arch, r_vectors)

FFTGrid{T, VT}(fft_size, opFFT, ipFFT, opBFFT, ipBFFT, fft_normalization,
ifft_normalization, Gs, r_vectors, arch)
FFTGrid{T, VT, typeof(Gs), typeof(r_vectors)}(fft_size, opFFT, ipFFT, opBFFT, ipBFFT,
fft_normalization, ifft_normalization,
Gs, r_vectors, arch)
end

G_vectors(fft_grid::FFTGrid) = fft_grid.G_vectors
Expand Down

0 comments on commit 6cb2c19

Please sign in to comment.