Skip to content
This repository has been archived by the owner on Nov 2, 2023. It is now read-only.

Commit

Permalink
Merge pull request #44 from mhaseeb123/fft-gpu-new
Browse files Browse the repository at this point in the history
GPU-based FFT + CMake cleanup + linalg::matrix_product
  • Loading branch information
Muhammad Haseeb authored Oct 20, 2023
2 parents 4c7afb0 + 99c4ce5 commit b2f02ec
Show file tree
Hide file tree
Showing 18 changed files with 436 additions and 449 deletions.
10 changes: 8 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,14 @@ endif()

# need to add appropriate flags for stdexec
set(CMAKE_CXX_FLAGS
"${CMAKE_CXX_FLAGS} -stdpar=${STDPAR} -mp=${OMP} --gcc-toolchain=/opt/cray/pe/gcc/12.2.0/bin/ -pthread"
)
"${CMAKE_CXX_FLAGS} -stdpar=${STDPAR} -mp=${OMP}")

# add -cudalib=cublas if -stdpar=gpu
if (STDPAR STREQUAL "gpu")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_GPU")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -UUSE_GPU")
endif()

# ##############################################################################
# Add sub-directories
Expand Down
2 changes: 1 addition & 1 deletion apps/1d_stencil/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ target_include_directories(
PRIVATE ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_LIST_DIR}/../../include
${ARGPARSE_INCLUDE_DIR} ${MDSPAN_INCLUDE_DIR})

if("${STDPAR}" STREQUAL "gpu")
# TODO, fix cmake
add_executable(stencil_stdexec stencil_stdexec.cpp)
target_link_libraries(stencil_stdexec stdexec)
Expand All @@ -22,6 +21,7 @@ if("${STDPAR}" STREQUAL "gpu")
PRIVATE ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_LIST_DIR}/../../include
${ARGPARSE_INCLUDE_DIR} ${MDSPAN_INCLUDE_DIR})

if("${STDPAR}" STREQUAL "gpu")
add_executable(stencil_cuda stencil_cuda.cpp)
target_include_directories(
stencil_cuda
Expand Down
13 changes: 11 additions & 2 deletions apps/1d_stencil/stencil_stdexec.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* MIT License
*
* Copyright (c) 2023 Weile Wei
* Copyright (c) 2023 Weile Wei
* Copyright (c) 2023 The Regents of the University of California,
* through Lawrence Berkeley National Laboratory (subject to receipt of any
* required approvals from the U.S. Dept. of Energy).All rights reserved.
Expand All @@ -27,8 +27,10 @@
//
// This example provides a stdexec implementation for the 1D stencil code.
#include <exec/static_thread_pool.hpp>
#if defined(USE_GPU)
#include <nvexec/multi_gpu_context.cuh>
#include <nvexec/stream_context.cuh>
#endif
#include <stdexec/execution.hpp>

#include "argparse/argparse.hpp"
Expand All @@ -45,7 +47,12 @@ struct args_params_t : public argparse::Args {
bool& no_header = kwarg("no-header", "Do not print csv header row (default: false)").set_default(false);
bool& help = flag("h, help", "print help");
bool& time = kwarg("t, time", "print time").set_default(true);
std::string& sch = kwarg("sch", "stdexec scheduler: [options: cpu, gpu, multigpu]").set_default("cpu");
std::string& sch = kwarg("sch", "stdexec scheduler: [options: cpu"
#if defined (USE_GPU)
", gpu, multigpu"
#endif //USE_GPU
"]").set_default("cpu");

int& nthreads = kwarg("nthreads", "number of threads").set_default(std::thread::hardware_concurrency());
};

Expand Down Expand Up @@ -121,12 +128,14 @@ int benchmark(args_params_t const& args) {
case sch_t::CPU:
solution = step.do_work(exec::static_thread_pool(nthreads).get_scheduler(), size, nt);
break;
#if defined(USE_GPU)
case sch_t::GPU:
solution = step.do_work(nvexec::stream_context().get_scheduler(), size, nt);
break;
case sch_t::MULTIGPU:
solution = step.do_work(nvexec::multi_gpu_stream_context().get_scheduler(), size, nt);
break;
#endif // USE_GPU
default:
std::cerr << "Unknown scheduler type encountered." << std::endl;
break;
Expand Down
2 changes: 1 addition & 1 deletion apps/comm-study/comm-study-no-senders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ auto work(P& A, P& B, P& Y, int N) {

// get sum(Y) - one last memcpy (not USM) D2H
sum +=
std::transform_reduce(std::execution::par_unseq, &Y[0], &Y[N], 0.0, std::plus<T>(), [](T &val){return val * val;});
std::reduce(std::execution::par_unseq, &Y[0], &Y[N], 0.0, std::plus<T>());

return sum / N;
}
Expand Down
142 changes: 0 additions & 142 deletions apps/comm-study/comm-study.cpp

This file was deleted.

14 changes: 11 additions & 3 deletions apps/fft/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ project(fft LANGUAGES CXX)

file(GLOB CPP_SOURCES "*.cpp")

# add -cudalib=cublas if -stdpar=gpu
if (STDPAR STREQUAL "gpu")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -cudalib=cublas")
endif()

foreach(source_file ${CPP_SOURCES})
if(NOT STDPAR STREQUAL "gpu")
if("${source_file}" MATCHES ".*stdpar.*gpu.*" OR "${source_file}"
Expand All @@ -18,16 +23,19 @@ foreach(source_file ${CPP_SOURCES})
add_executable(${exec_name} ${_EXCLUDE} ${source_file})

# add dependency on argparse
add_dependencies(${exec_name} argparse magic_enum)
add_dependencies(${exec_name} argparse)

set_source_files_properties(${source_file} PROPERTIES LANGUAGE CXX
LINKER_LANGUAGE CXX)
target_include_directories(
${exec_name}
PRIVATE ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_LIST_DIR}/../../include
${ARGPARSE_INCLUDE_DIR} ${MAGICENUM_INCLUDE_DIR} ${MDSPAN_INCLUDE_DIR})
${ARGPARSE_INCLUDE_DIR} ${MDSPAN_INCLUDE_DIR})

# uncomment only if using nvc++/23.1 - no need if nvc++/23.7
# target_link_directories(${exec_name} PRIVATE /opt/nvidia/hpc_sdk/Linux_x86_64/23.1/math_libs/lib64)

target_link_libraries(${exec_name} PUBLIC ${MPI_LIBS} stdexec)
target_link_libraries(${exec_name} PUBLIC ${MPI_LIBS} stdexec blas)

set_target_properties(
${exec_name}
Expand Down
71 changes: 65 additions & 6 deletions apps/fft/fft-serial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,74 @@

#include "fft.hpp"

//
// serial fft function
//
[[nodiscard]] std::vector<data_t> fft_serial(const data_t *x, const int N, bool debug = false)
{
std::vector<data_t> x_r(N);
std::vector<uint32_t> id(N);

// bit shift
int shift = 32 - ilog2(N);

// twiddle data in x[n]
for (int k = 0; k < N; k++)
{
id[k] = reverse_bits32(k) >> shift;
x_r[k] = x[id[k]];
}

// niterations
int niters = ilog2(N);
// local merge partition size
int lN = 2;

// set cout precision
std::cout << std::fixed << std::setprecision(1);

std::cout << "FFT progress: ";

for (int k = 0; k < niters; k++, lN*=2)
{
std::cout << (100.0 * k)/niters << "%.." << std::flush;

static Timer dtimer;

// number of partitions
int nparts = N/lN;
int tpp = lN/2;

if (debug)
dtimer.start();

// merge
for (int k = 0; k < N/2; k++)
{
// compute indices
int e = (k/tpp)*lN + (k % tpp);
auto o = e + tpp;
auto i = (k % tpp);
auto tmp = x_r[e] + x_r[o] * WNk(N, i * nparts);
x_r[o] = x_r[e] - x_r[o] * WNk(N, i * nparts);
x_r[e] = tmp;
}

if (debug)
std::cout << "This iter time: " << dtimer.stop() << " ms" << std::endl;
}

std::cout << "100%" << std::endl;
return x_r;
}

//
// simulation
//
int main(int argc, char* argv[])
{
// parse params
fft_params_t args = argparse::parse<fft_params_t>(argc, argv);
const fft_params_t args = argparse::parse<fft_params_t>(argc, argv);

// see if help wanted
if (args.help)
Expand Down Expand Up @@ -64,9 +125,6 @@ int main(int argc, char* argv[])
x_n.resize(N);
}

// y[n] = fft(x[n]);
sig_t y_n(x_n);

if (print_sig)
{
std::cout << std::endl << "x[n] = ";
Expand All @@ -81,7 +139,8 @@ int main(int argc, char* argv[])
Timer timer;

// fft radix-2 algorithm
fft_serial(y_n.data(), N, N);
// y[n] = fft(x[n]);
sig_t y_n(std::move(fft_serial(x_n.data(), N, args.debug)));

// stop timer
auto elapsed = timer.stop();
Expand All @@ -101,7 +160,7 @@ int main(int argc, char* argv[])
// validate the recursively computed fft
if (validate)
{
if (x_n.isFFT(y_n))
if (x_n.isFFT(y_n, exec::static_thread_pool(std::thread::hardware_concurrency()).get_scheduler()))
std::cout << "SUCCESS: y[n] == fft(x[n])" << std::endl;
else
std::cout << "FAILED: y[n] != fft(x[n])" << std::endl;
Expand Down
Loading

0 comments on commit b2f02ec

Please sign in to comment.