From 4488134edfd244af53c2c4ea141ca3223e8bd354 Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Wed, 19 Jun 2024 07:43:51 +0000 Subject: [PATCH] fix debug link error. fix windows crash --- CMakeLists.txt | 7 +- CMakePresets.json | 1 - ggml-sycl.cpp | 2 +- ggml-sycl/dpct/helper.hpp | 1603 ++++++++++++++++++------------------- ggml.h | 6 + 5 files changed, 808 insertions(+), 811 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c90414afa92be..9cfe08d7b7d59 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -665,6 +665,7 @@ if (LLAMA_SYCL) #todo: AOT find_package(IntelSYCL REQUIRED) + find_package(MKL REQUIRED) message(STATUS "SYCL found") @@ -679,11 +680,9 @@ if (LLAMA_SYCL) endif() add_compile_options(-I./) #include DPCT - add_compile_options(-I/${SYCL_INCLUDE_DIR}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -L${MKLROOT}/lib") if (LLAMA_SYCL_TARGET STREQUAL "NVIDIA") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda") endif() @@ -693,8 +692,10 @@ if (LLAMA_SYCL) list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp") if (WIN32) - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} -fsycl sycl7 OpenCL mkl_sycl_blas_dll.lib mkl_intel_ilp64_dll.lib mkl_sequential_dll.lib mkl_core_dll.lib) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL) else() + add_compile_options(-I/${SYCL_INCLUDE_DIR}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -L${MKLROOT}/lib") if (LLAMA_SYCL_TARGET STREQUAL "INTEL") set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} -fsycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread) elseif (LLAMA_SYCL_TARGET STREQUAL "NVIDIA") diff --git a/CMakePresets.json b/CMakePresets.json index 265843c84f032..501b33073c8b8 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -19,7 +19,6 @@ "cacheVariables": { "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", "CMAKE_CXX_COMPILER": "icx", - "CMAKE_C_COMPILER": "icx", "LLAMA_SYCL": "ON", "CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.." } diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 485f06ad331f8..e5ddf4a346c36 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -4911,7 +4911,7 @@ static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *sr GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS01; SYCL_CHECK(ggml_sycl_set_device(ctx.device)); queue_ptr main_stream = ctx.stream(); diff --git a/ggml-sycl/dpct/helper.hpp b/ggml-sycl/dpct/helper.hpp index 017fd6ee13268..af484d8333e59 100644 --- a/ggml-sycl/dpct/helper.hpp +++ b/ggml-sycl/dpct/helper.hpp @@ -58,7 +58,7 @@ #define __dpct_noinline__ __attribute__((noinline)) #endif -inline std::string get_device_type_name(const sycl::device &Device) { +inline std::string get_device_type_name(const sycl::device& Device) { auto DeviceType = Device.get_info(); switch (DeviceType) { case sycl::info::device_type::cpu: @@ -74,39 +74,39 @@ inline std::string get_device_type_name(const sycl::device &Device) { } } -inline std::string get_device_backend_and_type(const sycl::device &device) { +inline std::string get_device_backend_and_type(const sycl::device& device) { std::stringstream device_type; sycl::backend backend = device.get_backend(); - device_type << backend << ":" << get_device_type_name(device); + device_type << backend << ":" << get_device_type_name(device); return device_type.str(); } namespace dpct { - typedef sycl::queue *queue_ptr; - typedef sycl::event *event_ptr; - typedef char *device_ptr; + typedef sycl::queue* queue_ptr; + typedef sycl::event* event_ptr; + typedef char* device_ptr; typedef uint8_t byte_t; typedef sycl::buffer buffer_t; /// SYCL default exception handler inline auto exception_handler = [](sycl::exception_list exceptions) - { - for (std::exception_ptr const &e : exceptions) { - try - { - std::rethrow_exception(e); - } - catch (sycl::exception const &e) + for (std::exception_ptr const& e : exceptions) { - std::cerr << "Caught asynchronous SYCL exception:" << std::endl - << e.what() << std::endl - << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; + try + { + std::rethrow_exception(e); + } + catch (sycl::exception const& e) + { + std::cerr << "Caught asynchronous SYCL exception:" << std::endl + << e.what() << std::endl + << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + } } - } - }; + }; enum error_code { @@ -196,7 +196,7 @@ namespace dpct namespace detail { - static void get_version(const sycl::device &dev, int &major, int &minor) + static void get_version(const sycl::device& dev, int& major, int& minor) { // Version string has the following format: // a. OpenCL @@ -206,23 +206,24 @@ namespace dpct ver = dev.get_info(); std::string::size_type i = 0; while (i < ver.size()) { - if (isdigit(ver[i])) - break; - i++; + if (isdigit(ver[i])) + break; + i++; } major = std::stoi(&(ver[i])); while (i < ver.size()) { - if (ver[i] == '.') - break; - i++; + if (ver[i] == '.') + break; + i++; } if (i < ver.size()) { - // a. and b. - i++; - minor = std::stoi(&(ver[i])); - } else { - // c. - minor = 0; + // a. and b. + i++; + minor = std::stoi(&(ver[i])); + } + else { + // c. + minor = 0; } } @@ -231,7 +232,7 @@ namespace dpct { public: generic_error_type() = default; - generic_error_type(T value) : value{value} {} + generic_error_type(T value) : value{ value } {} operator T() const { return value; } private: @@ -245,11 +246,11 @@ namespace dpct { public: pitched_data() : pitched_data(nullptr, 0, 0, 0) {} - pitched_data(void *data, size_t pitch, size_t x, size_t y) + pitched_data(void* data, size_t pitch, size_t x, size_t y) : _data(data), _pitch(pitch), _x(x), _y(y) {} - void *get_data_ptr() { return _data; } - void set_data_ptr(void *data) { _data = data; } + void* get_data_ptr() { return _data; } + void set_data_ptr(void* data) { _data = data; } size_t get_pitch() { return _pitch; } void set_pitch(size_t pitch) { _pitch = pitch; } @@ -261,7 +262,7 @@ namespace dpct void set_y(size_t y) { _y = y; } private: - void *_data; + void* _data; size_t _pitch, _x, _y; }; @@ -269,33 +270,33 @@ namespace dpct { public: // get interface - const char *get_name() const { return _name; } - char *get_name() { return _name; } + const char* get_name() const { return _name; } + char* get_name() { return _name; } template , - std::enable_if_t> || - std::is_same_v, - int> = 0> + std::enable_if_t> || + std::is_same_v, + int> = 0> auto get_max_work_item_sizes() const { if constexpr (std::is_same_v>) return sycl::range<3>(_max_work_item_sizes_i[0], - _max_work_item_sizes_i[1], - _max_work_item_sizes_i[2]); + _max_work_item_sizes_i[1], + _max_work_item_sizes_i[2]); else { return _max_work_item_sizes_i; } } template , - std::enable_if_t> || - std::is_same_v, - int> = 0> + std::enable_if_t> || + std::is_same_v, + int> = 0> auto get_max_work_item_sizes() { if constexpr (std::is_same_v>) return sycl::range<3>(_max_work_item_sizes_i[0], - _max_work_item_sizes_i[1], - _max_work_item_sizes_i[2]); + _max_work_item_sizes_i[1], + _max_work_item_sizes_i[2]); else { return _max_work_item_sizes_i; @@ -317,24 +318,24 @@ namespace dpct { return _max_register_size_per_work_group; } - template || - std::is_same_v, - int> = 0> + template || + std::is_same_v, + int> = 0> auto get_max_nd_range_size() const { - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) return _max_nd_range_size; else return _max_nd_range_size_i; } - template || - std::is_same_v, - int> = 0> + template || + std::is_same_v, + int> = 0> auto get_max_nd_range_size() { - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) return _max_nd_range_size; else return _max_nd_range_size_i; @@ -357,7 +358,7 @@ namespace dpct } // set interface - void set_name(const char *name) + void set_name(const char* name) { size_t length = strlen(name); if (length < 256) @@ -376,7 +377,7 @@ namespace dpct _max_work_item_sizes_i[i] = max_work_item_sizes[i]; } [[deprecated]] void - set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes) + set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes) { for (int i = 0; i < 3; ++i) { @@ -416,7 +417,7 @@ namespace dpct _max_sub_group_size = max_sub_group_size; } void - set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit) + set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit) { _max_work_items_per_compute_unit = max_work_items_per_compute_unit; } @@ -437,7 +438,7 @@ namespace dpct _memory_bus_width = memory_bus_width; } void - set_max_register_size_per_work_group(int max_register_size_per_work_group) + set_max_register_size_per_work_group(int max_register_size_per_work_group) { _max_register_size_per_work_group = max_register_size_per_work_group; } @@ -481,21 +482,21 @@ namespace dpct std::array _uuid; }; - static int get_major_version(const sycl::device &dev) + static int get_major_version(const sycl::device& dev) { int major, minor; detail::get_version(dev, major, minor); return major; } - static int get_minor_version(const sycl::device &dev) + static int get_minor_version(const sycl::device& dev) { int major, minor; detail::get_version(dev, major, minor); return minor; } - static void get_device_info(device_info &out, const sycl::device &dev) + static void get_device_info(device_info& out, const sycl::device& dev) { device_info prop; prop.set_name(dev.get_info().c_str()); @@ -556,17 +557,17 @@ namespace dpct Use 3200000 kHz as memory_clock_rate default value. \ Use 64 bits as memory_bus_width default value.") #else -#warning "get_device_info: querying memory_clock_rate and \ + #warning "get_device_info: querying memory_clock_rate and \ memory_bus_width are not supported by the compiler used. \ Use 3200000 kHz as memory_clock_rate default value. \ Use 64 bits as memory_bus_width default value." #endif - size_t max_sub_group_size = 1; + size_t max_sub_group_size = 1; std::vector sub_group_sizes = dev.get_info(); - for (const auto &sub_group_size : sub_group_sizes) + for (const auto& sub_group_size : sub_group_sizes) { if (max_sub_group_size < sub_group_size) max_sub_group_size = sub_group_size; @@ -576,7 +577,7 @@ namespace dpct prop.set_max_work_items_per_compute_unit( dev.get_info()); - int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF}; + int max_nd_range_size[] = { 0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF }; prop.set_max_nd_range_size(max_nd_range_size); // Estimates max register size per work group, feel free to update the value @@ -594,13 +595,13 @@ namespace dpct typedef std::mutex mutex_type; public: - device_ext() : sycl::device(), _ctx(*this) {} + device_ext() : sycl::device() {} ~device_ext() { std::lock_guard lock(m_mutex); clear_queues(); } - device_ext(const sycl::device &base) : sycl::device(base), _ctx(*this) + device_ext(const sycl::device& base) : sycl::device(base) { std::lock_guard lock(m_mutex); init_queues(); @@ -663,12 +664,12 @@ namespace dpct /// Get the number of bytes of free and total memory on the SYCL device. /// \param [out] free_memory The number of bytes of free memory on the SYCL device. /// \param [out] total_memory The number of bytes of total memory on the SYCL device. - void get_memory_info(size_t &free_memory, size_t &total_memory) + void get_memory_info(size_t& free_memory, size_t& total_memory) { total_memory = get_device_info().get_global_mem_size(); - const char *warning_info = "get_memory_info: [warning] ext_intel_free_memory is not " - "supported (export/set ZES_ENABLE_SYSMAN=1 to support), " - "use total memory as free memory"; + const char* warning_info = "get_memory_info: [warning] ext_intel_free_memory is not " + "supported (export/set ZES_ENABLE_SYSMAN=1 to support), " + "use total memory as free memory"; #if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105) if (!has(sycl::aspect::ext_intel_free_memory)) { @@ -685,12 +686,12 @@ namespace dpct #if defined(_MSC_VER) && !defined(__clang__) #pragma message("Querying the number of bytes of free memory is not supported") #else -#warning "Querying the number of bytes of free memory is not supported" + #warning "Querying the number of bytes of free memory is not supported" #endif #endif } - void get_device_info(device_info &out) const + void get_device_info(device_info& out) const { dpct::get_device_info(out, *this); } @@ -709,11 +710,11 @@ namespace dpct init_queues(); } - sycl::queue &in_order_queue() { return *_q_in_order; } + sycl::queue& in_order_queue() { return _q_in_order; } - sycl::queue &out_of_order_queue() { return *_q_out_of_order; } + sycl::queue& out_of_order_queue() { return _q_out_of_order; } - sycl::queue &default_queue() + sycl::queue& default_queue() { return in_order_queue(); } @@ -721,130 +722,120 @@ namespace dpct void queues_wait_and_throw() { std::unique_lock lock(m_mutex); - std::vector> current_queues( - _queues); lock.unlock(); - for (const auto &q : current_queues) + for (auto& q : _queues) { - q->wait_and_throw(); + q.wait_and_throw(); } // Guard the destruct of current_queues to make sure the ref count is safe. lock.lock(); } - sycl::queue *create_queue(bool enable_exception_handler = false) + sycl::queue create_queue(bool enable_exception_handler = false) { return create_in_order_queue(enable_exception_handler); } - sycl::queue *create_queue(sycl::context context, sycl::device device, - bool enable_exception_handler = false) { - return create_in_order_queue(context, device, enable_exception_handler); + sycl::queue create_queue(sycl::device device, + bool enable_exception_handler = false) { + return create_in_order_queue(device, enable_exception_handler); } - sycl::queue *create_in_order_queue(bool enable_exception_handler = false) { + sycl::queue create_in_order_queue(bool enable_exception_handler = false) { std::lock_guard lock(m_mutex); return create_queue_impl(enable_exception_handler, - sycl::property::queue::in_order()); + sycl::property::queue::in_order()); } - sycl::queue *create_in_order_queue(sycl::context context, sycl::device device, - bool enable_exception_handler = false) { + sycl::queue create_in_order_queue(sycl::device device, + bool enable_exception_handler = false) { std::lock_guard lock(m_mutex); - return create_queue_impl(context, device, enable_exception_handler, - sycl::property::queue::in_order()); + return create_queue_impl(device, enable_exception_handler, + sycl::property::queue::in_order()); } - sycl::queue *create_out_of_order_queue(bool enable_exception_handler = false) { + sycl::queue create_out_of_order_queue(bool enable_exception_handler = false) { std::lock_guard lock(m_mutex); return create_queue_impl(enable_exception_handler); } - void destroy_queue(sycl::queue *&queue) + void destroy_queue(sycl::queue queue) { std::lock_guard lock(m_mutex); - _queues.erase(std::remove_if(_queues.begin(), _queues.end(), - [=](const std::shared_ptr &q) -> bool - { - return q.get() == queue; - }), - _queues.end()); - queue = nullptr; + _queues.clear(); } - void set_saved_queue(sycl::queue *q) + void set_saved_queue(sycl::queue q) { std::lock_guard lock(m_mutex); _saved_queue = q; } - sycl::queue *get_saved_queue() const + sycl::queue get_saved_queue() const { std::lock_guard lock(m_mutex); return _saved_queue; } - sycl::context get_context() const { return _ctx; } private: void clear_queues() { _queues.clear(); - _q_in_order = _q_out_of_order = _saved_queue = nullptr; } void init_queues() { _q_in_order = create_queue_impl(true, sycl::property::queue::in_order()); _q_out_of_order = create_queue_impl(true); - _saved_queue = &default_queue(); + _saved_queue = default_queue(); } /// Caller should acquire resource \p m_mutex before calling this function. template - sycl::queue *create_queue_impl(bool enable_exception_handler, - Properties... properties) + sycl::queue create_queue_impl(bool enable_exception_handler, + Properties... properties) { sycl::async_handler eh = {}; if (enable_exception_handler) { eh = exception_handler; } - _queues.push_back(std::make_shared( - _ctx, *this, eh, + auto q = sycl::queue( + *this, eh, sycl::property_list( #ifdef DPCT_PROFILING_ENABLED sycl::property::queue::enable_profiling(), #endif - properties...))); + properties...)); + _queues.push_back(q); - return _queues.back().get(); + return _queues.back(); } template - sycl::queue *create_queue_impl(sycl::context context, sycl::device device, - bool enable_exception_handler, - Properties... properties) { + sycl::queue create_queue_impl(sycl::device device, + bool enable_exception_handler, + Properties... properties) { sycl::async_handler eh = {}; if (enable_exception_handler) { eh = exception_handler; } - _queues.push_back(std::make_shared( - context, device, eh, + _queues.push_back(sycl::queue( + device, eh, sycl::property_list( - #ifdef DPCT_PROFILING_ENABLED +#ifdef DPCT_PROFILING_ENABLED sycl::property::queue::enable_profiling(), - #endif +#endif properties...))); - return _queues.back().get(); + return _queues.back(); } - void get_version(int &major, int &minor) const + void get_version(int& major, int& minor) const { detail::get_version(*this, major, minor); } - sycl::queue *_q_in_order, *_q_out_of_order; - sycl::queue *_saved_queue; - sycl::context _ctx; - std::vector> _queues; + sycl::queue _q_in_order, _q_out_of_order; + sycl::queue _saved_queue; + std::vector _queues; mutable mutex_type m_mutex; }; @@ -852,13 +843,13 @@ namespace dpct class dev_mgr { public: - device_ext ¤t_device() + device_ext& current_device() { unsigned int dev_id = current_device_id(); check_id(dev_id); return *_devs[dev_id]; } - device_ext &cpu_device() const + device_ext& cpu_device() const { std::lock_guard lock(m_mutex); if (_cpu_device == -1) @@ -870,7 +861,7 @@ namespace dpct return *_devs[_cpu_device]; } } - device_ext &get_device(unsigned int id) const + device_ext& get_device(unsigned int id) const { std::lock_guard lock(m_mutex); check_id(id); @@ -896,7 +887,7 @@ namespace dpct } unsigned int device_count() { return _devs.size(); } - unsigned int get_device_id(const sycl::device &dev) + unsigned int get_device_id(const sycl::device& dev) { unsigned int id = 0; for (auto dev_item : _devs) @@ -912,8 +903,8 @@ namespace dpct template std::enable_if_t< - std::is_invocable_r_v> - select_device(const DeviceSelector &selector = sycl::gpu_selector_v) + std::is_invocable_r_v> + select_device(const DeviceSelector& selector = sycl::gpu_selector_v) { sycl::device selected_device = sycl::device(selector); unsigned int selected_device_id = get_device_id(selected_device); @@ -921,32 +912,32 @@ namespace dpct } /// Returns the instance of device manager singleton. - static dev_mgr &instance() + static dev_mgr& instance() { static dev_mgr d_m; return d_m; } - dev_mgr(const dev_mgr &) = delete; - dev_mgr &operator=(const dev_mgr &) = delete; - dev_mgr(dev_mgr &&) = delete; - dev_mgr &operator=(dev_mgr &&) = delete; + dev_mgr(const dev_mgr&) = delete; + dev_mgr& operator=(const dev_mgr&) = delete; + dev_mgr(dev_mgr&&) = delete; + dev_mgr& operator=(dev_mgr&&) = delete; private: mutable std::recursive_mutex m_mutex; - static bool compare_dev(sycl::device &device1, sycl::device &device2) + static bool compare_dev(sycl::device& device1, sycl::device& device2) { sycl::backend backend1 = device1.get_backend(); sycl::backend backend2 = device2.get_backend(); // levelzero backends always come first - if(backend1 == sycl::backend::ext_oneapi_level_zero && backend2 != sycl::backend::ext_oneapi_level_zero) return true; - if(backend1 != sycl::backend::ext_oneapi_level_zero && backend2 == sycl::backend::ext_oneapi_level_zero) return false; + if (backend1 == sycl::backend::ext_oneapi_level_zero && backend2 != sycl::backend::ext_oneapi_level_zero) return true; + if (backend1 != sycl::backend::ext_oneapi_level_zero && backend2 == sycl::backend::ext_oneapi_level_zero) return false; dpct::device_info prop1; dpct::get_device_info(prop1, device1); dpct::device_info prop2; dpct::get_device_info(prop2, device2); return prop1.get_max_compute_units() > prop2.get_max_compute_units(); } - static int convert_backend_index(std::string & backend) { + static int convert_backend_index(std::string& backend) { if (backend == "ext_oneapi_level_zero:gpu") return 0; if (backend == "opencl:gpu") return 1; if (backend == "ext_oneapi_cuda:gpu") return 2; @@ -956,7 +947,7 @@ namespace dpct printf("convert_backend_index: can't handle backend=%s\n", backend.c_str()); GGML_ASSERT(false); } - static bool compare_backend(std::string &backend1, std::string &backend2) { + static bool compare_backend(std::string& backend1, std::string& backend2) { return convert_backend_index(backend1) < convert_backend_index(backend2); } dev_mgr() @@ -980,26 +971,26 @@ namespace dpct Platforms.pop_back(); auto devices = Platform.get_devices(); std::string backend_type = get_device_backend_and_type(devices[0]); - for (const auto &device : devices) { + for (const auto& device : devices) { backend_devices[backend_type].push_back(device); } } std::vector keys; - for(auto it = backend_devices.begin(); it != backend_devices.end(); ++it) { + for (auto it = backend_devices.begin(); it != backend_devices.end(); ++it) { keys.push_back(it->first); } std::sort(keys.begin(), keys.end(), compare_backend); - for (auto &key : keys) { + for (auto& key : keys) { std::vector devs = backend_devices[key]; std::sort(devs.begin(), devs.end(), compare_dev); - for (const auto &dev : devs) { + for (const auto& dev : devs) { sycl_all_devs.push_back(dev); } } - for (auto &dev : sycl_all_devs) + for (auto& dev : sycl_all_devs) { if (dev == default_device) { @@ -1029,7 +1020,7 @@ namespace dpct int _cpu_device = -1; }; - static inline sycl::queue &get_default_queue() + static inline sycl::queue& get_default_queue() { return dev_mgr::instance().current_device().default_queue(); } @@ -1044,8 +1035,8 @@ namespace dpct end }; - static pointer_access_attribute get_pointer_attribute(sycl::queue &q, - const void *ptr) + static pointer_access_attribute get_pointer_attribute(sycl::queue& q, + const void* ptr) { switch (sycl::get_pointer_type(ptr, q.get_context())) { @@ -1063,19 +1054,19 @@ namespace dpct inline constexpr std::uint64_t get_type_combination_id(ArgT Val) { static_assert((unsigned char)library_data_t::library_data_t_size <= - std::numeric_limits::max() && - "library_data_t size exceeds limit."); + std::numeric_limits::max() && + "library_data_t size exceeds limit."); static_assert(std::is_same_v, "Unsupported ArgT"); return (std::uint64_t)Val; } template inline constexpr std::uint64_t get_type_combination_id(FirstT FirstVal, - RestT... RestVal) + RestT... RestVal) { static_assert((std::uint8_t)library_data_t::library_data_t_size <= - std::numeric_limits::max() && - "library_data_t size exceeds limit."); + std::numeric_limits::max() && + "library_data_t size exceeds limit."); static_assert(sizeof...(RestT) <= 8 && "Too many parameters"); static_assert(std::is_same_v, "Unsupported FirstT"); return get_type_combination_id(RestVal...) << 8 | ((std::uint64_t)FirstVal); @@ -1088,10 +1079,10 @@ namespace dpct // Reserved address space, no real memory allocation happens here. #if defined(__linux__) mapped_address_space = - (byte_t *)mmap(nullptr, mapped_region_size, PROT_NONE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + (byte_t*)mmap(nullptr, mapped_region_size, PROT_NONE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); #elif defined(_WIN64) - mapped_address_space = (byte_t *)VirtualAlloc( + mapped_address_space = (byte_t*)VirtualAlloc( NULL, // NULL specified as the base address parameter mapped_region_size, // Size of allocation MEM_RESERVE, // Allocate reserved pages @@ -1108,7 +1099,7 @@ namespace dpct struct allocation { buffer_t buffer; - byte_t *alloc_ptr; + byte_t* alloc_ptr; size_t size; }; @@ -1123,13 +1114,13 @@ namespace dpct #endif }; - mem_mgr(const mem_mgr &) = delete; - mem_mgr &operator=(const mem_mgr &) = delete; - mem_mgr(mem_mgr &&) = delete; - mem_mgr &operator=(mem_mgr &&) = delete; + mem_mgr(const mem_mgr&) = delete; + mem_mgr& operator=(const mem_mgr&) = delete; + mem_mgr(mem_mgr&&) = delete; + mem_mgr& operator=(mem_mgr&&) = delete; /// Allocate - void *mem_alloc(size_t size) + void* mem_alloc(size_t size) { if (!size) return nullptr; @@ -1141,9 +1132,9 @@ namespace dpct // Allocation sycl::range<1> r(size); buffer_t buf(r); - allocation A{buf, next_free, size}; + allocation A{ buf, next_free, size }; // Map allocation to device pointer - void *result = next_free; + void* result = next_free; m_map.emplace(next_free + size, A); // Update pointer to the next free space. next_free += (size + extra_padding + alignment - 1) & ~(alignment - 1); @@ -1152,7 +1143,7 @@ namespace dpct } /// Deallocate - void mem_free(const void *ptr) + void mem_free(const void* ptr) { if (!ptr) return; @@ -1162,7 +1153,7 @@ namespace dpct } /// map: device pointer -> allocation(buffer, alloc_ptr, size) - allocation translate_ptr(const void *ptr) + allocation translate_ptr(const void* ptr) { std::lock_guard lock(m_mutex); auto it = get_map_iterator(ptr); @@ -1170,40 +1161,40 @@ namespace dpct } /// Check if the pointer represents device pointer or not. - bool is_device_ptr(const void *ptr) const + bool is_device_ptr(const void* ptr) const { std::lock_guard lock(m_mutex); return (mapped_address_space <= ptr) && - (ptr < mapped_address_space + mapped_region_size); + (ptr < mapped_address_space + mapped_region_size); } /// Returns the instance of memory manager singleton. - static mem_mgr &instance() + static mem_mgr& instance() { static mem_mgr m; return m; } private: - std::map m_map; + std::map m_map; mutable std::mutex m_mutex; - byte_t *mapped_address_space; - byte_t *next_free; + byte_t* mapped_address_space; + byte_t* next_free; const size_t mapped_region_size = 128ull * 1024 * 1024 * 1024; const size_t alignment = 256; /// This padding may be defined to some positive value to debug /// out of bound accesses. const size_t extra_padding = 0; - std::map::iterator get_map_iterator(const void *ptr) + std::map::iterator get_map_iterator(const void* ptr) { - auto it = m_map.upper_bound((byte_t *)ptr); + auto it = m_map.upper_bound((byte_t*)ptr); if (it == m_map.end()) { // Not a virtual pointer. throw std::runtime_error("can not get buffer from non-virtual pointer"); } - const allocation &alloc = it->second; + const allocation& alloc = it->second; if (ptr < alloc.alloc_ptr) { // Out of bound. @@ -1225,7 +1216,7 @@ namespace dpct sycl::access::target::device; static constexpr sycl::access_mode mode = (Memory == constant) ? sycl::access_mode::read - : sycl::access_mode::read_write; + : sycl::access_mode::read_write; static constexpr size_t type_size = sizeof(T); using element_t = typename std::conditional::type; @@ -1234,17 +1225,17 @@ namespace dpct using accessor_t = typename std::conditional< Memory == local, sycl::local_accessor, sycl::accessor>::type; - using pointer_t = T *; + using pointer_t = T*; }; - static inline void *dpct_malloc(size_t size, sycl::queue &q) + static inline void* dpct_malloc(size_t size, sycl::queue& q) { return sycl::malloc_device(size, q.get_device(), q.get_context()); } #define PITCH_DEFAULT_ALIGN(x) (((x) + 31) & ~(0x1F)) - static inline void *dpct_malloc(size_t &pitch, size_t x, size_t y, size_t z, - sycl::queue &q) + static inline void* dpct_malloc(size_t& pitch, size_t x, size_t y, size_t z, + sycl::queue& q) { pitch = PITCH_DEFAULT_ALIGN(x); return dpct_malloc(pitch * y * z, q); @@ -1260,8 +1251,8 @@ namespace dpct * @return An event representing the memset operation. */ template - static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr, - valueT value, size_t size) + static inline sycl::event dpct_memset(sycl::queue& q, void* dev_ptr, + valueT value, size_t size) { return q.fill(dev_ptr, value, size); } @@ -1277,15 +1268,15 @@ namespace dpct */ template static inline std::vector - dpct_memset(sycl::queue &q, pitched_data data, valueT value, - sycl::range<3> size) + dpct_memset(sycl::queue& q, pitched_data data, valueT value, + sycl::range<3> size) { std::vector event_list; size_t slice = data.get_pitch() * data.get_y(); - unsigned char *data_surface = (unsigned char *)data.get_data_ptr(); + unsigned char* data_surface = (unsigned char*)data.get_data_ptr(); for (size_t z = 0; z < size.get(2); ++z) { - unsigned char *data_ptr = data_surface; + unsigned char* data_ptr = data_surface; for (size_t y = 0; y < size.get(1); ++y) { event_list.push_back(dpct_memset(q, data_ptr, value, size.get(0))); @@ -1309,16 +1300,16 @@ namespace dpct */ template static inline std::vector - dpct_memset(sycl::queue &q, void *ptr, size_t pitch, valueT val, size_t x, - size_t y) + dpct_memset(sycl::queue& q, void* ptr, size_t pitch, valueT val, size_t x, + size_t y) { return dpct_memset(q, pitched_data(ptr, pitch, x, 1), val, - sycl::range<3>(x, y, 1)); + sycl::range<3>(x, y, 1)); } - static memcpy_direction deduce_memcpy_direction(sycl::queue &q, void *to_ptr, - const void *from_ptr, - memcpy_direction dir) + static memcpy_direction deduce_memcpy_direction(sycl::queue& q, void* to_ptr, + const void* from_ptr, + memcpy_direction dir) { switch (dir) { @@ -1332,16 +1323,16 @@ namespace dpct // table[to_attribute][from_attribute] static const memcpy_direction direction_table[static_cast(pointer_access_attribute::end)] - [static_cast(pointer_access_attribute::end)] = - {{memcpy_direction::host_to_host, - memcpy_direction::device_to_host, - memcpy_direction::host_to_host}, - {memcpy_direction::host_to_device, - memcpy_direction::device_to_device, - memcpy_direction::device_to_device}, - {memcpy_direction::host_to_host, - memcpy_direction::device_to_device, - memcpy_direction::device_to_device}}; + [static_cast(pointer_access_attribute::end)] = + { {memcpy_direction::host_to_host, + memcpy_direction::device_to_host, + memcpy_direction::host_to_host}, + {memcpy_direction::host_to_device, + memcpy_direction::device_to_device, + memcpy_direction::device_to_device}, + {memcpy_direction::host_to_host, + memcpy_direction::device_to_device, + memcpy_direction::device_to_device} }; return direction_table[static_cast(get_pointer_attribute( q, to_ptr))][static_cast(get_pointer_attribute(q, from_ptr))]; } @@ -1351,9 +1342,9 @@ namespace dpct } static sycl::event - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size, - memcpy_direction direction, - const std::vector &dep_events = {}) + dpct_memcpy(sycl::queue& q, void* to_ptr, const void* from_ptr, size_t size, + memcpy_direction direction, + const std::vector& dep_events = {}) { if (!size) return sycl::event{}; @@ -1363,13 +1354,13 @@ namespace dpct // Get actual copy range and make sure it will not exceed range. static inline size_t get_copy_range(sycl::range<3> size, size_t slice, - size_t pitch) + size_t pitch) { return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0); } static inline size_t get_offset(sycl::id<3> id, size_t slice, - size_t pitch) + size_t pitch) { return slice * id.get(2) + pitch * id.get(1) + id.get(0); } @@ -1377,51 +1368,51 @@ namespace dpct /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr /// and \p from_range to another specified by \p to_ptr and \p to_range. static inline std::vector - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, - sycl::range<3> to_range, sycl::range<3> from_range, - sycl::id<3> to_id, sycl::id<3> from_id, - sycl::range<3> size, memcpy_direction direction, - const std::vector &dep_events = {}) + dpct_memcpy(sycl::queue& q, void* to_ptr, const void* from_ptr, + sycl::range<3> to_range, sycl::range<3> from_range, + sycl::id<3> to_id, sycl::id<3> from_id, + sycl::range<3> size, memcpy_direction direction, + const std::vector& dep_events = {}) { // RAII for host pointer class host_buffer { - void *_buf; + void* _buf; size_t _size; - sycl::queue &_q; - const std::vector &_deps; // free operation depends + sycl::queue& _q; + const std::vector& _deps; // free operation depends public: - host_buffer(size_t size, sycl::queue &q, - const std::vector &deps) + host_buffer(size_t size, sycl::queue& q, + const std::vector& deps) : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {} - void *get_ptr() const { return _buf; } + void* get_ptr() const { return _buf; } size_t get_size() const { return _size; } ~host_buffer() { if (_buf) { - _q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(_deps); - cgh.host_task([buf = _buf] { std::free(buf); }); }); + _q.submit([&](sycl::handler& cgh) + { + cgh.depends_on(_deps); + cgh.host_task([buf = _buf] { std::free(buf); }); }); } } }; std::vector event_list; size_t to_slice = to_range.get(1) * to_range.get(0), - from_slice = from_range.get(1) * from_range.get(0); - unsigned char *to_surface = - (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); - const unsigned char *from_surface = - (const unsigned char *)from_ptr + + from_slice = from_range.get(1) * from_range.get(0); + unsigned char* to_surface = + (unsigned char*)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); + const unsigned char* from_surface = + (const unsigned char*)from_ptr + get_offset(from_id, from_slice, from_range.get(0)); if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) { - return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), - direction, dep_events)}; + return { dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), + direction, dep_events) }; } direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction); size_t size_slice = size.get(1) * size.get(0); @@ -1430,20 +1421,20 @@ namespace dpct case host_to_host: for (size_t z = 0; z < size.get(2); ++z) { - unsigned char *to_ptr = to_surface; - const unsigned char *from_ptr = from_surface; + unsigned char* to_ptr = to_surface; + const unsigned char* from_ptr = from_surface; if (to_range.get(0) == from_range.get(0) && to_range.get(0) == size.get(0)) { event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice, - direction, dep_events)); + direction, dep_events)); } else { for (size_t y = 0; y < size.get(1); ++y) { event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0), - direction, dep_events)); + direction, dep_events)); to_ptr += to_range.get(0); from_ptr += from_range.get(0); } @@ -1455,15 +1446,15 @@ namespace dpct case host_to_device: { host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q, - event_list); + event_list); std::vector host_events; if (to_slice == size_slice) { // Copy host data to a temp host buffer with the shape of target. host_events = dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range, - sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, - host_to_host, dep_events); + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, + host_to_host, dep_events); } else { @@ -1474,39 +1465,39 @@ namespace dpct // If has padding data, not sure whether it is useless. So fill temp // buffer with it. std::vector{ - dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(), - device_to_host, dep_events)}); + dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(), + device_to_host, dep_events)}); } // Copy from temp host buffer to device with only one submit. event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(), - buf.get_size(), host_to_device, - host_events)); + buf.get_size(), host_to_device, + host_events)); break; } case device_to_host: { host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q, - event_list); + event_list); // Copy from host temp buffer to host target with reshaping. event_list = dpct_memcpy( q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host, // Copy from device to temp host buffer with only one submit. std::vector{dpct_memcpy(q, buf.get_ptr(), from_surface, - buf.get_size(), - device_to_host, dep_events)}); + buf.get_size(), + device_to_host, dep_events)}); break; } case device_to_device: - event_list.push_back(q.submit([&](sycl::handler &cgh){ - cgh.depends_on(dep_events); - cgh.parallel_for( - size, - [=](sycl::id<3> id) { - to_surface[get_offset(id, to_slice, to_range.get(0))] = - from_surface[get_offset(id, from_slice, from_range.get(0))]; - }); })); - break; + event_list.push_back(q.submit([&](sycl::handler& cgh) { + cgh.depends_on(dep_events); + cgh.parallel_for( + size, + [=](sycl::id<3> id) { + to_surface[get_offset(id, to_slice, to_range.get(0))] = + from_surface[get_offset(id, from_slice, from_range.get(0))]; + }); })); + break; default: throw std::runtime_error("dpct_memcpy: invalid direction value"); } @@ -1515,26 +1506,26 @@ namespace dpct /// memcpy 2D/3D matrix specified by pitched_data. static inline std::vector - dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id, - pitched_data from, sycl::id<3> from_id, sycl::range<3> size, - memcpy_direction direction = automatic) + dpct_memcpy(sycl::queue& q, pitched_data to, sycl::id<3> to_id, + pitched_data from, sycl::id<3> from_id, sycl::range<3> size, + memcpy_direction direction = automatic) { return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(), - sycl::range<3>(to.get_pitch(), to.get_y(), 1), - sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id, - size, direction); + sycl::range<3>(to.get_pitch(), to.get_y(), 1), + sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id, + size, direction); } /// memcpy 2D matrix with pitch. static inline std::vector - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, - size_t to_pitch, size_t from_pitch, size_t x, size_t y, - memcpy_direction direction = automatic) + dpct_memcpy(sycl::queue& q, void* to_ptr, const void* from_ptr, + size_t to_pitch, size_t from_pitch, size_t x, size_t y, + memcpy_direction direction = automatic) { return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1), - sycl::range<3>(from_pitch, y, 1), - sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), - sycl::range<3>(x, y, 1), direction); + sycl::range<3>(from_pitch, y, 1), + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), + sycl::range<3>(x, y, 1), direction); } namespace deprecated @@ -1554,9 +1545,9 @@ namespace dpct using void_pointer = typename std::allocator_traits::void_pointer; using const_void_pointer = typename std::allocator_traits::const_void_pointer; - using reference = typename std::allocator_traits::value_type &; + using reference = typename std::allocator_traits::value_type&; using const_reference = - const typename std::allocator_traits::value_type &; + const typename std::allocator_traits::value_type&; using difference_type = typename std::allocator_traits::difference_type; using size_type = typename std::allocator_traits::size_type; @@ -1577,8 +1568,8 @@ namespace dpct usm_allocator() : _impl(dpct::get_default_queue()) {} ~usm_allocator() {} - usm_allocator(const usm_allocator &other) : _impl(other._impl) {} - usm_allocator(usm_allocator &&other) : _impl(std::move(other._impl)) {} + usm_allocator(const usm_allocator& other) : _impl(other._impl) {} + usm_allocator(usm_allocator&& other) : _impl(std::move(other._impl)) {} pointer address(reference r) { return &r; } const_pointer address(const_reference r) { return &r; } pointer allocate(size_type cnt, const_void_pointer hint = nullptr) @@ -1593,14 +1584,14 @@ namespace dpct { return std::allocator_traits::max_size(_impl); } - bool operator==(const usm_allocator &other) const { return _impl == other._impl; } - bool operator!=(const usm_allocator &other) const { return _impl != other._impl; } + bool operator==(const usm_allocator& other) const { return _impl == other._impl; } + bool operator!=(const usm_allocator& other) const { return _impl != other._impl; } }; } // namespace deprecated - inline void dpct_free(void *ptr, - const sycl::queue &q) + inline void dpct_free(void* ptr, + const sycl::queue& q) { if (ptr) { @@ -1609,29 +1600,29 @@ namespace dpct } template - inline auto get_memory(const void *x) + inline auto get_memory(const void* x) { - T *new_x = reinterpret_cast(const_cast(x)); + T* new_x = reinterpret_cast(const_cast(x)); return new_x; } template - inline typename DataType::T2 get_value(const T *s, sycl::queue &q) + inline typename DataType::T2 get_value(const T* s, sycl::queue& q) { using Ty = typename DataType::T2; Ty s_h; if (get_pointer_attribute(q, s) == pointer_access_attribute::device_only) - detail::dpct_memcpy(q, (void *)&s_h, (const void *)s, sizeof(T), device_to_host) - .wait(); + detail::dpct_memcpy(q, (void*)&s_h, (const void*)s, sizeof(T), device_to_host) + .wait(); else - s_h = *reinterpret_cast(s); + s_h = *reinterpret_cast(s); return s_h; } } // namespace detail template - inline auto get_value(const T *s, sycl::queue &q) + inline auto get_value(const T* s, sycl::queue& q) { return detail::get_value(s, q); } @@ -1639,13 +1630,13 @@ namespace dpct namespace detail { template - inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a, int lda, const void *b, - int ldb, const void *beta, void *c, int ldc) + inline void gemm_impl(sycl::queue& q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void* alpha, const void* a, int lda, const void* b, + int ldb, const void* beta, void* c, int ldc) { - Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); auto data_a = get_memory(a); auto data_b = get_memory(b); auto data_c = get_memory(c); @@ -1682,11 +1673,11 @@ namespace dpct }; template - inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void **a, int lda, - const void **b, int ldb, const void *beta, void **c, - int ldc, int batch_size) + inline void gemm_batch_impl(sycl::queue& q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void* alpha, const void** a, int lda, + const void** b, int ldb, const void* beta, void** c, + int ldc, int batch_size) { struct matrix_info_t { @@ -1697,11 +1688,11 @@ namespace dpct std::int64_t groupsize_info; }; - Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - matrix_info_t *matrix_info = - (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); + matrix_info_t* matrix_info = + (matrix_info_t*)std::malloc(sizeof(matrix_info_t)); matrix_info->transpose_info[0] = a_trans; matrix_info->transpose_info[1] = b_trans; matrix_info->value_info[0] = alpha_value; @@ -1718,28 +1709,28 @@ namespace dpct q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info, - reinterpret_cast(a), matrix_info->ld_info, - reinterpret_cast(b), matrix_info->ld_info + 1, - matrix_info->value_info + 1, reinterpret_cast(c), + reinterpret_cast(a), matrix_info->ld_info, + reinterpret_cast(b), matrix_info->ld_info + 1, + matrix_info->value_info + 1, reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); - q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(e); - cgh.host_task([=] { std::free(matrix_info); }); }); + q.submit([&](sycl::handler& cgh) + { + cgh.depends_on(e); + cgh.host_task([=] { std::free(matrix_info); }); }); } template inline void - gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, - int k, const void *alpha, const void *a, int lda, - long long int stride_a, const void *b, int ldb, - long long int stride_b, const void *beta, void *c, - int ldc, long long int stride_c, int batch_size) - { - Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + gemm_batch_impl(sycl::queue& q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, + int k, const void* alpha, const void* a, int lda, + long long int stride_a, const void* b, int ldb, + long long int stride_b, const void* beta, void* c, + int ldc, long long int stride_c, int batch_size) + { + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); auto data_a = get_memory(a); auto data_b = get_memory(b); auto data_c = get_memory(c); @@ -1753,9 +1744,9 @@ namespace dpct template inline unsigned vectorized_binary(unsigned a, unsigned b, - const BinaryOperation binary_op) + const BinaryOperation binary_op) { - sycl::vec v0{a}, v1{b}; + sycl::vec v0{ a }, v1{ b }; auto v2 = v0.as(); auto v3 = v1.as(); auto v4 = @@ -1764,9 +1755,9 @@ namespace dpct return v0; } - static void async_dpct_memcpy(void *to_ptr, const void *from_ptr, size_t size, - memcpy_direction direction = automatic, - sycl::queue &q = dpct::get_default_queue()) + static void async_dpct_memcpy(void* to_ptr, const void* from_ptr, size_t size, + memcpy_direction direction = automatic, + sycl::queue& q = dpct::get_default_queue()) { detail::dpct_memcpy(q, to_ptr, from_ptr, size, direction); } @@ -1779,16 +1770,16 @@ namespace dpct template T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask, - unsigned int logical_sub_group_size = 32) + unsigned int logical_sub_group_size = 32) { unsigned int id = g.get_local_linear_id(); unsigned int start_index = id / logical_sub_group_size * logical_sub_group_size; unsigned int target_offset = (id % logical_sub_group_size) ^ mask; return sycl::select_from_group(g, x, - target_offset < logical_sub_group_size - ? start_index + target_offset - : id); + target_offset < logical_sub_group_size + ? start_index + target_offset + : id); } template @@ -1796,14 +1787,14 @@ namespace dpct { return sycl::vec(val) .template as, int8_t, uint8_t>, 4>>() + std::conditional_t, int8_t, uint8_t>, 4>>() .template convert(); } template using dot_product_acc_t = - std::conditional_t && std::is_unsigned_v, - uint32_t, int32_t>; + std::conditional_t&& std::is_unsigned_v, + uint32_t, int32_t>; template inline auto dp4a(T1 a, T2 b, T3 c) @@ -1830,7 +1821,7 @@ namespace dpct template inline T vectorized_min(T a, T b) { - sycl::vec v0{a}, v1{b}; + sycl::vec v0{ a }, v1{ b }; auto v2 = v0.template as(); auto v3 = v1.template as(); auto v4 = sycl::min(v2, v3); @@ -1844,13 +1835,13 @@ namespace dpct inline double pow(const double a, const double b) { return sycl::pow(a, b); } template inline typename std::enable_if_t, T> - pow(const T a, const U b) + pow(const T a, const U b) { return sycl::pow(a, static_cast(b)); } template inline typename std::enable_if_t, double> - pow(const T a, const U b) + pow(const T a, const U b) { return sycl::pow(static_cast(a), static_cast(b)); } @@ -1977,10 +1968,10 @@ namespace dpct } inline void - has_capability_or_fail(const sycl::device &dev, - const std::initializer_list &props) + has_capability_or_fail(const sycl::device& dev, + const std::initializer_list& props) { - for (const auto &it : props) + for (const auto& it : props) { if (dev.has(it)) continue; @@ -1988,13 +1979,13 @@ namespace dpct { case sycl::aspect::fp64: throw std::runtime_error("'double' is not supported in '" + - dev.get_info() + - "' device"); + dev.get_info() + + "' device"); break; case sycl::aspect::fp16: throw std::runtime_error("'half' is not supported in '" + - dev.get_info() + - "' device"); + dev.get_info() + + "' device"); break; default: #define __SYCL_ASPECT(ASPECT, ID) \ @@ -2003,15 +1994,15 @@ namespace dpct #define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID) #define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE) auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string - { - switch (AspectNum) { + switch (AspectNum) + { #include #include - default: - return "unknown aspect"; - } - }; + default: + return "unknown aspect"; + } + }; #undef __SYCL_ASPECT_DEPRECATED_ALIAS #undef __SYCL_ASPECT_DEPRECATED #undef __SYCL_ASPECT @@ -2028,20 +2019,20 @@ namespace dpct return dev_mgr::instance().current_device_id(); } - static inline device_ext &get_current_device() + static inline device_ext& get_current_device() { return dev_mgr::instance().current_device(); } - static inline sycl::queue &get_in_order_queue() + static inline sycl::queue& get_in_order_queue() { return dev_mgr::instance().current_device().in_order_queue(); } static sycl::event - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size, - memcpy_direction direction, - const std::vector &dep_events = {}) + dpct_memcpy(sycl::queue& q, void* to_ptr, const void* from_ptr, size_t size, + memcpy_direction direction, + const std::vector& dep_events = {}) { if (!size) return sycl::event{}; @@ -2051,13 +2042,13 @@ namespace dpct // Get actual copy range and make sure it will not exceed range. static inline size_t get_copy_range(sycl::range<3> size, size_t slice, - size_t pitch) + size_t pitch) { return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0); } static inline size_t get_offset(sycl::id<3> id, size_t slice, - size_t pitch) + size_t pitch) { return slice * id.get(2) + pitch * id.get(1) + id.get(0); } @@ -2065,51 +2056,51 @@ namespace dpct /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr /// and \p from_range to another specified by \p to_ptr and \p to_range. static inline std::vector - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, - sycl::range<3> to_range, sycl::range<3> from_range, - sycl::id<3> to_id, sycl::id<3> from_id, - sycl::range<3> size, memcpy_direction direction, - const std::vector &dep_events = {}) + dpct_memcpy(sycl::queue& q, void* to_ptr, const void* from_ptr, + sycl::range<3> to_range, sycl::range<3> from_range, + sycl::id<3> to_id, sycl::id<3> from_id, + sycl::range<3> size, memcpy_direction direction, + const std::vector& dep_events = {}) { // RAII for host pointer class host_buffer { - void *_buf; + void* _buf; size_t _size; - sycl::queue &_q; - const std::vector &_deps; // free operation depends + sycl::queue& _q; + const std::vector& _deps; // free operation depends public: - host_buffer(size_t size, sycl::queue &q, - const std::vector &deps) + host_buffer(size_t size, sycl::queue& q, + const std::vector& deps) : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {} - void *get_ptr() const { return _buf; } + void* get_ptr() const { return _buf; } size_t get_size() const { return _size; } ~host_buffer() { if (_buf) { - _q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(_deps); - cgh.host_task([buf = _buf] { std::free(buf); }); }); + _q.submit([&](sycl::handler& cgh) + { + cgh.depends_on(_deps); + cgh.host_task([buf = _buf] { std::free(buf); }); }); } } }; std::vector event_list; size_t to_slice = to_range.get(1) * to_range.get(0), - from_slice = from_range.get(1) * from_range.get(0); - unsigned char *to_surface = - (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); - const unsigned char *from_surface = - (const unsigned char *)from_ptr + + from_slice = from_range.get(1) * from_range.get(0); + unsigned char* to_surface = + (unsigned char*)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); + const unsigned char* from_surface = + (const unsigned char*)from_ptr + get_offset(from_id, from_slice, from_range.get(0)); if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) { - return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), - direction, dep_events)}; + return { dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), + direction, dep_events) }; } direction = detail::deduce_memcpy_direction(q, to_ptr, from_ptr, direction); size_t size_slice = size.get(1) * size.get(0); @@ -2118,20 +2109,20 @@ namespace dpct case host_to_host: for (size_t z = 0; z < size.get(2); ++z) { - unsigned char *to_ptr = to_surface; - const unsigned char *from_ptr = from_surface; + unsigned char* to_ptr = to_surface; + const unsigned char* from_ptr = from_surface; if (to_range.get(0) == from_range.get(0) && to_range.get(0) == size.get(0)) { event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice, - direction, dep_events)); + direction, dep_events)); } else { for (size_t y = 0; y < size.get(1); ++y) { event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0), - direction, dep_events)); + direction, dep_events)); to_ptr += to_range.get(0); from_ptr += from_range.get(0); } @@ -2143,15 +2134,15 @@ namespace dpct case host_to_device: { host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q, - event_list); + event_list); std::vector host_events; if (to_slice == size_slice) { // Copy host data to a temp host buffer with the shape of target. host_events = dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range, - sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, - host_to_host, dep_events); + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, + host_to_host, dep_events); } else { @@ -2162,40 +2153,40 @@ namespace dpct // If has padding data, not sure whether it is useless. So fill temp // buffer with it. std::vector{ - dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(), - device_to_host, dep_events)}); + dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(), + device_to_host, dep_events)}); } // Copy from temp host buffer to device with only one submit. event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(), - buf.get_size(), host_to_device, - host_events)); + buf.get_size(), host_to_device, + host_events)); break; } case device_to_host: { host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q, - event_list); + event_list); // Copy from host temp buffer to host target with reshaping. event_list = dpct_memcpy( q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host, // Copy from device to temp host buffer with only one submit. std::vector{dpct_memcpy(q, buf.get_ptr(), from_surface, - buf.get_size(), - device_to_host, dep_events)}); + buf.get_size(), + device_to_host, dep_events)}); break; } case device_to_device: - event_list.push_back(q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(dep_events); - cgh.parallel_for( - size, - [=](sycl::id<3> id) { - to_surface[get_offset(id, to_slice, to_range.get(0))] = - from_surface[get_offset(id, from_slice, from_range.get(0))]; - }); })); - break; + event_list.push_back(q.submit([&](sycl::handler& cgh) + { + cgh.depends_on(dep_events); + cgh.parallel_for( + size, + [=](sycl::id<3> id) { + to_surface[get_offset(id, to_slice, to_range.get(0))] = + from_surface[get_offset(id, from_slice, from_range.get(0))]; + }); })); + break; default: throw std::runtime_error("dpct_memcpy: invalid direction value"); } @@ -2204,34 +2195,34 @@ namespace dpct /// memcpy 2D/3D matrix specified by pitched_data. static inline std::vector - dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id, - pitched_data from, sycl::id<3> from_id, sycl::range<3> size, - memcpy_direction direction = automatic) + dpct_memcpy(sycl::queue& q, pitched_data to, sycl::id<3> to_id, + pitched_data from, sycl::id<3> from_id, sycl::range<3> size, + memcpy_direction direction = automatic) { return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(), - sycl::range<3>(to.get_pitch(), to.get_y(), 1), - sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id, - size, direction); + sycl::range<3>(to.get_pitch(), to.get_y(), 1), + sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id, + size, direction); } /// memcpy 2D matrix with pitch. static inline std::vector - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, - size_t to_pitch, size_t from_pitch, size_t x, size_t y, - memcpy_direction direction = automatic) + dpct_memcpy(sycl::queue& q, void* to_ptr, const void* from_ptr, + size_t to_pitch, size_t from_pitch, size_t x, size_t y, + memcpy_direction direction = automatic) { return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1), - sycl::range<3>(from_pitch, y, 1), - sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), - sycl::range<3>(x, y, 1), direction); + sycl::range<3>(from_pitch, y, 1), + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), + sycl::range<3>(x, y, 1), direction); } - inline void gemm(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a, library_data_t a_type, - int lda, const void *b, library_data_t b_type, int ldb, - const void *beta, void *c, library_data_t c_type, int ldc, - library_data_t scaling_type) + inline void gemm(sycl::queue& q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void* alpha, const void* a, library_data_t a_type, + int lda, const void* b, library_data_t b_type, int ldb, + const void* beta, void* c, library_data_t c_type, int ldc, + library_data_t scaling_type) { if (scaling_type == library_data_t::real_float && c_type == library_data_t::complex_float) @@ -2239,7 +2230,7 @@ namespace dpct scaling_type = library_data_t::complex_float; } else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) + c_type == library_data_t::complex_double) { scaling_type = library_data_t::complex_double; } @@ -2248,114 +2239,114 @@ namespace dpct detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); switch (key) { - case detail::get_type_combination_id( - library_data_t::real_float, library_data_t::real_float, + case detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float, library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_double, library_data_t::real_double, + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, library_data_t::real_double, library_data_t::real_double): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, + { + detail::gemm_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, + { + detail::gemm_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_half): - { - detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, - lda, b, ldb, beta, c, ldc); - break; - } + { + detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, + lda, b, ldb, beta, c, ldc); + break; + } #ifdef __INTEL_MKL__ - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, + { + detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_float): - { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_value = - dpct::get_value(reinterpret_cast(beta), q); - sycl::half alpha_half(alpha_value); - sycl::half beta_half(beta_value); - detail::gemm_impl(q, a_trans, b_trans, m, n, k, &alpha_half, - a, lda, b, ldb, &beta_half, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, + { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_impl(q, a_trans, b_trans, m, n, k, &alpha_half, + a, lda, b, ldb, &beta_half, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, library_data_t::real_int32, library_data_t::real_int32): - { - float alpha_float = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_float = - dpct::get_value(reinterpret_cast(beta), q); - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc); - break; - } + { + float alpha_float = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_float = + dpct::get_value(reinterpret_cast(beta), q); + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc); + break; + } #endif // __INTEL_MKL__ - default: - throw std::runtime_error("the combination of data type is unsupported"); + default: + throw std::runtime_error("the combination of data type is unsupported"); } } // gemm() @@ -2379,13 +2370,13 @@ namespace dpct /// \param [in] ldc Leading dimension of C. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a[], - library_data_t a_type, int lda, const void *b[], - library_data_t b_type, int ldb, const void *beta, - void *c[], library_data_t c_type, int ldc, - int batch_size, library_data_t scaling_type) + inline void gemm_batch(sycl::queue& q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void* alpha, const void* a[], + library_data_t a_type, int lda, const void* b[], + library_data_t b_type, int ldb, const void* beta, + void* c[], library_data_t c_type, int ldc, + int batch_size, library_data_t scaling_type) { if (scaling_type == library_data_t::real_float && c_type == library_data_t::complex_float) @@ -2393,7 +2384,7 @@ namespace dpct scaling_type = library_data_t::complex_float; } else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) + c_type == library_data_t::complex_double) { scaling_type = library_data_t::complex_double; } @@ -2402,124 +2393,124 @@ namespace dpct detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); switch (key) { - case detail::get_type_combination_id( - library_data_t::real_float, library_data_t::real_float, + case detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float, library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_double, library_data_t::real_double, + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, library_data_t::real_double, library_data_t::real_double): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_half): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } #ifdef __INTEL_MKL__ - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - b, ldb, beta, c, ldc, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, + b, ldb, beta, c, ldc, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, library_data_t::real_int32, library_data_t::real_int32): - { - float alpha_float = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_float = - dpct::get_value(reinterpret_cast(beta), q); - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, - a, lda, b, ldb, &beta_float, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, + { + float alpha_float = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_float = + dpct::get_value(reinterpret_cast(beta), q); + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, + a, lda, b, ldb, &beta_float, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } #endif - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_float): - { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_value = - dpct::get_value(reinterpret_cast(beta), q); - sycl::half alpha_half(alpha_value); - sycl::half beta_half(beta_value); - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, - batch_size); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); + { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, + batch_size); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); } } @@ -2546,14 +2537,14 @@ namespace dpct /// \param [in] stride_c Stride between the different C matrices. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a, library_data_t a_type, - int lda, long long int stride_a, const void *b, - library_data_t b_type, int ldb, long long int stride_b, - const void *beta, void *c, library_data_t c_type, - int ldc, long long int stride_c, int batch_size, - library_data_t scaling_type) + inline void gemm_batch(sycl::queue& q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void* alpha, const void* a, library_data_t a_type, + int lda, long long int stride_a, const void* b, + library_data_t b_type, int ldb, long long int stride_b, + const void* beta, void* c, library_data_t c_type, + int ldc, long long int stride_c, int batch_size, + library_data_t scaling_type) { if (scaling_type == library_data_t::real_float && c_type == library_data_t::complex_float) @@ -2561,7 +2552,7 @@ namespace dpct scaling_type = library_data_t::complex_float; } else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) + c_type == library_data_t::complex_double) { scaling_type = library_data_t::complex_double; } @@ -2570,138 +2561,138 @@ namespace dpct detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); switch (key) { - case detail::get_type_combination_id( - library_data_t::real_float, library_data_t::real_float, + case detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float, library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_double, library_data_t::real_double, + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, library_data_t::real_double, library_data_t::real_double): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_half): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } #ifdef __INTEL_MKL__ - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - stride_a, b, ldb, stride_b, beta, c, ldc, - stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, library_data_t::real_int32, library_data_t::real_int32): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } #endif - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_float): - { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_value = - dpct::get_value(reinterpret_cast(beta), q); - sycl::half alpha_half(alpha_value); - sycl::half beta_half(beta_value); - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b, - &beta_half, c, ldc, stride_c, batch_size); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); + { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b, + &beta_half, c, ldc, stride_c, batch_size); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); } } static inline void - async_dpct_memcpy(void *to_ptr, size_t to_pitch, const void *from_ptr, - size_t from_pitch, size_t x, size_t y, - memcpy_direction direction = automatic, - sycl::queue &q = get_default_queue()) + async_dpct_memcpy(void* to_ptr, size_t to_pitch, const void* from_ptr, + size_t from_pitch, size_t x, size_t y, + memcpy_direction direction = automatic, + sycl::queue& q = get_default_queue()) { detail::dpct_memcpy(q, to_ptr, from_ptr, to_pitch, from_pitch, x, y, - direction); + direction); } using err0 = detail::generic_error_type; using err1 = detail::generic_error_type; - static inline void dpct_free(void *ptr, sycl::queue &q = get_default_queue()) { + static inline void dpct_free(void* ptr, sycl::queue& q = get_default_queue()) { detail::dpct_free(ptr, q); } @@ -2713,12 +2704,12 @@ namespace dpct using element_t = typename memory_t::element_t; using pointer_t = typename memory_t::pointer_t; using accessor_t = typename memory_t::template accessor_t<3>; - accessor(pointer_t data, const sycl::range<3> &in_range) + accessor(pointer_t data, const sycl::range<3>& in_range) : _data(data), _range(in_range) {} template - accessor(typename std::enable_if::type &acc) + accessor(typename std::enable_if::type& acc) : accessor(acc, acc.get_range()) {} - accessor(const accessor_t &acc, const sycl::range<3> &in_range) + accessor(const accessor_t& acc, const sycl::range<3>& in_range) : accessor(acc.get_pointer(), in_range) {} accessor operator[](size_t index) const { sycl::range<2> sub(_range.get(1), _range.get(2)); @@ -2737,12 +2728,12 @@ namespace dpct using element_t = typename memory_t::element_t; using pointer_t = typename memory_t::pointer_t; using accessor_t = typename memory_t::template accessor_t<2>; - accessor(pointer_t data, const sycl::range<2> &in_range) + accessor(pointer_t data, const sycl::range<2>& in_range) : _data(data), _range(in_range) {} template - accessor(typename std::enable_if::type &acc) + accessor(typename std::enable_if::type& acc) : accessor(acc, acc.get_range()) {} - accessor(const accessor_t &acc, const sycl::range<2> &in_range) + accessor(const accessor_t& acc, const sycl::range<2>& in_range) : accessor(acc.get_pointer(), in_range) {} pointer_t operator[](size_t index) const { @@ -2762,18 +2753,18 @@ namespace dpct public: using accessor_t = typename detail::memory_traits::template accessor_t; + T>::template accessor_t; using value_t = typename detail::memory_traits::value_t; using dpct_accessor_t = dpct::accessor; device_memory() : device_memory(sycl::range(1)) {} /// Constructor of 1-D array with initializer list - device_memory(const sycl::range &in_range, - std::initializer_list &&init_list) + device_memory(const sycl::range& in_range, + std::initializer_list&& init_list) : device_memory(in_range) { assert(init_list.size() <= in_range.size()); - _host_ptr = (value_t *)std::malloc(_size); + _host_ptr = (value_t*)std::malloc(_size); std::memset(_host_ptr, 0, _size); std::memcpy(_host_ptr, init_list.begin(), init_list.size() * sizeof(T)); } @@ -2781,23 +2772,23 @@ namespace dpct /// Constructor of 2-D array with initializer list template device_memory( - const typename std::enable_if>::type &in_range, - std::initializer_list> &&init_list) + const typename std::enable_if>::type& in_range, + std::initializer_list>&& init_list) : device_memory(in_range) { assert(init_list.size() <= in_range[0]); - _host_ptr = (value_t *)std::malloc(_size); + _host_ptr = (value_t*)std::malloc(_size); std::memset(_host_ptr, 0, _size); auto tmp_data = _host_ptr; for (auto sub_list : init_list) { assert(sub_list.size() <= in_range[1]); std::memcpy(tmp_data, sub_list.begin(), - sub_list.size() * sizeof(T)); + sub_list.size() * sizeof(T)); tmp_data += in_range[1]; } } /// Constructor with range - device_memory(const sycl::range &range_in) + device_memory(const sycl::range& range_in) : _size(range_in.size() * sizeof(T)), _range(range_in), _reference(false), _host_ptr(nullptr), _device_ptr(nullptr) { static_assert( @@ -2826,7 +2817,7 @@ namespace dpct void init() { init(dpct::get_default_queue()); } /// Allocate memory with specified queue, and init memory if has initial /// value. - void init(sycl::queue &q) { + void init(sycl::queue& q) { if (_device_ptr) return; if (!_size) @@ -2834,21 +2825,21 @@ namespace dpct allocate_device(q); if (_host_ptr) detail::dpct_memcpy(q, _device_ptr, _host_ptr, _size, - host_to_device); + host_to_device); } /// The variable is assigned to a device pointer. - void assign(value_t *src, size_t size) { + void assign(value_t* src, size_t size) { this->~device_memory(); new (this) device_memory(src, size); } /// Get memory pointer of the memory object, which is virtual pointer when /// usm is not used, and device pointer when usm is used. - value_t *get_ptr() { return get_ptr(get_default_queue()); } + value_t* get_ptr() { return get_ptr(get_default_queue()); } /// Get memory pointer of the memory object, which is virtual pointer when /// usm is not used, and device pointer when usm is used. - value_t *get_ptr(sycl::queue &q) { + value_t* get_ptr(sycl::queue& q) { init(q); return _device_ptr; } @@ -2857,7 +2848,7 @@ namespace dpct size_t get_size() { return _size; } template - typename std::enable_if::type &operator[](size_t index) { + typename std::enable_if::type& operator[](size_t index) { init(); return _device_ptr[index]; } @@ -2866,39 +2857,39 @@ namespace dpct /// when usm is used and dimension is greater than 1. template typename std::enable_if::type - get_access([[maybe_unused]] sycl::handler &cgh) { - return dpct_accessor_t((T *)_device_ptr, _range); + get_access([[maybe_unused]] sycl::handler& cgh) { + return dpct_accessor_t((T*)_device_ptr, _range); } private: - device_memory(value_t *memory_ptr, size_t size) + device_memory(value_t* memory_ptr, size_t size) : _size(size), _range(size / sizeof(T)), _reference(true), _device_ptr(memory_ptr) {} - void allocate_device(sycl::queue &q) { - #ifndef DPCT_USM_LEVEL_NONE + void allocate_device(sycl::queue& q) { +#ifndef DPCT_USM_LEVEL_NONE if (Memory == shared) { - _device_ptr = (value_t *)sycl::malloc_shared(_size, q.get_device(), - q.get_context()); + _device_ptr = (value_t*)sycl::malloc_shared(_size, q.get_device(), + q.get_context()); return; } - #ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY +#ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY if (Memory == constant) { - _device_ptr = (value_t *)sycl::malloc_device( + _device_ptr = (value_t*)sycl::malloc_device( _size, q.get_device(), q.get_context(), sycl::ext::oneapi::property::usm::device_read_only()); return; } - #endif - #endif - _device_ptr = (value_t *)detail::dpct_malloc(_size, q); +#endif +#endif + _device_ptr = (value_t*)detail::dpct_malloc(_size, q); } size_t _size; sycl::range _range; bool _reference; - value_t *_host_ptr; - value_t *_device_ptr; + value_t* _host_ptr; + value_t* _device_ptr; }; template class device_memory : public device_memory { @@ -2909,12 +2900,12 @@ namespace dpct typename detail::memory_traits::template accessor_t<0>; /// Constructor with initial value. - device_memory(const value_t &val) : base(sycl::range<1>(1), {val}) {} + device_memory(const value_t& val) : base(sycl::range<1>(1), { val }) {} /// Default constructor device_memory() : base(1) {} }; - } // namespace detail + } // namespace detail template using global_memory = detail::device_memory; @@ -2925,54 +2916,54 @@ namespace dpct template - inline T atomic_fetch_add(T *addr, T operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.fetch_add(operand); + sycl::access::address_space addressSpace = + sycl::access::address_space::global_space, + sycl::memory_order memoryOrder = sycl::memory_order::relaxed, + sycl::memory_scope memoryScope = sycl::memory_scope::device> + inline T atomic_fetch_add(T* addr, T operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_add(operand); } template - inline T1 atomic_fetch_add(T1 *addr, T2 operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.fetch_add(operand); + sycl::access::address_space::global_space, + sycl::memory_order memoryOrder = sycl::memory_order::relaxed, + sycl::memory_scope memoryScope = sycl::memory_scope::device, + typename T1, typename T2> + inline T1 atomic_fetch_add(T1* addr, T2 operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_add(operand); } template - inline T atomic_fetch_add(T *addr, T operand, - sycl::memory_order memoryOrder) { - switch (memoryOrder) { + sycl::access::address_space::global_space> + inline T atomic_fetch_add(T* addr, T operand, + sycl::memory_order memoryOrder) { + switch (memoryOrder) { case sycl::memory_order::relaxed: return atomic_fetch_add(addr, operand); + sycl::memory_scope::device>(addr, operand); case sycl::memory_order::acq_rel: return atomic_fetch_add(addr, operand); + sycl::memory_scope::device>(addr, operand); case sycl::memory_order::seq_cst: return atomic_fetch_add(addr, operand); + sycl::memory_scope::device>(addr, operand); default: assert(false && "Invalid memory_order for atomics. Valid memory_order for " - "atomics are: sycl::memory_order::relaxed, " - "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); + "atomics are: sycl::memory_order::relaxed, " + "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); } } template - inline T1 atomic_fetch_add(T1 *addr, T2 operand, - sycl::memory_order memoryOrder) { - atomic_fetch_add(addr, operand, memoryOrder); + sycl::access::address_space::global_space, + typename T1, typename T2> + inline T1 atomic_fetch_add(T1* addr, T2 operand, + sycl::memory_order memoryOrder) { + atomic_fetch_add(addr, operand, memoryOrder); } } // COPY from DPCT head files diff --git a/ggml.h b/ggml.h index 13502a3622fc4..2e8fd0dbc2e31 100644 --- a/ggml.h +++ b/ggml.h @@ -312,6 +312,12 @@ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ GGML_TENSOR_LOCALS(size_t, nb, dst, nb) +#define GGML_TENSOR_BINARY_OP_LOCALS01 \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) + #ifdef __cplusplus extern "C" { #endif