Skip to content

Commit

Permalink
Fix CUDA cache issue
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Nov 10, 2024
1 parent ec494bf commit f9e4dd1
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 17 deletions.
29 changes: 22 additions & 7 deletions Src/FFT/AMReX_FFT_Helper.H
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,6 @@ struct Plan
VendorPlan plan2{};
void* pf = nullptr;
void* pb = nullptr;
#ifdef AMREX_USE_CUDA
std::size_t work_size = 0;
#endif

#ifdef AMREX_USE_GPU
void set_ptrs (void* p0, void* p1) {
Expand Down Expand Up @@ -206,6 +203,7 @@ struct Plan

AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
std::size_t work_size;
if constexpr (D == Direction::forward) {
cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
AMREX_CUFFT_SAFE_CALL
Expand All @@ -215,7 +213,6 @@ struct Plan
AMREX_CUFFT_SAFE_CALL
(cufftMakePlanMany(plan, rank, len, nullptr, 1, nc, nullptr, 1, nr, bwd_type, howmany, &work_size));
}
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));

#elif defined(AMREX_USE_HIP)

Expand Down Expand Up @@ -314,9 +311,9 @@ struct Plan
AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));

cufftType t = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
std::size_t work_size;
AMREX_CUFFT_SAFE_CALL
(cufftMakePlanMany(plan, 1, &n, nullptr, 1, n, nullptr, 1, n, t, howmany, &work_size));
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));

#elif defined(AMREX_USE_HIP)

Expand Down Expand Up @@ -475,9 +472,9 @@ struct Plan
AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
std::size_t work_size;
AMREX_CUFFT_SAFE_CALL
(cufftMakePlanMany(plan, 1, &nex, nullptr, 1, nc*2, nullptr, 1, nc, fwd_type, howmany, &work_size));
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));

#elif defined(AMREX_USE_HIP)

Expand Down Expand Up @@ -585,8 +582,14 @@ struct Plan
auto* po = (TO*)((D == Direction::forward) ? pb : pf);

#if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));

std::size_t work_size = 0;
AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size);

auto* work_area = The_Arena()->alloc(work_size);
AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));

if constexpr (D == Direction::forward) {
if constexpr (std::is_same_v<float,T>) {
AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, pi, po));
Expand Down Expand Up @@ -625,8 +628,14 @@ struct Plan
auto* p = (VendorComplex*)pf;

#if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));

std::size_t work_size = 0;
AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size);

auto* work_area = The_Arena()->alloc(work_size);
AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));

auto dir = (D == Direction::forward) ? CUFFT_FORWARD : CUFFT_INVERSE;
if constexpr (std::is_same_v<float,T>) {
AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, dir));
Expand Down Expand Up @@ -1061,8 +1070,14 @@ struct Plan

#if defined(AMREX_USE_CUDA)

AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));

std::size_t work_size = 0;
AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size);

auto* work_area = The_Arena()->alloc(work_size);
AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area));

if constexpr (std::is_same_v<float,T>) {
AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, (T*)pscratch, (VendorComplex*)pscratch));
} else {
Expand Down Expand Up @@ -1165,6 +1180,7 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* buffer, bool cache)
} else {
type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
}
std::size_t work_size;
if constexpr (M == 1) {
AMREX_CUFFT_SAFE_CALL
(cufftMakePlan1d(plan, fft_size[0], type, howmany, &work_size));
Expand All @@ -1175,7 +1191,6 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* buffer, bool cache)
AMREX_CUFFT_SAFE_CALL
(cufftMakePlan3d(plan, fft_size[2], fft_size[1], fft_size[0], type, &work_size));
}
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));

#elif defined(AMREX_USE_HIP)

Expand Down
46 changes: 36 additions & 10 deletions Src/FFT/AMReX_FFT_LocalR2C.H
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ private:

void* m_buffer = nullptr; // for fftw

#if defined(AMREX_USE_GPU)
#if defined(AMREX_USE_SYCL)
gpuStream_t m_gpu_stream;
#endif

Expand All @@ -106,16 +106,20 @@ LocalR2C<T,D,M>::LocalR2C (IntVectND<M> const& fft_size, bool cache_plan)
BL_PROFILE("FFT::LocalR2C");
m_spectral_size[0] = m_real_size[0]/2 + 1;

#if defined(AMREX_USE_GPU)
#if defined(AMREX_USE_SYCL)

auto current_stream = Gpu::gpuStream();
Gpu::Device::resetStreamIndex();
m_gpu_stream = Gpu::gpuStream();
#else

#elif !defined(AMREX_USE_GPU)

Long total_size = 1;
for (auto s : m_spectral_size) {
total_size *= s;
}
m_buffer = The_Arena()->alloc(sizeof(GpuComplex<T>)*total_size);

#endif

#ifdef AMREX_USE_SYCL
Expand All @@ -130,7 +134,7 @@ LocalR2C<T,D,M>::LocalR2C (IntVectND<M> const& fft_size, bool cache_plan)
}
#endif

#if defined(AMREX_USE_GPU)
#if defined(AMREX_USE_SYCL)
Gpu::Device::setStream(current_stream);
#endif
}
Expand Down Expand Up @@ -158,13 +162,19 @@ void LocalR2C<T,D,M>::forward (T const* indata, GpuComplex<T>* outdata)
BL_PROFILE("FFT::LocalR2C::forward");

#if defined(AMREX_USE_GPU)

m_fft_fwd.set_ptrs((void*)indata, (void*)outdata);

#if defined(AMREX_USE_SYCL)
auto current_stream = Gpu::gpuStream();
if (current_stream != m_gpu_stream) {
Gpu::streamSynchronize();
Gpu::Device::setStream(m_gpu_stream);
}
#else
#endif

#else /* FFTW */

int ny = 1;
for (int idim = 1; idim < M; ++idim) {
ny *= m_real_size[idim];
Expand All @@ -176,18 +186,23 @@ void LocalR2C<T,D,M>::forward (T const* indata, GpuComplex<T>* outdata)
pdst += sizeof(GpuComplex<T>) * m_spectral_size[0];
psrc += sizeof( T ) * m_real_size[0];
}

#endif

m_fft_fwd.template compute_r2c<Direction::forward>();

#if defined(AMREX_USE_GPU)
#if defined(AMREX_USE_SYCL)

if (current_stream != m_gpu_stream) {
Gpu::Device::setStream(current_stream);
}
#else

#elif !defined(AMREX_USE_GPU)

std::size_t nbytes = sizeof(GpuComplex<T>);
for (auto s : m_spectral_size) { nbytes *= s; }
std::memcpy(outdata, m_buffer, nbytes);

#endif
}

Expand All @@ -199,25 +214,35 @@ void LocalR2C<T,D,M>::backward (GpuComplex<T> const* indata, T* outdata)
BL_PROFILE("FFT::LocalR2C::backward");

#if defined(AMREX_USE_GPU)

m_fft_bwd.set_ptrs((void*)outdata, (void*)indata);

#if defined(AMREX_USE_SYCL)
auto current_stream = Gpu::gpuStream();
if (current_stream != m_gpu_stream) {
Gpu::streamSynchronize();
Gpu::Device::setStream(m_gpu_stream);
}
#else
#endif

#else /* FFTW */

std::size_t nbytes = sizeof(GpuComplex<T>);
for (auto s : m_spectral_size) { nbytes *= s; }
std::memcpy(m_buffer, indata, nbytes);

#endif

m_fft_bwd.template compute_r2c<Direction::backward>();

#if defined(AMREX_USE_GPU)
#if defined(AMREX_USE_SYCL)

if (current_stream != m_gpu_stream) {
Gpu::Device::setStream(current_stream);
}
#else

#elif !defined(AMREX_USE_GPU)

int ny = 1;
for (int idim = 1; idim < M; ++idim) {
ny *= m_real_size[idim];
Expand All @@ -229,6 +254,7 @@ void LocalR2C<T,D,M>::backward (GpuComplex<T> const* indata, T* outdata)
pdst += sizeof( T ) * m_real_size[0];
psrc += sizeof(GpuComplex<T>) * m_spectral_size[0];
}

#endif
}

Expand Down

0 comments on commit f9e4dd1

Please sign in to comment.