From a76c56fa1a3d27467eb97468d8c3b2fe1243b61a Mon Sep 17 00:00:00 2001 From: lhez Date: Fri, 13 Dec 2024 12:23:52 -0800 Subject: [PATCH] Introducing experimental OpenCL backend with support for Qualcomm Adreno GPUs (#10693) * [cl][adreno] Add Adreno GPU support Add new OpenCL backend to support Adreno GPUs --------- Co-authored-by: Skyler Szot Co-authored-by: Shangqing Gu Co-authored-by: Alexander Angus Co-authored-by: Hongqiang Wang Co-authored-by: Max Krasnyansky * [cl][ci] Add workflow for CL * [cl][adreno] Fix memory leak for non SMALL_ALLOC path * opencl: integrate backend dyn.load interface and fix compiler and format warnings * opencl: remove small-alloc support and fix build errors for non-opencl platforms * opencl: fixed merge conflict (MUSA added twice in cmake) * opencl-ci: use RUNNER_TEMP instead of github.workspace * opencl: fix embed tool invocation with python3 * opencl: CI workflow fixes * opencl: Clean up small-alloc in CMake files * opencl: cleanup ggml-opencl2 header file * opencl: use ulong for offsets and strides in ADD kernel * opencl: use cl_ulong for all offsets * opencl: use cl_ulong for sizes and strides * opencl: use `GGML_LOG_xxx` instead of `fprintf(stderr, ...)` * opencl: rename backend `opencl2` -> `opencl` * opencl: rename kernel files `ggml-opencl2` -> `ggml-opencl` * opencl: make OpenCL required, remove redundant lib and inc directories * `ggml-base`, `..` and `.` are added by `ggml_add_backend_library` * opencl: rename backend - funcs, structs, etc `opencl2` -> `opencl` * opencl: remove copyright marker since main license already covers * opencl: replace some more OPENCL2 leftovers * opencl: remove limits on `tensor_extra` * opencl: use pools for `tensor_extra` * opencl: fix compiler warnings with GCC and Clang Still getting the warning about clCreateCmdQueue being obsolete. Will fix that separately. * opencl: fail gracefully if opencl devices are not available Also for unsupported GPUs. * opencl: fix MSVC builds (string length error) * opencl: check for various requirements, allow deprecated API * opencl: update log message for unsupported GPUs --------- Co-authored-by: Skyler Szot Co-authored-by: Shangqing Gu Co-authored-by: Alexander Angus Co-authored-by: Hongqiang Wang Co-authored-by: Max Krasnyansky --- .github/workflows/build.yml | 26 +- ggml/CMakeLists.txt | 5 + ggml/include/ggml-opencl.h | 26 + ggml/src/CMakeLists.txt | 1 + ggml/src/ggml-backend-reg.cpp | 8 + ggml/src/ggml-opencl/CMakeLists.txt | 147 + ggml/src/ggml-opencl/ggml-opencl.cpp | 4004 +++++++++++++++++ ggml/src/ggml-opencl/kernels/embed_kernel.py | 26 + ggml/src/ggml-opencl/kernels/ggml-opencl.cl | 2683 +++++++++++ .../ggml-opencl/kernels/ggml-opencl_cvt.cl | 106 + .../kernels/ggml-opencl_gemv_noshuffle.cl | 265 ++ .../ggml-opencl_gemv_noshuffle_general.cl | 271 ++ .../src/ggml-opencl/kernels/ggml-opencl_mm.cl | 1225 +++++ .../kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl | 130 + .../kernels/ggml-opencl_transpose_16.cl | 32 + .../kernels/ggml-opencl_transpose_32.cl | 25 + .../kernels/ggml-opencl_transpose_32_16.cl | 35 + 17 files changed, 9014 insertions(+), 1 deletion(-) create mode 100644 ggml/include/ggml-opencl.h create mode 100644 ggml/src/ggml-opencl/CMakeLists.txt create mode 100644 ggml/src/ggml-opencl/ggml-opencl.cpp create mode 100644 ggml/src/ggml-opencl/kernels/embed_kernel.py create mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl.cl create mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl create mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl create mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl create mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl create mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl create mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl create mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl create mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 886d33d2d5640..022d31fb2434b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -662,6 +662,8 @@ jobs: defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DBUILD_SHARED_LIBS=ON' - build: 'msvc-arm64' defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-msvc.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DBUILD_SHARED_LIBS=ON' + - build: 'llvm-arm64-opencl-adreno' + defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON' steps: - name: Clone @@ -703,6 +705,28 @@ jobs: run: | choco install ninja + - name: Install OpenCL Headers and Libs + id: install_opencl + if: ${{ matrix.build == 'llvm-arm64-opencl-adreno' }} + run: | + git clone https://github.com/KhronosGroup/OpenCL-Headers + cd OpenCL-Headers + mkdir build && cd build + cmake .. ` + -DBUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF ` + -DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release" + cmake --build . --target install + git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader + cd OpenCL-ICD-Loader + mkdir build-arm64-release && cd build-arm64-release + cmake .. ` + -A arm64 ` + -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" ` + -DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release" + cmake --build . --target install --config release + - name: Build id: cmake_build run: | @@ -732,7 +756,7 @@ jobs: - name: Test id: cmake_test # not all machines have native AVX-512 - if: ${{ matrix.build != 'msvc-arm64' && matrix.build != 'llvm-arm64' && matrix.build != 'kompute-x64' && matrix.build != 'vulkan-x64' && (matrix.build != 'avx512-x64' || env.HAS_AVX512F == '1') }} + if: ${{ matrix.build != 'msvc-arm64' && matrix.build != 'llvm-arm64' && matrix.build != 'llvm-arm64-opencl-adreno' && matrix.build != 'kompute-x64' && matrix.build != 'vulkan-x64' && (matrix.build != 'avx512-x64' || env.HAS_AVX512F == '1') }} run: | cd build ctest -L main -C Release --verbose --timeout 900 diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index c91e931634255..3442142adbb43 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -179,6 +179,11 @@ set (GGML_SYCL_TARGET "INTEL" CACHE STRING set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING "ggml: sycl device architecture") +option(GGML_OPENCL "ggml: use OpenCL" OFF) +option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increases overhead)" OFF) +option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON) +option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON) + # extra artifacts option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE}) option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE}) diff --git a/ggml/include/ggml-opencl.h b/ggml/include/ggml-opencl.h new file mode 100644 index 0000000000000..6b61771358f87 --- /dev/null +++ b/ggml/include/ggml-opencl.h @@ -0,0 +1,26 @@ +#ifndef GGML_OPENCL_H +#define GGML_OPENCL_H + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// +// backend API +// +GGML_BACKEND_API ggml_backend_t ggml_backend_opencl_init(void); +GGML_BACKEND_API bool ggml_backend_is_opencl(ggml_backend_t backend); + +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type(void); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type(void); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_opencl_reg(void); + +#ifdef __cplusplus +} +#endif + +#endif // GGML_OPENCL_H diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 349f4c57fdf51..bf5ee5fc273fc 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -308,6 +308,7 @@ ggml_add_backend(MUSA) ggml_add_backend(RPC) ggml_add_backend(SYCL) ggml_add_backend(Vulkan) +ggml_add_backend(OpenCL) foreach (target ggml-base ggml) target_include_directories(${target} PUBLIC $ $) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index b2eded903cf91..66927148abcb9 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -46,6 +46,10 @@ #include "ggml-vulkan.h" #endif +#ifdef GGML_USE_OPENCL +#include "ggml-opencl.h" +#endif + #ifdef GGML_USE_BLAS #include "ggml-blas.h" #endif @@ -146,6 +150,9 @@ struct ggml_backend_registry { #ifdef GGML_USE_VULKAN register_backend(ggml_backend_vk_reg()); #endif +#ifdef GGML_USE_OPENCL + register_backend(ggml_backend_opencl_reg()); +#endif #ifdef GGML_USE_CANN register_backend(ggml_backend_cann_reg()); #endif @@ -539,6 +546,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) { ggml_backend_load_best("rpc", silent, dir_path); ggml_backend_load_best("sycl", silent, dir_path); ggml_backend_load_best("vulkan", silent, dir_path); + ggml_backend_load_best("opencl", silent, dir_path); ggml_backend_load_best("musa", silent, dir_path); ggml_backend_load_best("cpu", silent, dir_path); } diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt new file mode 100644 index 0000000000000..45328a6579320 --- /dev/null +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -0,0 +1,147 @@ +find_package(OpenCL REQUIRED) +find_package(Python3 REQUIRED) + +set(TARGET_NAME ggml-opencl) + +ggml_add_backend_library(${TARGET_NAME} + ggml-opencl.cpp + ../../include/ggml-opencl.h) +target_link_libraries(${TARGET_NAME} PRIVATE ${OpenCL_LIBRARIES}) +target_include_directories(${TARGET_NAME} PRIVATE ${OpenCL_INCLUDE_DIRS}) + +if (GGML_OPENCL_PROFILING) + message(STATUS "OpenCL profiling enabled (increases CPU overhead)") + add_compile_definitions(GGML_OPENCL_PROFILING) +endif () + +add_compile_definitions(GGML_OPENCL_SOA_Q) + +if (GGML_OPENCL_USE_ADRENO_KERNELS) + message(STATUS "OpenCL will use matmul kernels optimized for Adreno") + add_compile_definitions(GGML_OPENCL_USE_ADRENO_KERNELS) +endif () + +if (GGML_OPENCL_EMBED_KERNELS) + add_compile_definitions(GGML_OPENCL_EMBED_KERNELS) + + set(OPENCL_CL_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl.cl.h") + set(OPENCL_MM_CL_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_mm.cl.h") + set(OPENCL_CVT_CL_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_cvt.cl.h") + + set(OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_gemv_noshuffle.cl.h") + set(OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_gemv_noshuffle_general.cl.h") + set(OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_mul_mat_Ab_Bi_8x4.cl.h") + set(OPENCL_TRANSPOSE_16_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_16.cl.h") + set(OPENCL_TRANSPOSE_32_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_32.cl.h") + set(OPENCL_TRANSPOSE_32_16_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_32_16.cl.h") + + set(EMBED_KERNEL_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/kernels/embed_kernel.py") + file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated") + + include_directories("${CMAKE_BINARY_DIR}/autogenerated") + + # Python must be accessible from command line + add_custom_command( + OUTPUT ${OPENCL_CL_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl.cl + ${OPENCL_CL_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl.cl.h" + ) + + add_custom_command( + OUTPUT ${OPENCL_MM_CL_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_mm.cl + ${OPENCL_MM_CL_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl_mm.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl_mm.cl.h" + ) + + add_custom_command( + OUTPUT ${OPENCL_CVT_CL_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_cvt.cl + ${OPENCL_CVT_CL_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl_cvt.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl_cvt.cl.h" + ) + + add_custom_command( + OUTPUT ${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_gemv_noshuffle.cl + ${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl_gemv_noshuffle.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl_gemv_noshuffle.cl.h" + ) + + add_custom_command( + OUTPUT ${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_gemv_noshuffle_general.cl + ${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl_gemv_noshuffle_general.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl_gemv_noshuffle_general.cl.h" + ) + + add_custom_command( + OUTPUT ${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl + ${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl_mul_mat_Ab_Bi_8x4.cl.cl.h" + ) + + add_custom_command( + OUTPUT ${OPENCL_TRANSPOSE_16_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_16.cl + ${OPENCL_TRANSPOSE_16_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl_transpose_16.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl_transpose_16.cl.h" + ) + + add_custom_command( + OUTPUT ${OPENCL_TRANSPOSE_32_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_32.cl + ${OPENCL_TRANSPOSE_32_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl_transpose_32.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl_transpose_32.cl.h" + ) + + add_custom_command( + OUTPUT ${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_32_16.cl + ${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl_transpose_32_16.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl_transpose_32_16.cl.h" + ) + + target_sources(${TARGET_NAME} PRIVATE + ${OPENCL_CL_SOURCE_EMBED} + ${OPENCL_MM_CL_SOURCE_EMBED} + ${OPENCL_CVT_CL_SOURCE_EMBED} + ${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED} + ${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED} + ${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED} + ${OPENCL_TRANSPOSE_16_SOURCE_EMBED} + ${OPENCL_TRANSPOSE_32_SOURCE_EMBED} + ${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED}) +else () + # copy ggml-opencl.cl to bin directory + configure_file(kernels/ggml-opencl.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl.cl COPYONLY) + configure_file(kernels/ggml-opencl_mm.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_mm.cl COPYONLY) + configure_file(kernels/ggml-opencl_cvt.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_cvt.cl COPYONLY) + + configure_file(kernels/ggml-opencl_gemv_noshuffle.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_gemv_noshuffle.cl COPYONLY) + configure_file(kernels/ggml-opencl_gemv_noshuffle_general.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_gemv_noshuffle_general.cl COPYONLY) + configure_file(kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_mul_mat_Ab_Bi_8x4.cl COPYONLY) + configure_file(kernels/ggml-opencl_transpose_16.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_16.cl COPYONLY) + configure_file(kernels/ggml-opencl_transpose_32.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_32.cl COPYONLY) + configure_file(kernels/ggml-opencl_transpose_32_16.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_32_16.cl COPYONLY) +endif () diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp new file mode 100644 index 0000000000000..c77d629f08760 --- /dev/null +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -0,0 +1,4004 @@ +#define CL_TARGET_OPENCL_VERSION 220 +#define CL_USE_DEPRECATED_OPENCL_1_2_APIS + +// suppress warnings in CL headers for GCC and Clang +#pragma GCC diagnostic ignored "-Woverlength-strings" +#ifdef __clang__ +#pragma GCC diagnostic ignored "-Wgnu-anonymous-struct" +#endif + +#include "ggml-opencl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "ggml-backend-impl.h" +#include "ggml.h" + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +#define UNUSED(x) (void)(x) + +#define CL_CHECK(err) \ + do { \ + cl_int err_ = (err); \ + if (err_ != CL_SUCCESS) { \ + GGML_LOG_ERROR("ggml_opencl: %s error %d at %s:%d\n", \ + #err, err_, __FILE__, __LINE__); \ + GGML_ASSERT(0); \ + } \ + } while (0) + +//------------------------------------------------------------------------------ +// OpenCL +//------------------------------------------------------------------------------ + +bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor); + +enum GPU_FAMILY { + ADRENO, + INTEL, + UNKNOWN, +}; + +enum ADRENO_GPU_GEN { + ADRENO_UNKNOWN, + A7X, + A8X, + X1E, +}; + +static ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) { + if (strstr(device_name, "730") || + strstr(device_name, "740") || + strstr(device_name, "750")) { + return ADRENO_GPU_GEN::A7X; + } + + if (strstr(device_name, "830")) { + return ADRENO_GPU_GEN::A8X; + } + + if (strstr(device_name, "X1")) { + return ADRENO_GPU_GEN::X1E; + } + + return ADRENO_GPU_GEN::ADRENO_UNKNOWN; +} + +static int get_adreno_cl_compiler_version(const char *driver_version) { + std::string driver_ver_str(driver_version); + size_t compiler_ver_pos = driver_ver_str.find("E031"); + size_t compiler_ver_len = 13; + size_t compiler_ver_offset = 5; + + if (compiler_ver_pos == std::string::npos) { + compiler_ver_pos = driver_ver_str.find("DX"); + if (compiler_ver_pos == std::string::npos) { + return -1; + } + compiler_ver_len = 11; + compiler_ver_offset = 3; + } + + std::string compiler_ver_str = driver_ver_str.substr(compiler_ver_pos, compiler_ver_len); + std::string major_ver_str = compiler_ver_str.substr(compiler_ver_offset, 2); + return std::atoi(major_ver_str.c_str()); +} + +// backend device context +struct ggml_backend_opencl_device_context { + cl_platform_id platform; + std::string platform_name; + + cl_device_id device; + std::string device_name; +}; + +// backend context +struct ggml_backend_opencl_context { + cl_device_id device; + std::string device_name; + + std::string driver_version; + + GPU_FAMILY gpu_family; + ADRENO_GPU_GEN adreno_gen; + + cl_int alignment; + size_t max_alloc_size; + bool fp16_support; + + int adreno_wave_size; + + cl_context context; + cl_command_queue queue; + + cl_program program; + cl_program program_1; + cl_program program_2; + + cl_kernel kernel_add, kernel_add_row; + cl_kernel kernel_mul, kernel_mul_row; + cl_kernel kernel_scale; + cl_kernel kernel_silu, kernel_silu_4; + cl_kernel kernel_gelu, kernel_gelu_4; + cl_kernel kernel_relu; + cl_kernel kernel_clamp; + cl_kernel kernel_norm; + cl_kernel kernel_rms_norm; + cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; + cl_kernel kernel_soft_max, kernel_soft_max_4; + cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0; + cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; + cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32; + cl_kernel kernel_mul_mat_f32_f32; + cl_kernel kernel_mul_mat_f16_f16; + cl_kernel kernel_mul_mat_f16_f32_1row; + cl_kernel kernel_mul_mat_f16_f32; + cl_kernel kernel_mul_mat_f16_f32_l4; + cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; + cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0, kernel_mul_mat_q4_0_f32_flat; + cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; + cl_kernel kernel_convert_block_q4_0_noshuffle, kernel_mul_mat_q4_0_f32_flat_v0, + kernel_mul_mat_q4_0_f32_flat_img_v0; + cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; + cl_kernel kernel_mul_mv_q6_K_f32; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Transpose kernels + cl_program program_transpose_32; + cl_program program_transpose_32_16; + cl_program program_transpose_16; + cl_kernel kernel_transpose_32; + cl_kernel kernel_transpose_32_16; + cl_kernel kernel_transpose_16; + + cl_mem A_s_d_max; // max scale buffer size for transpose + cl_mem A_q_d_max; // max weight buffer size for transpose + cl_mem B_d_max; // max activation buffer size for transpose + + // Gemm and Gemv related programs, kernels, etc + cl_program program_CL_gemm; + cl_program program_CL_gemv_general; + cl_program program_CL_gemv_4096_1_11008; + cl_program program_CL_gemv_4096_1_4096; + cl_program program_CL_gemv_11008_1_4096; + cl_program program_CL_gemv_32000_1_4096; + cl_kernel CL_mul_mat_Ab_Bi_8x4; + cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general; + cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008; + cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; + cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096; + cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096; +#endif // GGML_OPENCL_USE_ADRENO_KERNELS +}; + +static ggml_backend_device g_ggml_backend_opencl_device; +static ggml_backend_opencl_device_context g_ggml_ctx_dev_main { + /*.platform =*/ nullptr, + /*.platform_nane =*/ "", + /*.device =*/ nullptr, + /*.device_name =*/ "", +}; + +static int ggml_backend_opencl_n_devices = 0; + +// Profiling +#ifdef GGML_OPENCL_PROFILING +struct ProfilingInfo { + std::string op_name; + std::string kernel_name; + // Kernel execution time in nanoseconds. + cl_ulong duration_ns; + // Global and local work sizes. + size_t global_size[3]; + size_t local_size[3]; + // Op output size. + size_t output_size[4]; +}; + +std::vector g_profiling_info; +#endif + +inline std::string read_file(const std::string &path) { + std::ifstream ifs(path); + if (!ifs) { + return ""; + } + std::string text; + ifs.seekg(0, std::ios::end); + text.resize(ifs.tellg()); + ifs.seekg(0, std::ios::beg); + ifs.read(&text[0], text.size()); + return text; +} + +static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer, const std::string &compile_opts) { + cl_program p; + char *program_log; + size_t program_size; + size_t log_size; + int err; + + program_size = strlen(program_buffer); + + p = clCreateProgramWithSource(ctx, 1, (const char**)&program_buffer, &program_size, &err); + if(err < 0) { + GGML_LOG_ERROR("OpenCL error creating program"); + exit(1); + } + + err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL); + if(err < 0) { + clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size); + program_log = (char*) malloc(log_size + 1); + program_log[log_size] = '\0'; + clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL); + GGML_LOG_ERROR("ggml_opencl: kernel compile error:\n\n%s\n", program_log); + free(program_log); + exit(1); + } + + return p; +} + +static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { + static bool initialized = false; + static ggml_backend_opencl_context *backend_ctx = nullptr; + + if (initialized) { + return backend_ctx; + } + + ggml_backend_opencl_device_context *dev_ctx = (ggml_backend_opencl_device_context *)dev->context; + GGML_ASSERT(dev_ctx); + GGML_ASSERT(dev_ctx->platform == nullptr); + GGML_ASSERT(dev_ctx->device == nullptr); + GGML_ASSERT(backend_ctx == nullptr); + + initialized = true; + backend_ctx = new ggml_backend_opencl_context(); + backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; + + cl_int err; + +#ifdef GGML_PROFILE_OPENCL + GGML_LOG_INFO("ggml_opencl: OpenCL profiling enabled\n"); +#endif + + struct cl_device; + struct cl_platform { + cl_platform_id id; + unsigned number; + char name[128]; + char vendor[128]; + struct cl_device * devices; + unsigned n_devices; + struct cl_device * default_device; + }; + + struct cl_device { + struct cl_platform * platform; + cl_device_id id; + unsigned number; + cl_device_type type; + char name[128]; + }; + + enum { NPLAT = 16, NDEV = 16 }; + + struct cl_platform platforms[NPLAT]; + unsigned n_platforms = 0; + struct cl_device devices[NDEV]; + unsigned n_devices = 0; + struct cl_device * default_device = NULL; + + cl_platform_id platform_ids[NPLAT]; + if (clGetPlatformIDs(NPLAT, platform_ids, &n_platforms) != CL_SUCCESS) { + GGML_LOG_ERROR("ggml_opencl: plaform IDs not available.\n"); + return backend_ctx; + } + + for (unsigned i = 0; i < n_platforms; i++) { + struct cl_platform * p = &platforms[i]; + p->number = i; + p->id = platform_ids[i]; + CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_NAME, sizeof(p->name), &p->name, NULL)); + CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_VENDOR, sizeof(p->vendor), &p->vendor, NULL)); + + cl_device_id device_ids[NDEV]; + cl_int clGetDeviceIDsError = clGetDeviceIDs(p->id, CL_DEVICE_TYPE_ALL, NDEV, device_ids, &p->n_devices); + if (clGetDeviceIDsError == CL_DEVICE_NOT_FOUND) { + p->n_devices = 0; + } else { + CL_CHECK(clGetDeviceIDsError); + } + p->devices = p->n_devices > 0 ? &devices[n_devices] : NULL; + p->default_device = NULL; + + for (unsigned j = 0; j < p->n_devices; j++) { + struct cl_device * d = &devices[n_devices]; + d->number = n_devices++; + d->id = device_ids[j]; + d->platform = p; + CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_NAME, sizeof(d->name), &d->name, NULL)); + CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_TYPE, sizeof(d->type), &d->type, NULL)); + + if (p->default_device == NULL && d->type == CL_DEVICE_TYPE_GPU) { + p->default_device = d; + } + } + + if (default_device == NULL && p->default_device != NULL) { + default_device = p->default_device; + } + } + + if (n_devices == 0) { + GGML_LOG_ERROR("ggml_opencl: could find any OpenCL devices.\n"); + return backend_ctx; + } + + char * user_platform_string = getenv("GGML_OPENCL_PLATFORM"); + char * user_device_string = getenv("GGML_OPENCL_DEVICE"); + int user_platform_number = -1; + int user_device_number = -1; + + unsigned n; + if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) { + user_platform_number = (int)n; + } + if (user_device_string != NULL && sscanf(user_device_string, " %u", &n) == 1 && n < n_devices) { + user_device_number = (int)n; + } + if (user_platform_number != -1 && user_device_number != -1) { + cl_platform* platform = &platforms[user_platform_number]; + if ((unsigned)user_device_number >= platform->n_devices) { + GGML_LOG_ERROR("ggml_opencl: invalid device number %d\n", user_device_number); + exit(1); + } + default_device = &platform->devices[user_device_number]; + } else { + + struct cl_device * selected_devices = devices; + unsigned n_selected_devices = n_devices; + + if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) { + for (unsigned i = 0; i < n_platforms; i++) { + struct cl_platform * p = &platforms[i]; + if (strstr(p->name, user_platform_string) != NULL || + strstr(p->vendor, user_platform_string) != NULL) { + user_platform_number = (int)i; + break; + } + } + if (user_platform_number == -1) { + GGML_LOG_ERROR("ggml_opencl: no platform matching '%s' was found.\n", user_platform_string); + exit(1); + } + } + if (user_platform_number != -1) { + struct cl_platform * p = &platforms[user_platform_number]; + selected_devices = p->devices; + n_selected_devices = p->n_devices; + default_device = p->default_device; + if (n_selected_devices == 0) { + GGML_LOG_ERROR("ggml_opencl: selected platform '%s' does not have any devices.\n", p->name); + exit(1); + } + } + + if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) { + for (unsigned i = 0; i < n_selected_devices; i++) { + struct cl_device * d = &selected_devices[i]; + if (strstr(d->name, user_device_string) != NULL) { + user_device_number = d->number; + break; + } + } + if (user_device_number == -1) { + GGML_LOG_ERROR("ggml_opencl: no device matching '%s' was found.\n", user_device_string); + exit(1); + } + } + if (user_device_number != -1) { + selected_devices = &devices[user_device_number]; + n_selected_devices = 1; + default_device = &selected_devices[0]; + } + + GGML_ASSERT(n_selected_devices > 0); + + if (default_device == NULL) { + default_device = &selected_devices[0]; + } + } + + GGML_LOG_INFO("ggml_opencl: selecting platform: '%s'\n", default_device->platform->name); + GGML_LOG_INFO("ggml_opencl: selecting device: '%s'\n", default_device->name); + if (default_device->type != CL_DEVICE_TYPE_GPU) { + GGML_LOG_WARN("ggml_opencl: warning, not a GPU: '%s'.\n", default_device->name); + } + + dev_ctx->platform = default_device->platform->id; + dev_ctx->device = default_device->id; + backend_ctx->device = default_device->id; + + if (strstr(default_device->name, "Adreno")) { + backend_ctx->gpu_family = GPU_FAMILY::ADRENO; + backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->name); + + // Default wave size is 128, A8x uses 64. + if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::A8X) { + backend_ctx->adreno_wave_size = 64; + } else if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::A7X || + backend_ctx->adreno_gen == ADRENO_GPU_GEN::X1E) { + backend_ctx->adreno_wave_size = 128; + } else { + backend_ctx->adreno_wave_size = 128; + GGML_LOG_WARN("ggml_opencl: Unsupported Adreno GPU: %s, " + "using wave size %d, " + "may not work as expected\n", + backend_ctx->device_name.c_str(), backend_ctx->adreno_wave_size); + } + } else if (strstr(default_device->name, "Intel")) { + backend_ctx->gpu_family = GPU_FAMILY::INTEL; + } else { + GGML_LOG_ERROR("Unsupported GPU: %s\n", default_device->name); + backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; + return backend_ctx; + } + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) { + GGML_LOG_ERROR("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; " + "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n"); + return backend_ctx; + } +#endif + + // Populate backend device name + dev_ctx->platform_name = default_device->platform->name; + dev_ctx->device_name = default_device->name; + backend_ctx->device_name = default_device->name; + + // A local ref of cl_device_id for convenience + cl_device_id device = backend_ctx->device; + + // Check device OpenCL version, OpenCL 2.0 or above is required + size_t device_ver_str_size; + clGetDeviceInfo(device, CL_DEVICE_VERSION, 0, NULL, &device_ver_str_size); + char *device_ver_buffer = (char *)alloca(device_ver_str_size + 1); + clGetDeviceInfo(device, CL_DEVICE_VERSION, device_ver_str_size, device_ver_buffer, NULL); + device_ver_buffer[device_ver_str_size] = '\0'; + GGML_LOG_INFO("ggml_opencl: device OpenCL version: %s\n", device_ver_buffer); + + if (strstr(device_ver_buffer, "OpenCL 2") == NULL && + strstr(device_ver_buffer, "OpenCL 3") == NULL) { + GGML_LOG_ERROR("ggml_opencl: OpenCL 2.0 or above is required\n"); + return backend_ctx; + } + + // Check driver version + size_t driver_version_str_size; + clGetDeviceInfo(device, CL_DRIVER_VERSION, 0, NULL, &driver_version_str_size); + char *driver_version = (char *)alloca(driver_version_str_size + 1); + clGetDeviceInfo(device, CL_DRIVER_VERSION, driver_version_str_size, driver_version, NULL); + driver_version[driver_version_str_size] = '\0'; + GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", driver_version); + backend_ctx->driver_version = driver_version; + + int adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); + bool has_vector_subgroup_broadcast = + adreno_cl_compiler_version >= 47 || adreno_cl_compiler_version == 17; + GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", + has_vector_subgroup_broadcast ? "true" : "false"); + + size_t ext_str_size; + clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); + char *ext_buffer = (char *)alloca(ext_str_size + 1); + clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); + ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated + // Check if ext_buffer contains cl_khr_fp16 + backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; + GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); + + // fp16 is required + if (!backend_ctx->fp16_support) { + GGML_LOG_ERROR("ggml_opencl: device does not support FP16\n"); + return backend_ctx; + } + + // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes + // optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x) + if (strstr(device_ver_buffer, "OpenCL 3") && + strstr(ext_buffer, "cl_khr_subgroups") == NULL && + strstr(ext_buffer, "cl_intel_subgroups") == NULL) { + GGML_LOG_ERROR("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) " + "(note that subgroups is an optional feature in OpenCL 3.0)\n"); + return backend_ctx; + } + + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &backend_ctx->alignment, NULL)); + GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", backend_ctx->alignment); + + clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL); + GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024); + + // Check SVM. + cl_device_svm_capabilities svm_caps; + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &svm_caps, 0)); + GGML_LOG_INFO("ggml_opencl: SVM coarse grain buffer support: %s\n", + svm_caps & CL_DEVICE_SVM_COARSE_GRAIN_BUFFER ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM fine grain buffer support: %s\n", + svm_caps & CL_DEVICE_SVM_FINE_GRAIN_BUFFER ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM fine grain system support: %s\n", + svm_caps & CL_DEVICE_SVM_FINE_GRAIN_SYSTEM ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n", + svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false"); + + // Print out configurations +#ifdef GGML_OPENCL_SOA_Q + GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n"); +#endif // GGML_OPENCL_SOA_Q + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + cl_context_properties properties[] = { + (intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)dev_ctx->platform, 0 + }; + + CL_CHECK((backend_ctx->context = clCreateContext(properties, 1, &device, NULL, NULL, &err), err)); + + // A local ref of cl_context for convenience + cl_context context = backend_ctx->context; + + //CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err), + // (err != CL_INVALID_QUEUE_PROPERTIES && err != CL_INVALID_VALUE ? err : + // (queue = clCreateCommandQueue(context, device, 0, &err), err) + //))); + cl_command_queue_properties command_queue_props = 0; +#ifdef GGML_OPENCL_PROFILING + command_queue_props |= CL_QUEUE_PROFILING_ENABLE; +#endif + CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err)); + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "ggml-opencl.cl.h" + }; +#else + const std::string kernel_src = read_file("ggml-opencl.cl"); +#endif + + std::string compile_opts = + "-cl-std=CL2.0 -cl-mad-enable -cl-unsafe-math-optimizations " + "-cl-finite-math-only -cl-fast-relaxed-math "; + backend_ctx->program = build_program_from_source(context, device, kernel_src.c_str(), compile_opts); + + // Non matmul kernels. + CL_CHECK((backend_ctx->kernel_get_rows_f32 = clCreateKernel(backend_ctx->program, "kernel_get_rows_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_get_rows_f16 = clCreateKernel(backend_ctx->program, "kernel_get_rows_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_get_rows_q4_0 = clCreateKernel(backend_ctx->program, "kernel_get_rows_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program, "kernel_add", &err), err)); + CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program, "kernel_add_row", &err), err)); + CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program, "kernel_mul", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program, "kernel_mul_row", &err), err)); + CL_CHECK((backend_ctx->kernel_scale = clCreateKernel(backend_ctx->program, "kernel_scale", &err), err)); + CL_CHECK((backend_ctx->kernel_silu = clCreateKernel(backend_ctx->program, "kernel_silu", &err), err)); + CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err)); + CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err)); + CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err)); + CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program, "kernel_rms_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_diag_mask_inf = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf", &err), err)); + CL_CHECK((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf_8", &err), err)); + CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program, "kernel_soft_max", &err), err)); + CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f32", &err), err)); + + // Matmul kernels. + CL_CHECK((backend_ctx->kernel_mul_mat_f32_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_1row = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32_1row", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_l4 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32_l4", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_v = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_v", &err), err)); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_flat", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program, "kernel_convert_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program, "kernel_restore_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_8x_flat", &err), err)); + + // Load additional mulmat kernels. +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_1 { + #include "ggml-opencl_mm.cl.h" + }; +#else + const std::string kernel_src_1 = read_file("ggml-opencl_mm.cl"); +#endif + backend_ctx->program_1 = build_program_from_source(context, device, kernel_src_1.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_1d_8x_flat", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_1d_16x_flat", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mv_q6_K_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat_v0 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_flat_v0", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat_img_v0 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_flat_img_v0", &err), err)); + + // Load additional data conversion kernels. +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_2 { + #include "ggml-opencl_cvt.cl.h" + }; +#else + const std::string kernel_src_2 = read_file("ggml-opencl_cvt.cl"); +#endif + backend_ctx->program_2 = build_program_from_source(context, device, kernel_src_2.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err)); + + // Kernels for Adreno +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string transpose_32_src { + #include "ggml-opencl_transpose_32.cl.h" + }; +#else + const std::string transpose_32_src = read_file("ggml-opencl_transpose_32.cl"); +#endif + backend_ctx->program_transpose_32 = build_program_from_source(context, device, transpose_32_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose_32, "kernel_transpose_32", &err), err)); + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string transpose_32_16_src { + #include "ggml-opencl_transpose_32_16.cl.h" + }; +#else + const std::string transpose_32_16_src = read_file("ggml-opencl_transpose_32_16.cl"); +#endif + backend_ctx->program_transpose_32_16 = build_program_from_source(context, device, transpose_32_16_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose_32_16, "kernel_transpose_32_16", &err), err)); + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string transpose_16_src { + #include "ggml-opencl_transpose_16.cl.h" + }; +#else + const std::string transpose_16_src = read_file("ggml-opencl_transpose_16.cl"); +#endif + backend_ctx->program_transpose_16 = build_program_from_source(context, device, transpose_16_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose_16, "kernel_transpose_16", &err), err)); + + // Gemv general + std::string CL_gemv_compile_opts = + " -cl-std=CL2.0 " + " -cl-mad-enable " + " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); + if (has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv_general { + #include "ggml-opencl_gemv_noshuffle_general.cl.h" + }; +#else + const std::string kernel_src_CL_gemv_general = read_file("ggml-opencl_gemv_noshuffle_general.cl"); +#endif + + backend_ctx->program_CL_gemv_general = build_program_from_source( + context, device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err)); + + // Gemv 2048, 16384 + CL_gemv_compile_opts = + " -cl-std=CL2.0 " + " -cl-mad-enable " + " -DLINE_STRIDE_A=2048 " + " -DBLOCK_STRIDE_A=16384 " + " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); + if (has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv { + #include "ggml-opencl_gemv_noshuffle.cl.h" + }; +#else + const std::string kernel_src_CL_gemv = read_file("ggml-opencl_gemv_noshuffle.cl"); +#endif + + backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source( + context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err)); + + // Gemv 2048, 16384 + CL_gemv_compile_opts = + " -cl-std=CL2.0 " + " -cl-mad-enable " + " -DLINE_STRIDE_A=2048 " + " -DBLOCK_STRIDE_A=16384 " + " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); + if (has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source( + context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err)); + + // Gemv 5504, 44032 + CL_gemv_compile_opts = + " -cl-std=CL2.0 " + " -cl-mad-enable " + " -DLINE_STRIDE_A=5504 " + " -DBLOCK_STRIDE_A=44032 " + " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); + if (has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source( + context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err)); + + // Gemv 16000, 128000 + CL_gemv_compile_opts = + " -cl-std=CL2.0 " + " -cl-mad-enable " + " -DLINE_STRIDE_A=16000 " + " -DBLOCK_STRIDE_A=128000 " + " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); + if (has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source(context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, "kernel_gemv_noshuffle", &err), err)); + + // Gemm +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemm { + #include "ggml-opencl_mul_mat_Ab_Bi_8x4.cl.h" + }; +#else + const std::string kernel_src_CL_gemm = read_file("ggml-opencl_mul_mat_Ab_Bi_8x4.cl"); +#endif + backend_ctx->program_CL_gemm = build_program_from_source(context, device, kernel_src_CL_gemm.c_str(), compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); + + // Allocate intermediate buffers and images + size_t max_A_q_d_bytes = 311164928; + size_t max_A_s_d_bytes = 38895616; + size_t max_B_d_bytes = 45088768; + + CL_CHECK((backend_ctx->A_q_d_max = clCreateBuffer(context, 0, max_A_q_d_bytes, NULL, &err), err)); + CL_CHECK((backend_ctx->A_s_d_max = clCreateBuffer(context, 0, max_A_s_d_bytes, NULL, &err), err)); + CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err)); +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + // For now we support a single devices + ggml_backend_opencl_n_devices = 1; + + return backend_ctx; +} + +static void ggml_cl2_free(void) { +#ifdef GGML_OPENCL_PROFILING + FILE * fperf = fopen("cl_profiling.csv", "w"); + if (!fperf) { + GGML_LOG_ERROR("Failed to open cl_profiling.csv\n"); + return; + } + + float total_kernel_time = 0; + fprintf(fperf, "op name, kernel name, duration (ms), global size, local size, output size\n"); + for (const ProfilingInfo & info : g_profiling_info) { + total_kernel_time += info.duration_ns/1.e6f; + fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n", + info.op_name.c_str(), info.kernel_name.c_str(), info.duration_ns/1.e6f, + info.global_size[0], info.global_size[1], info.global_size[2], + info.local_size[0], info.local_size[2], info.local_size[2], + info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]); + } + fclose(fperf); + + GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time); +#endif +} + +//------------------------------------------------------------------------------ +// Tensor extra management +//------------------------------------------------------------------------------ +struct ggml_tensor_extra_cl { + // The buffer object that holds the data. + cl_mem data_device; + // The offset into the buffer object. This is primarily for scratch buffer + // and view operation. + // NB: this offset no longer includes view offset (view_offs). Whenever this + // offset is used, view_offs should be considered. + cl_ulong offset; + // The actual size of the cl_mem object. This is needed when returning the + // block to the pool. + size_t actual_size; + + void reset() { + data_device = nullptr; + offset = 0; + actual_size = 0; + } +}; + +// Additional tensor extra structs for quantized tensors. +// These tensors are loaded from files and should not be allocated in scratch -- +// they should always be allocated from the pool. Hence, they do not have an +// `offset`, which indicate their locations in the scratch buffer. +struct ggml_tensor_extra_cl_q4_0 { + // Quantized values. + cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; + // Scales. + cl_mem d = nullptr; + // Scales in image1d_buffer_t. + cl_mem d_img = nullptr; + // Size of quantized values. + size_t size_q = 0; + // Size of scales. + size_t size_d = 0; + + ~ggml_tensor_extra_cl_q4_0() { + reset(); + } + + void reset() { + // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. + // They must be properly released so that the original buffer can be + // properly released to avoid memory leak. + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + // Currently, q_img and d_img are only initialized when SMALL_ALLOC is + // enabled. They point to the images in ggml_backend_opencl_buffer_context. + // So, there is no need to release them here. + // TODO: initialize them for non SMALL_PATH path, or remove them. + q_img = nullptr; + d_img = nullptr; + size_q = 0; + size_d = 0; + } +}; + +//------------------------------------------------------------------------------ +// Backend API +//------------------------------------------------------------------------------ + +// +// backend +// +static const char * ggml_backend_opencl_name(ggml_backend_t backend) { + return "OpenCL"; + + UNUSED(backend); +} + +static void ggml_backend_opencl_free(ggml_backend_t backend) { + ggml_cl2_free(); + + GGML_UNUSED(backend); +} + +static void ggml_backend_opencl_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_UNUSED(backend); + GGML_UNUSED(tensor); + GGML_UNUSED(data); + GGML_UNUSED(offset); + GGML_UNUSED(size); +} + +static void ggml_backend_opencl_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_UNUSED(backend); + GGML_UNUSED(tensor); + GGML_UNUSED(data); + GGML_UNUSED(offset); + GGML_UNUSED(size); +} + +static bool ggml_backend_opencl_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { + GGML_UNUSED(backend); + GGML_UNUSED(src); + GGML_UNUSED(dst); + return false; +} + +static void ggml_backend_opencl_synchronize(ggml_backend_t backend) { + GGML_UNUSED(backend); +} + +static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + + bool ok = ggml_cl_compute_forward(backend, node); + if (!ok) { + GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); + } + + return GGML_STATUS_SUCCESS; +} + +static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + GGML_UNUSED(dev); + + switch (op->op) { + case GGML_OP_NONE: + return true; + case GGML_OP_GET_ROWS: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + case GGML_TYPE_Q4_0: +#ifdef GGML_OPENCL_SOA_Q + // We do not support flattened Q4_0 (and possibly other Q's) + return false; +#else // GGML_OPENCL_SOA_Q + return true; +#endif // GGML_OPENCL_SOA_Q + default: + return false; + } + case GGML_OP_CPY: + case GGML_OP_DUP: + case GGML_OP_CONT: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + switch (op->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + default: + return false; + } + case GGML_TYPE_F16: + switch (op->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + default: + return false; + } + default: + return false; + } + case GGML_OP_ADD: + case GGML_OP_SCALE: + case GGML_OP_MUL: + return true; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_RELU: + return ggml_is_contiguous(op->src[0]); + default: + return false; + } + case GGML_OP_CLAMP: + case GGML_OP_SOFT_MAX: + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + return true; + case GGML_OP_MUL_MAT: + if (op->src[0]->type == GGML_TYPE_F16) { + return true; + } else if (op->src[0]->type == GGML_TYPE_F32) { + return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); + } else if (op->src[0]->type == GGML_TYPE_Q4_0 || + op->src[0]->type == GGML_TYPE_Q6_K) { + return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); + } + return false; + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + return true; + case GGML_OP_DIAG_MASK_INF: + return op->ne[3] == 1; + case GGML_OP_ROPE: + return true; + default: + return false; + } +} + +// Forward declaration - implementation appears later in the file. +static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer_type_t buffer_type); + +static ggml_guid_t ggml_backend_opencl_guid() { + static ggml_guid guid = { 0xde, 0xe0, 0x70, 0xa2, 0x73, 0x4e, 0x4d, 0xbc, 0xb0, 0xc7, 0x4f, 0xd4, 0x6d, 0x4e, 0x90, 0xfe }; + return &guid; +} + +static ggml_backend_i ggml_backend_opencl_i = { + /* .get_name = */ ggml_backend_opencl_name, + /* .free = */ ggml_backend_opencl_free, + /* .set_tensor_async = */ NULL, /* ggml_backend_opencl_set_tensor_async */ + /* .get_tensor_async = */ NULL, /* ggml_backend_opencl_get_tensor_async */ + /* .cpy_tensor_async = */ NULL, /* ggml_backend_opencl_cpy_tensor_async */ + /* .synchronize = */ NULL, /* ggml_backend_opencl_synchronize */ + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_opencl_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; + +ggml_backend_t ggml_backend_opencl_init(void) { + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_opencl_reg(), 0); + ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); + + ggml_backend_t backend = new ggml_backend { + /* .guid = */ ggml_backend_opencl_guid(), + /* .interface = */ ggml_backend_opencl_i, + /* .device = */ dev, + /* .context = */ backend_ctx + }; + + return backend; +} + +bool ggml_backend_is_opencl(ggml_backend_t backend) { + return backend && backend->iface.get_name == ggml_backend_opencl_name; +} + +// +// buffer +// +struct ggml_backend_opencl_buffer_context { + // A buffer context can hold multiple cl_mem objects. This is for flattening + // quantized weights and should be used with GGML_OPENCL_SMALL_ALLOC where + // each tensor is allocated a separate buffer. When flattening is enabled + // with small allocation, each tensor is backed by two cl_mem objects (for + // quants and scales) packed into a backend_opencl_buffer. + ggml_backend_opencl_buffer_context(cl_mem buf) + : name("OpenCL") { + buffer.push_back(buf); + } + + ~ggml_backend_opencl_buffer_context() { + for (cl_mem buf : buffer) { + CL_CHECK(clReleaseMemObject(buf)); + } + for (cl_mem im : img) { + CL_CHECK(clReleaseMemObject(im)); + } + + // Delete all extras to trigger their destructors + for (ggml_tensor_extra_cl * e : temp_tensor_extras) { + delete e; + } + for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) { + delete e; + } + for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0) { + delete e; + } + for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) { + delete e; + } + } + + ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() { + ggml_tensor_extra_cl * extra; + if (temp_tensor_extras.empty()) { + extra = new ggml_tensor_extra_cl(); + } else { + extra = temp_tensor_extras.back(); + temp_tensor_extras.pop_back(); + } + + temp_tensor_extras_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_q4_0 * ggml_opencl_alloc_temp_tensor_extra_q4_0() { + ggml_tensor_extra_cl_q4_0 * extra; + if (temp_tensor_extras_q4_0.empty()) { + extra = new ggml_tensor_extra_cl_q4_0(); + } else { + extra = temp_tensor_extras_q4_0.back(); + temp_tensor_extras_q4_0.pop_back(); + } + + temp_tensor_extras_q4_0_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + void reset() { + for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) { + temp_tensor_extras.push_back(e); + } + temp_tensor_extras_in_use.clear(); + + for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) { + temp_tensor_extras_q4_0.push_back(e); + } + temp_tensor_extras_q4_0_in_use.clear(); + } + + // Pools for extras. Available extras are in `temp_tensor_extras`. Extras + // being used are in `temp_tensor_extras_in_use`. At the first run, new + // extras get created and put in `in_use`. When the buffer is reset via + // the `reset` callback, all extras in `in_use` get moved to available extras + // for reuse. + std::vector temp_tensor_extras; + std::vector temp_tensor_extras_in_use; + std::vector temp_tensor_extras_q4_0; + std::vector temp_tensor_extras_q4_0_in_use; + + // The buffer_context is initially created by ggml_backend_buft_alloc_buffer + // before any tensor is initialized (at the beginning of alloc_tensor_range). + // Hence, there is alway a buffer object in this vector. When each tensor is + // being initialized, this original buffer object will be released if both + // flattening and small allocation are enabled, and additional buffer + // objects will be created in init_tensor to represent flattened quantized + // weights. + std::vector buffer; + // These are image1d_buffer_t objects that wrap around the quants and scales. + // For Q4_0 quantization, there should be two of them - one for quants and + // one for scales. They should be populated only when flattening and small + // allocation are enabled. + std::vector img; + std::string name; +}; + +static void * const cl_ptr_base = (void *)(uintptr_t) 0x1000; + +static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + delete ctx; +} + +static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) { + return cl_ptr_base; + + GGML_UNUSED(buffer); +} + +static void ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + + ggml_cl2_init(buffer->buft->device); + + if (tensor->view_src != nullptr) { + GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); + + ggml_tensor_extra_cl * view_extra = (ggml_tensor_extra_cl *) tensor->view_src->extra; + GGML_ASSERT(view_extra && "view_extra is nullptr?"); + + // Reuse extra of the parent tensor. The offset of this view tensor + // becomes `extra->offset + view_offs` and needs to be calculated when + // it is used. This changes is needed because of the change to + // ggml_alloc.c in https://github.com/ggerganov/llama.cpp/pull/7640. + // `buffer` passed in here will always be `tensor->buffer`. It is OK + // to allocate extras from the same buffer context for ordinary + // intermediate tensors. But for views into kv cache tensors, doing so + // would mess up the extras used by kv cache. + // Before #7640, `buffer` is for intermediate tensors, which is always + // different from that of kv cache tensors. + // + // NB: now extra->offset no longer accounts for view_offs. + // NB: this should not apply to weight tensors (for end-to-end runs, but + // may apply for test-backend-ops). + // FIXME: if any unexpected results are seen, double check the offset - + // there could be other places that need fix. + tensor->extra = view_extra; + } else { + { + size_t offset = (char *)tensor->data - (char *)cl_ptr_base; + + ggml_tensor_extra_cl * extra = ctx->ggml_opencl_alloc_temp_tensor_extra(); + extra->offset = offset; + extra->data_device = ctx->buffer[0]; + extra->actual_size = ggml_nbytes(tensor); + + tensor->extra = extra; + } + } +} + +// The optimized gemm and gemv kernels are used for large matrices without batch. +// tensor is the quantized weights matrix. +inline bool use_adreno_kernels(const ggml_tensor *tensor) { + return tensor->ne[0] >= 512 && tensor->ne[1] >= 512 && + tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); + + cl_context context = backend_ctx->context; + cl_command_queue queue = backend_ctx->queue; + +#ifdef GGML_OPENCL_SOA_Q + // We separate the quantized bits and scale from block_q4_0 by using an + // additional kernel, where each thread handles a block. We first read the + // original weights into a temporary buffer, then create two separate + // buffers for quantized bits and scales, which are then populated by the + // conversion kernel. + if (tensor->type == GGML_TYPE_Q4_0) { + // Tensors should have been preallocated, therefore they should + // already have ggml_tensor_extra_cl as extra. + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q4_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_0(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + // We consider the specified offset arg as always, although For weights + // the offset arg should be 0 (we do not assert this). + //GGML_ASSERT(offset == 0); + + // We create subbuffers from the original tensor buffer for scales and + // quants - i.e., scales and quants are aliases into the buffer obejct + // that backs the original tensor. This is a cleaner way to adapt to the + // new memory management. + // In the old code, we allocate new buffers for scales and quants + // respectively, which could still be done but would result in double + // allocation; properly deallocating the preallocated buffer that backs + // the tensors is tricky and would leak the backend specific information + // into the general backend code. + // Does this create misaligned subbuffers (alignment is 1024) in certain + // cases ? + cl_buffer_region region; + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, then quants. + // Create subbuffer for scales. + region.origin = extra_orig->offset + tensor->view_offs + offset; + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + // Create subbuffer for quants. + region.origin = extra_orig->offset + tensor->view_offs + offset + size_d; + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + //cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; + + // The optimized kernels need weights in natural order, so unshuffle. + if (use_adreno_kernels(tensor)) { + kernel = backend_ctx->kernel_convert_block_q4_0_noshuffle; + } + #else + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; + #endif // GGML_OPENCL_USE_ADRENO_KERNELS + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + + // transpose the weights and scales + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Only do transpose for large, non batched matrix + // TODO: use preallocated images instead of sub-buffer then image + if (use_adreno_kernels(tensor)) { + // <----------------------------------------------------------------------------------> // + // start transpose + // <----------------------------------------------------------------------------------> // + int M = tensor->ne[1]; // ne01 + int K = tensor->ne[0]; // ne00 + + // transpose is out of place, so we need to allocate transposed buffers + // <----------------------------------------------------------------------------------> // + // use sub_buffer of max buffer size instead + + size_t q_size_bytes = K * M / 8 * sizeof(float); + cl_buffer_region region; + region.origin = 0; + region.size = q_size_bytes; + cl_mem qT_d = clCreateSubBuffer( + backend_ctx->A_q_d_max, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &err); + // cl_mem qT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, q_size_bytes, NULL, &err); + CL_CHECK(err); + + // size_t d_size_bytes = M * (K / 32) / 2 * sizeof(float); + size_t d_size_bytes = M * (K / 32) * 2; + region.origin = 0; + region.size = d_size_bytes; + cl_mem dT_d = clCreateSubBuffer( + backend_ctx->A_s_d_max, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &err); + // cl_mem dT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, d_size_bytes, NULL, &err); + CL_CHECK(err); + + // <----------------------------------------------------------------------------------> // + + + // create images from the buffers + // <----------------------------------------------------------------------------------> // + cl_mem q_d_image1D; + cl_mem d_d_image1D; + cl_mem qT_d_image1D; + cl_mem dT_d_image1D; + + cl_image_format img_fmt_1d = { CL_RGBA, CL_FLOAT }; + cl_image_desc img_desc_1d; + + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 8 / 4; + img_desc_1d.buffer = extra->q; + q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + img_fmt_1d = { CL_RGBA, CL_FLOAT }; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 8 / 4; + img_desc_1d.buffer = qT_d; + qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + img_fmt_1d = { CL_RGBA, CL_FLOAT }; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 32 / 4 / 2; + img_desc_1d.buffer = extra->d; + d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + img_fmt_1d = { CL_RGBA, CL_FLOAT }; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 32 / 4 / 2; + img_desc_1d.buffer = dT_d; + dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + // <----------------------------------------------------------------------------------> // + + // set up and call the transpose kernels + // <----------------------------------------------------------------------------------> // + // weights + int height_q = M / 8; + int width_q = K / 8 / 4; + kernel = backend_ctx->kernel_transpose_16; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_q)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_q)); + + size_t local_size_q[3] = {4, 16, 1}; + size_t global_size_q[3] = {static_cast(width_q), static_cast(height_q), 1}; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + + // scales + int height_s = M / 8; + int width_s = K / 32 / 8; + + kernel = backend_ctx->kernel_transpose_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s)); + + size_t local_size_s[3] = {4, 16, 1}; + size_t global_size_s[3] = {static_cast(width_s), static_cast(height_s), 1}; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + // <----------------------------------------------------------------------------------> // + + // copy transposed buffer contents to original buffers + // <----------------------------------------------------------------------------------> // + // weights + CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + + // scales + CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + // <----------------------------------------------------------------------------------> // + + // deallocate transpose buffers + // <----------------------------------------------------------------------------------> // + CL_CHECK(clReleaseMemObject(qT_d)); + CL_CHECK(clReleaseMemObject(dT_d)); + + // deallocate temporary images + CL_CHECK(clReleaseMemObject(q_d_image1D)); + CL_CHECK(clReleaseMemObject(d_d_image1D)); + CL_CHECK(clReleaseMemObject(qT_d_image1D)); + CL_CHECK(clReleaseMemObject(dT_d_image1D)); + // <----------------------------------------------------------------------------------> // + // end transpose + // <----------------------------------------------------------------------------------> // + } + #endif // GGML_OPENCL_USE_ADRENO_KERNELS + + return; + } +#endif // GGML_OPENCL_SOA_Q + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + CL_CHECK(clEnqueueWriteBuffer( + queue, extra->data_device, CL_TRUE, extra->offset + offset, + size, data, 0, NULL, NULL)); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor->extra); + + ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); + + cl_context context = backend_ctx->context; + cl_command_queue queue = backend_ctx->queue; + + // Make sure all previously submitted commands are finished. + CL_CHECK(clFinish(queue)); + +#ifdef GGML_OPENCL_SOA_Q + // In end-to-end runs, get_tensor is usually used to get back the logits, + // where we can simply do clEnqueueReadBuffer since they are f32. + // However, in test-backend-ops, the GPU graph is copied to the CPU backend, + // which requires reading back quantized weight tensors. + // To properly support this, we need to restore block_q4_0 struct arrays + // from the flattened buffers. + if (tensor->type == GGML_TYPE_Q4_0) { + ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_SOA_Q + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + + CL_CHECK(clEnqueueReadBuffer( + queue, extra->data_device, CL_TRUE, extra->offset + tensor->view_offs + offset, + size, data, 0, NULL, NULL)); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_dev_t dev = buffer->buft->device; + ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); + cl_command_queue queue = backend_ctx->queue; + + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + for (cl_mem buf : ctx->buffer) { + CL_CHECK(clEnqueueFillBuffer(queue, buf, &value, sizeof(value), 0, buffer->size, 0, NULL, NULL)); + } + CL_CHECK(clFinish(queue)); +} + +static void ggml_backend_opencl_buffer_reset(ggml_backend_buffer_t buffer) { + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ctx->reset(); +} + +static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = { + /* .free_buffer = */ ggml_backend_opencl_buffer_free_buffer, + /* .get_base = */ ggml_backend_opencl_buffer_get_base, + /* .init_tensor = */ ggml_backend_opencl_buffer_init_tensor, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_opencl_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_opencl_buffer_get_tensor, + /* .cpy_tensor = */ NULL, + /* .clear = */ ggml_backend_opencl_buffer_clear, + /* .reset = */ ggml_backend_opencl_buffer_reset, +}; + +// +// buffer type +// + +static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer_type_t buffer_type) { + return "OpenCL"; + + GGML_UNUSED(buffer_type); +} + +static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buffer_type, size_t size) { + ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer_type->device); + + // clCreateBuffer returns -61 for size 0 + size = std::max(size, (size_t)1); + + cl_int err; + cl_mem mem = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, size, NULL, &err); + if (err != CL_SUCCESS) { + GGML_LOG_INFO("%s: failed to allocate %.2f MiB\n", __func__, size / 1024.0 / 1024.0); + return nullptr; + } + + ggml_backend_opencl_buffer_context * ctx = new ggml_backend_opencl_buffer_context(mem); + + return ggml_backend_buffer_init(buffer_type, ggml_backend_opencl_buffer_interface, ctx, size); +} + +static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) { + // FIXME: not thread safe, device may not be initialized yet + static cl_uint alignment = -1; + if (alignment == (cl_uint)-1) { + ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); + alignment = backend_ctx->alignment; + } + return alignment; +} + +static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { + static size_t max_size = -1; + if (max_size == (size_t)-1) { + ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); + max_size = backend_ctx->max_alloc_size; + } + return max_size; +} + +static bool ggml_backend_opencl_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { + return ggml_backend_is_opencl(backend); + + UNUSED(buft); +} + +static ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = { + /* .get_name = */ ggml_backend_opencl_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_opencl_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_opencl_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_opencl_buffer_type_get_max_size, + /* .get_alloc_size = */ NULL, + /* .is_host = */ NULL, +}; + +ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type() { + static ggml_backend_buffer_type buffer_type = { + /* .iface = */ ggml_backend_opencl_buffer_type_interface, + /* .device = */ &g_ggml_backend_opencl_device, + /* .context = */ nullptr, + }; + + return &buffer_type; +} + +// +// backend device +// + +static const char * ggml_backend_opencl_device_get_name(ggml_backend_dev_t dev) { + return "GPUOpenCL"; + + GGML_UNUSED(dev); +} + +static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_opencl_device_context *dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + return dev_ctx->device_name.c_str(); +} + +static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + *free = 1; + *total = 1; + + GGML_UNUSED(dev); +} + +static enum ggml_backend_dev_type ggml_backend_opencl_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_GPU; + + GGML_UNUSED(dev); +} + +static void ggml_backend_opencl_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_opencl_device_get_name(dev); + props->description = ggml_backend_opencl_device_get_description(dev); + props->type = ggml_backend_opencl_device_get_type(dev); + ggml_backend_opencl_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = ggml_backend_dev_caps { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, const char * params) { + ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(dev); + + ggml_backend_t backend = new ggml_backend { + /* .guid = */ ggml_backend_opencl_guid(), + /* .interface = */ ggml_backend_opencl_i, + /* .device = */ dev, + /* .context = */ backend_ctx, + }; + + return backend; + + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_opencl_device_get_buffer_type(ggml_backend_dev_t dev) { + return ggml_backend_opencl_buffer_type(); + + GGML_UNUSED(dev); +} + +static ggml_backend_buffer_t ggml_backend_opencl_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + GGML_UNUSED(dev); + GGML_UNUSED(ptr); + GGML_UNUSED(size); + GGML_UNUSED(max_tensor_size); + return nullptr; +} + +static bool ggml_backend_opencl_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + return ggml_opencl_supports_op(dev, op); +} + +static bool ggml_backend_opencl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_opencl_buffer_type_get_name; + + GGML_UNUSED(dev); +} + +static struct ggml_backend_device_i ggml_backend_opencl_device_i = { + /* .get_name = */ ggml_backend_opencl_device_get_name, + /* .get_description = */ ggml_backend_opencl_device_get_description, + /* .get_memory = */ ggml_backend_opencl_device_get_memory, + /* .get_type = */ ggml_backend_opencl_device_get_type, + /* .get_props = */ ggml_backend_opencl_device_get_props, + /* .init_backend = */ ggml_backend_opencl_device_init, + /* .get_buffer_type = */ ggml_backend_opencl_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_opencl_device_buffer_from_ptr, + /* .supports_op = */ ggml_backend_opencl_device_supports_op, + /* .supports_buft = */ ggml_backend_opencl_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// Backend registry + +static const char * ggml_backend_opencl_reg_get_name(ggml_backend_reg_t reg) { + return "OpenCL"; + + GGML_UNUSED(reg); +} + +static size_t ggml_backend_opencl_reg_device_count(ggml_backend_reg_t reg) { + return ggml_backend_opencl_n_devices; + + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_opencl_reg_device_get(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(index == 0); + + return &g_ggml_backend_opencl_device; + + GGML_UNUSED(reg); + GGML_UNUSED(index); +} + +static struct ggml_backend_reg_i ggml_backend_opencl_reg_i = { + /* .get_name = */ ggml_backend_opencl_reg_get_name, + /* .device_count = */ ggml_backend_opencl_reg_device_count, + /* .device_get = */ ggml_backend_opencl_reg_device_get, + /* .get_proc_address = */ NULL, +}; + +ggml_backend_reg_t ggml_backend_opencl_reg(void) { + // TODO: make this thread-safe somehow? + static ggml_backend_reg reg; + static bool initialized = false; + + if (!initialized) { + reg = ggml_backend_reg { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_opencl_reg_i, + /* .context = */ NULL, + }; + + g_ggml_backend_opencl_device = ggml_backend_device { + /* .iface = */ ggml_backend_opencl_device_i, + /* .reg = */ ®, + /* .context = */ &g_ggml_ctx_dev_main, + }; + + ggml_cl2_init(&g_ggml_backend_opencl_device); + + initialized = true; + } + + return ® +} + +GGML_BACKEND_DL_IMPL(ggml_backend_opencl_reg) + +//------------------------------------------------------------------------------ +// Debugging utils +//------------------------------------------------------------------------------ +#if 0 +#define QK4_0 32 +typedef struct { + ggml_fp16_t d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, + "wrong q4_0 block size/padding"); + +#include +#ifdef __cplusplus +#include "half.hpp" +#endif + +static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tensor) { + void * buf = malloc(ggml_nbytes(tensor)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; +#ifdef GGML_OPENCL_SOA_Q + void * buf_q; + void * buf_d; +#endif + +#ifdef GGML_USE_OPENCL + // Make sure everything is done. + CL_CHECK(clFinish(queue)); + +#ifdef GGML_OPENCL_SOA_Q + if (tensor->type == GGML_TYPE_Q4_0) { + ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *) tensor->extra; + GGML_ASSERT(extra); + + size_t size_q = ggml_nelements(tensor)/QK4_0 * QK4_0/2; + size_t size_d = ggml_nelements(tensor)/QK4_0 * sizeof(ggml_fp16_t); + GGML_ASSERT(size_q + size_d == ggml_nbytes(tensor)); + buf_q = malloc(size_q); + buf_d = malloc(size_d); + + CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL)); + CL_CHECK(clFinish(queue)); + } else { + // Read out the tensor from GPU memory. + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + CL_CHECK(clEnqueueReadBuffer(queue, extra->data_device, CL_TRUE, + extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL)); + CL_CHECK(clFinish(queue)); + } +#else + // Read out the tensor from GPU memory. + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + CL_CHECK(clEnqueueReadBuffer(queue, extra->data_device, CL_TRUE, + extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL)); + CL_CHECK(clFinish(queue)); +#endif // GGML_OPENCL_SOA_Q +#endif // GGML_USE_OPENCL + + // Open file and dump. + char fname[512]; + sprintf(fname, "./tensor-dumps/%s.txt", tensor->name); + FILE * f = fopen(fname, "w"); + if (!f) { + printf("Failed to open %s\n", fname); + return; + } + + if (tensor->type == GGML_TYPE_F32) { + float * data = (float *) buf; + for (int i = 0; i < ggml_nelements(tensor); ++i) { + if (isnan(data[i])) { + printf("NaN found: %s\n", tensor->name); + break; + } + fprintf(f, "%f\n", data[i]); + } + } else if (tensor->type == GGML_TYPE_I32) { + int * data = (int *) buf; + for (int i = 0; i < ggml_nelements(tensor); ++i) { + if (isnan(data[i])) { + printf("NaN found: %s\n", tensor->name); + break; + } + fprintf(f, "%d\n", data[i]); + } + } else if (tensor->type == GGML_TYPE_F16) { +#ifdef __cplusplus + half_float::half * data = (half_float::half *) buf; + for (int i = 0; i < ggml_nelements(tensor); ++i) { + if (std::isnan(data[i])) { + printf("NaN found: %s\n", tensor->name); + break; + } + fprintf(f, "%f\n", float(data[i])); + } +#endif + } else if (tensor->type == GGML_TYPE_Q4_0) { +#ifdef GGML_OPENCL_SOA_Q + ggml_fp16_t * data_d = (ggml_fp16_t *)buf_d; + unsigned char * data_q = (unsigned char *)buf_q; + + for (int i = 0; i < ggml_nelements(tensor)/QK4_0; ++i) { + fprintf(f, "%04x, ", data_d[i]); + for (int k = 0; k < QK4_0/2; ++k) { + fprintf(f, "%02x, ", data_q[k]); + } + fprintf(f, "\n"); + data_q += QK4_0/2; + } + free(buf_d); + free(buf_q); +#else + block_q4_0 * data = (block_q4_0 *) buf; + for (int i = 0; i < ggml_nelements(tensor)/QK4_0; ++i) { + fprintf(f, "%04x, ", data[i].d); + for (int k = 0; k < QK4_0/2; ++k) { + fprintf(f, "%02x, ", data[i].qs[k]); + } + fprintf(f, "\n"); + } +#endif // GGML_OPENCL_SOA_Q + } + free(buf); + fflush(f); + fclose(f); +} +#else +#define dump_tensor(tensor) +#endif + +//------------------------------------------------------------------------------ +// Profiling utility +//------------------------------------------------------------------------------ +#ifdef GGML_OPENCL_PROFILING +void populateProfilingInfo( + ProfilingInfo& info, cl_event evt, cl_kernel kernel, + size_t global_size[3], size_t local_size[3], + const ggml_tensor * tensor) { + cl_ulong start; + cl_ulong end; + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clGetEventProfilingInfo( + evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &start, NULL)); + CL_CHECK(clGetEventProfilingInfo( + evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &end, NULL)); + + char kernel_name[512]; + CL_CHECK(clGetKernelInfo(kernel, CL_KERNEL_FUNCTION_NAME, + sizeof(kernel_name), kernel_name, NULL)); + + info.duration_ns = end - start; + info.op_name = tensor->name; + info.kernel_name = kernel_name; + info.local_size[0] = local_size[0]; + info.local_size[1] = local_size[1]; + info.local_size[2] = local_size[2]; + info.global_size[0] = global_size[0]; + info.global_size[1] = global_size[1]; + info.global_size[2] = global_size[2]; + info.output_size[0] = tensor->ne[0]; + info.output_size[1] = tensor->ne[1]; + info.output_size[2] = tensor->ne[2]; + info.output_size[3] = tensor->ne[3]; +} +#endif + +//------------------------------------------------------------------------------ +// Ops +//------------------------------------------------------------------------------ + +static bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + // TODO: find the optimal values for these + return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && + src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32); +} + +static void ggml_cl_nop(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + UNUSED(backend); + UNUSED(src0); + UNUSED(src1); + UNUSED(dst); +} + +static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const int ne00 = src0 ? src0->ne[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const int ne10 = src1 ? src1->ne[0] : 0; + const cl_ulong nb10 = src1 ? src1->nb[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; + const cl_ulong nb11 = src1 ? src1->nb[1] : 0; + const cl_ulong nb1 = dst ? dst->nb[1] : 0; + const cl_ulong nb2 = dst ? dst->nb[2] : 0; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_get_rows_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_get_rows_f16; + break; + case GGML_TYPE_Q4_0: + kernel = backend_ctx->kernel_get_rows_q4_0; + break; + default: + GGML_ASSERT(false && "not implemented"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2)); + + size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1}; + size_t local_work_size[] = {1, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb00 = src0 ? src0->nb[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + const int ne10 = src1 ? src1->ne[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; + const int ne12 = src1 ? src1->ne[2] : 0; + const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + + const cl_ulong nb10 = src1 ? src1->nb[0] : 0; + const cl_ulong nb11 = src1 ? src1->nb[1] : 0; + const cl_ulong nb12 = src1 ? src1->nb[2] : 0; + const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); + + const int ne0 = dst ? dst->ne[0] : 0; + const int ne1 = dst ? dst->ne[1] : 0; + const int ne2 = dst ? dst->ne[2] : 0; + const int ne3 = dst ? dst->ne[3] : 0; + + const cl_ulong nb0 = dst ? dst->nb[0] : 0; + const cl_ulong nb1 = dst ? dst->nb[1] : 0; + const cl_ulong nb2 = dst ? dst->nb[2] : 0; + const cl_ulong nb3 = dst ? dst->nb[3] : 0; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + bool bcast_row = false; + cl_kernel kernel; + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + bcast_row = true; + int ne = ne00 / 4; + kernel = backend_ctx->kernel_add_row; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + } else { + kernel = backend_ctx->kernel_add; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3)); + } + + if (bcast_row) { + int n = ggml_nelements(dst)/4; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else { + unsigned int nth = MIN(64, ne0); + size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb00 = src0 ? src0->nb[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + const int ne10 = src1 ? src1->ne[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; + const int ne12 = src1 ? src1->ne[2] : 0; + const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + + const cl_ulong nb10 = src1 ? src1->nb[0] : 0; + const cl_ulong nb11 = src1 ? src1->nb[1] : 0; + const cl_ulong nb12 = src1 ? src1->nb[2] : 0; + const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); + + const int ne0 = dst ? dst->ne[0] : 0; + const int ne1 = dst ? dst->ne[1] : 0; + const int ne2 = dst ? dst->ne[2] : 0; + const int ne3 = dst ? dst->ne[3] : 0; + + const cl_ulong nb0 = dst ? dst->nb[0] : 0; + const cl_ulong nb1 = dst ? dst->nb[1] : 0; + const cl_ulong nb2 = dst ? dst->nb[2] : 0; + const cl_ulong nb3 = dst ? dst->nb[3] : 0; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + bool bcast_row = false; + cl_kernel kernel; + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + bcast_row = true; + int ne = ne00 / 4; + kernel = backend_ctx->kernel_mul_row; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + } else { + kernel = backend_ctx->kernel_mul; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3)); + } + + if (bcast_row) { + int n = ggml_nelements(dst)/4; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else { + unsigned int nth = MIN(64, ne0); + size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + int n = ggml_nelements(dst); + + if (n % 4 == 0) { + kernel = backend_ctx->kernel_gelu_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_gelu; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL); +#endif +} + +static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + int n = ggml_nelements(dst); + + if (n % 4 == 0) { + kernel = backend_ctx->kernel_silu_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_silu; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_relu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel = backend_ctx->kernel_relu; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + const int64_t n = ggml_nelements(dst); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_clamp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + float min; + float max; + memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); + + cl_kernel kernel = backend_ctx->kernel_clamp; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &min)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &max)); + + const int64_t n = ggml_nelements(dst); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int ne00 = src0 ? src0->ne[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + const int nth = MIN(64, ne00); + + cl_kernel kernel = backend_ctx->kernel_norm; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth, NULL)); + + const int64_t nrows = ggml_nrows(src0); + + size_t global_work_size[] = {(size_t)nrows*nth, 1, 1}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_backend_opencl_device_context * dev_ctx = + (ggml_backend_opencl_device_context *)backend->device->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int ne00 = src0 ? src0->ne[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + const int nth = MIN(64, ne00); + + const int64_t nrows = ggml_nrows(src0); + + size_t global_work_size[] = {(size_t)nrows*nth, 1, 1}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + cl_kernel kernel = backend_ctx->kernel_rms_norm; + + // Note, this kernel declares local memory in kernel args and the size + // depends on subgroup size. + // Retrieve subgroup size. + // Note, this requires OpenCL 2.1 and above + size_t sgs; + CL_CHECK(clGetKernelSubGroupInfo(kernel, dev_ctx->device, + CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, + sizeof(local_work_size), local_work_size, + sizeof(size_t), &sgs, NULL)); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps)); + // This is local memory - the size depends on subgroup size. + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth/sgs, NULL)); + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; + const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + +#ifdef GGML_OPENCL_SOA_Q + ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; +#endif + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb00 = src0 ? src0->nb[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + const int ne10 = src1 ? src1->ne[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; + const int ne12 = src1 ? src1->ne[2] : 0; + const int ne13 = src1 ? src1->ne[3] : 0; + + const cl_ulong nb10 = src1 ? src1->nb[0] : 0; + const cl_ulong nb11 = src1 ? src1->nb[1] : 0; + const cl_ulong nb12 = src1 ? src1->nb[2] : 0; + const cl_ulong nb13 = src1 ? src1->nb[3] : 0; + + const int ne0 = dst ? dst->ne[0] : 0; + const int ne1 = dst ? dst->ne[1] : 0; + + int r2 = ne12/ne02; + int r3 = ne13/ne03; + + GGML_ASSERT(ne00 == ne10); + + int nth0 = 32; + int nth1 = 1; + int nrows = 1; + // The number of values produced by each subgroup + int ndst = 4; + + cl_kernel kernel; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_context context = backend_ctx->context; + + if (ne01 && ne1 && use_adreno_kernels(src0)) { + + // init CL objects + // <--------------------------------------------> // + cl_int status; + cl_image_format img_fmt_1d; + cl_image_desc img_desc_1d; + cl_buffer_region region; + cl_mem A_image1d; + cl_mem B_image1d; + cl_mem B_sub_buffer; + cl_mem C_d; + // for B transpose + cl_mem B_d; + cl_mem B_d_input_image; + // <--------------------------------------------> // + + // define matrix dimensions + // <--------------------------------------------> // + int M = ne01; + int N = ne1; + int K = ne00; + int padding; + // <--------------------------------------------> // + + // q4_0 x fp32 + if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) { + // TODO: remove duplicate definitions of image description + format -- move to top + + // create an image for A + // <--------------------------------------------> // + if (N == 1) { + img_fmt_1d = { CL_R, CL_UNSIGNED_INT32}; + } else { + img_fmt_1d = { CL_R, CL_FLOAT}; + } + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 2 / 4; // Divide by 4 for char -> float + img_desc_1d.buffer = extra0_q4_0->q; + A_image1d = clCreateImage( + context, + CL_MEM_READ_ONLY, + &img_fmt_1d, + &img_desc_1d, + NULL, + &status); + CL_CHECK(status); + // <--------------------------------------------> // + + + // create a sub_buffer for B + // <--------------------------------------------> // + region.origin = (extra1->offset); + region.size = K * N * sizeof(float); + B_sub_buffer = clCreateSubBuffer( + extra1->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // <--------------------------------------------> // + + // transpose activation for Skyler's gemm + if (N != 1) { + //how many extra elements beyond multiple of 8 + int extra_elements = N % 8; + + //how much padding to add + padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // Specify the starting offset (in bytes) + region.origin = 0; + // Specify the size of the sub-buffer (divide by 2 for FP16) + region.size = K * (N + padding) * sizeof(float)/2; + B_d = clCreateSubBuffer( + backend_ctx->B_d_max, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + + cl_image_format image_format_B_d_input = { CL_RGBA, CL_FLOAT }; + cl_image_desc image_desc_B_d_input = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(K * N / 4), + 0, 0, 0, 0, 0, 0, 0, { B_sub_buffer } + }; + B_d_input_image = clCreateImage( + context, + 0, + &image_format_B_d_input, + &image_desc_B_d_input, + NULL, + &status); + CL_CHECK(status); + + cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16) + cl_image_desc image_desc_B_d_output = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(K * (N + padding)/4), + 0, 0, 0, 0, 0, 0, 0, { B_d } + }; + B_image1d = clCreateImage( + context, + 0, + &image_format_B_d_output, + &image_desc_B_d_output, + NULL, + &status); + CL_CHECK(status); + + int height_B = N/4; + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_d_input_image)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_size_t[2] = { 1, 16 }; + //WGS tuning + if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { + local_size_t[0]=4; + local_size_t[1]=8; + } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { + local_size_t[0]=2; + local_size_t[1]=8; + } else if(ne0 == 4096 && ne1 == 128 && ne10 == 11008) { + local_size_t[0]=1; + local_size_t[1]=8; + } else if(ne0 == 32000 && ne1 == 128 && ne10 == 4096) { + local_size_t[0]=2; + local_size_t[1]=8; + } + + size_t global_size_t[2] = { + static_cast(width_B), + static_cast(padded_height_B) + }; + + #ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_size_t, local_size_t, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_size_t, local_size_t, dst); + #else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_size_t, local_size_t, 0, NULL, NULL)); + #endif + } else { + // no need to transpose B in other cases + // create an image for B from sub_buffer + // <--------------------------------------------> // + img_fmt_1d = {CL_RGBA, CL_FLOAT}; + + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_width = K * N / 4; + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.buffer = B_sub_buffer; + B_image1d = clCreateImage( + context, + CL_MEM_READ_ONLY, + &img_fmt_1d, + &img_desc_1d, + NULL, + &status); + CL_CHECK(status); + // <--------------------------------------------> // + } + + // choose gemm or gemv kernel + // <--------------------------------------------> // + if (N == 1) { + kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general; + if (M == 4096 && K == 4096) { + kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; + } else if (M == 4096 && K == 11008) { + kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008; + } else if (M == 11008 && K == 4096) { + kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096; + } else if (M == 32000 && K == 4096) { + kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096; + } + } else { + kernel = backend_ctx->CL_mul_mat_Ab_Bi_8x4; + } + // <--------------------------------------------> // + + // set kernel args + // <--------------------------------------------> // + cl_uint k_arg = 0; + + if (N == 1) { + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_image1d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extra1->offset)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extrad->offset)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r3)); + } else { + region.origin = extrad->offset; // Specify the starting offset (in bytes) + region.size = M * N * sizeof(float); // Specify the size of the sub-buffer + C_d = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + int padded_N = ne1 + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); //A_q_dextra0_q4_0->q + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); //A_s_d + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &B_image1d)); //B_d + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &C_d)); //C_d + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); //M + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &padded_N)); //N with padding + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); //K + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne1)); //N without padding + } + // <--------------------------------------------> // + + // choose workgroup size + // <--------------------------------------------> // + size_t global_work_size[3] = { + 64, static_cast((M+63)/64), static_cast((N+31)/32)}; + size_t local_work_size[3] = {64, 2, 4}; + + global_work_size[0] = (size_t)(ceil((float)ne1/8)); + global_work_size[1] = (size_t)(ne01/4); + global_work_size[2] = (size_t)(1); + + local_work_size[0] = (size_t)(1); //4x32 for FP32 + local_work_size[1] = (size_t)(128); + local_work_size[2] = (size_t)(1); + + //WGS tuning + if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 1; + local_work_size[1] = 128; + } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } else if (ne0 == 4096 && ne1 == 128 && ne10 == 11008) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } else if (ne0 == 32000 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } + + if (N == 1) { + local_work_size[0] = backend_ctx->adreno_wave_size; // localsize + local_work_size[1] = 4; // reduce factor + local_work_size[2] = 1; + + global_work_size[0] = M / 2; + global_work_size[1] = 4; // reduce factor + global_work_size[2] = 1; + } + // <--------------------------------------------> // + + // enqueue kernel with profiling + // <--------------------------------------------> // + #ifdef GGML_OPENCL_PROFILING + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); + // enqueue kernel without profiling + #else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + #endif + // <--------------------------------------------> // + + // deallocate sub buffers and images + // <--------------------------------------------> // + CL_CHECK(clReleaseMemObject(A_image1d)); + CL_CHECK(clReleaseMemObject(B_sub_buffer)); + CL_CHECK(clReleaseMemObject(B_image1d)); + + if (N != 1) { + CL_CHECK(clReleaseMemObject(B_d)); + CL_CHECK(clReleaseMemObject(B_d_input_image)); + CL_CHECK(clReleaseMemObject(C_d)); + } + // <--------------------------------------------> // + + return; + } + } // if (ne01 && ne1) +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + if (!ggml_is_transposed(src0) && + !ggml_is_transposed(src1) && + src1t == GGML_TYPE_F32 && + ne00%32 == 0 && + ne11 > 2) { +#ifdef GGML_OPENCL_SOA_Q + // Set up kernel. + switch(src0t) { + case GGML_TYPE_Q4_0: + // This should have been satisfied. + GGML_ASSERT(ne11 == ne1); + GGML_ASSERT(ne01 == ne0); + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); + break; + default: + break; + } + + // Launch kernel. + if (src0t == GGML_TYPE_Q4_0) { + size_t global_work_size[] = {(size_t)(ne01 + 7)/8*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + + if (backend_ctx->gpu_family == INTEL) { + // Set global size for Intel. It uses 16x output values. + global_work_size[0] = (size_t)(ne01 + 15)/16*nth0; + global_work_size[1] = (size_t)ne11*nth1; + global_work_size[2] = (size_t)ne12*ne13; + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + return; + } +#else // GGML_OPENCL_SOA_Q + // TODO: add block_q4_0 variant. +#endif // GGML_OPENCL_SOA_Q + } + + // use custom matrix x vector kernel + switch (src0t) { + case GGML_TYPE_F32: + //GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(src1t == GGML_TYPE_F32); + kernel = backend_ctx->kernel_mul_mat_f32_f32; + nrows = 4; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 32; + nth1 = 1; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3)); + break; + case GGML_TYPE_F16: + //GGML_ASSERT(ne02 == ne12); + if (backend_ctx->gpu_family == INTEL) { + nth0 = 32; + nth1 = 1; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + if (src1t == GGML_TYPE_F32) { + if (ne11 * ne12 < 4) { + kernel = backend_ctx->kernel_mul_mat_f16_f32_1row; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + kernel = backend_ctx->kernel_mul_mat_f16_f32_l4; + nrows = ne11; + } else { + kernel = backend_ctx->kernel_mul_mat_f16_f32; + nrows = 4; + } + } else { + kernel = backend_ctx->kernel_mul_mat_f16_f16; + nrows = 4; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3)); + break; + case GGML_TYPE_Q4_0: + // This should have been satisfied. + GGML_ASSERT(ne11 == ne1); + GGML_ASSERT(ne01 == ne0); + +#ifdef GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat; + ndst = 8; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat; + ndst =8; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#else // GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + // Use 1D local size. Each workgroup is a SIMD group. Each SIMD + // group produces N_DST (4 for Q4_0 kernel) values in the result. + // The number of workgroups on dim 0 (the leading dimension) is + // the nearest multiple of 4 that covers ne0 (equals ne01). + nth0 = 16; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_v; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + kernel = backend_ctx->kernel_mul_mv_q6_K_f32; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 2; + nth1 = 16; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 2; + nth1 = 64; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); + break; + default: + GGML_ASSERT(false && "not implemented"); + } + + if (src0t == GGML_TYPE_Q4_0 || + src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q8_0 || + src0t == GGML_TYPE_Q2_K) { + // Each SIMD group produces N_DST values in the result. Assuming each + // workgroup has N_SIMDGROUP SIMD groups, then each workgroup will + // produce N_DST*N_SIMDGROUP values in the result. Hence, the grid size + // (number of workgroups) will be a nearest multiple of + // N_DST*N_SIMDGROUP to cover the size of the dimension. Below, 4 is + // N_DST*N_SIMDGROUP (see the kernel for Q4_0 matmul). + size_t global_work_size[] = {(size_t)(ne01 + ndst-1)/ndst*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else if (src0t == GGML_TYPE_Q4_K) { + GGML_ASSERT(false && "not implemented"); + } else if (src0t == GGML_TYPE_Q3_K) { + GGML_ASSERT(false && "not implemented"); + } else if (src0t == GGML_TYPE_Q5_K) { + GGML_ASSERT(false && "not implemented"); + } else if (src0t == GGML_TYPE_Q6_K) { + size_t global_work_size[] = {(size_t)(ne01+1)/2*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else { + int64_t ny = (ne11 + nrows - 1)/nrows; + + size_t global_work_size[] = {(size_t)ne01*nth0, (size_t)ny*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + float scale; + memcpy(&scale, dst->op_params, sizeof(scale)); + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel = backend_ctx->kernel_scale; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale)); + + int n = ggml_nelements(dst)/4; + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + + // GGML_OP_CPY happens between src0 and src1. + // GGML_OP_DUP and GGML_OP_CONT happen between src0 and dst. + UNUSED(dst); + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb00 = src0 ? src0->nb[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + const int ne10 = src1 ? src1->ne[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; + const int ne12 = src1 ? src1->ne[2] : 0; + const int ne13 = src1 ? src1->ne[3] : 0; + + const cl_ulong nb10 = src1 ? src1->nb[0] : 0; + const cl_ulong nb11 = src1 ? src1->nb[1] : 0; + const cl_ulong nb12 = src1 ? src1->nb[2] : 0; + const cl_ulong nb13 = src1 ? src1->nb[3] : 0; + + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; + const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + + cl_kernel kernel; + + switch (src0t) { + case GGML_TYPE_F32: + switch (src1t) { + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_cpy_f32_f16; + break; + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_cpy_f32_f32; + break; + default: + GGML_ASSERT(false && "not implemented"); + } + break; + case GGML_TYPE_F16: + switch (src1t) { + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_cpy_f16_f16; + break; + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_cpy_f16_f32; + break; + default: + GGML_ASSERT(false && "not implemented"); + } + break; + default: + GGML_ASSERT(false && "not implemented"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); + + const int nth = MIN(64, ne00); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, src1); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_dup(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cl_cpy(backend, src0, dst, nullptr); + UNUSED(src1); +} + +static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + int n_past = ((int32_t *)(dst->op_params))[0]; + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + if (ne00%8 == 0) { + kernel = backend_ctx->kernel_diag_mask_inf_8; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &n_past)); + + size_t global_work_size[] = {(size_t)ne00*ne01*ne02/8, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else { + kernel = backend_ctx->kernel_diag_mask_inf; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &n_past)); + + size_t global_work_size[] = {(size_t)ne00, (size_t)ne01, (size_t)ne02}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + // Softmax can now fuse KQ mask and KQ scale, which used to be two additional + // ops before softmax. It now also fuses alibi if `max_bias > 0`. For llama, + // alibi is not used; however, for some other models, it is used. + // KQ_mask + if (src1) { + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + } + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0; + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + float scale, max_bias; + memcpy(&scale, dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, dst->op_params + 1, sizeof(float)); + + const int nrows_x = ggml_nrows(src0); + const int nrows_y = src0->ne[1]; + + const int n_head = nrows_x/nrows_y; + const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + // Local size must be wave size. Each workgroup is a wave, working on a row, + // where a row corresponds to leading dimension. + int nth = MIN(32, ne00); + + if (backend_ctx->gpu_family == INTEL) { + // This is the same as the initial value. + nth = MIN(32, ne00); + } + else if (backend_ctx->gpu_family == ADRENO) { + nth = 64; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + cl_kernel kernel; + + if (ne00%4 == 0) { + kernel = backend_ctx->kernel_soft_max_4; + } else { + kernel = backend_ctx->kernel_soft_max; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2)); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + ggml_tensor * src2 = dst->src[2]; + ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr; + + cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0; + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const int nb00 = src0 ? src0->nb[0] : 0; + const int nb01 = src0 ? src0->nb[1] : 0; + const int nb02 = src0 ? src0->nb[2] : 0; + const int nb03 = src0 ? src0->nb[3] : 0; + + const int ne10 = src1 ? src1->ne[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; UNUSED(ne11); + const int ne12 = src1 ? src1->ne[2] : 0; UNUSED(ne12); + const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + + const int ne0 = dst ? dst->ne[0] : 0; + const int ne1 = dst ? dst->ne[1] : 0; + const int ne2 = dst ? dst->ne[2] : 0; + const int ne3 = dst ? dst->ne[3] : 0; + + const int nb0 = dst ? dst->nb[0] : 0; + const int nb1 = dst ? dst->nb[1] : 0; + const int nb2 = dst ? dst->nb[2] : 0; + const int nb3 = dst ? dst->nb[3] : 0; + + GGML_ASSERT(ne10 == ne02); + + int nth = MIN(64, ne00); + + const int n_past = ((int *) dst->op_params)[0]; + const int n_dims = ((int *) dst->op_params)[1]; + const int mode = ((int *) dst->op_params)[2]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + + const bool is_neox = mode & 2; + + cl_kernel kernel; + + if (!is_neox) { + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_rope_norm_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_rope_norm_f16; + break; + default: + GGML_ASSERT(false); + }; + } else { + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_rope_neox_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_rope_neox_f16; + break; + default: + GGML_ASSERT(false); + }; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), extra2 ? &extra2->data_device : &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &nb3)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &n_past)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &n_dims)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &n_ctx_orig)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float), &freq_base)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(float), &freq_scale)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float), &ext_factor)); + CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor)); + CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast)); + CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow)); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +//------------------------------------------------------------------------------ +// Op offloading +//------------------------------------------------------------------------------ + +typedef void (*ggml_cl_func_t)(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor) { + ggml_cl_func_t func = nullptr; + + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + + const bool any_on_device = tensor->extra + || (src0 != nullptr && src0->extra) + || (src1 != nullptr && src1->extra); + + switch (tensor->op) { + case GGML_OP_GET_ROWS: + if (!any_on_device) { + return false; + } + func = ggml_cl_get_rows; + break; + case GGML_OP_CPY: + if (!any_on_device) { + return false; + } + func = ggml_cl_cpy; + break; + case GGML_OP_DUP: + case GGML_OP_CONT: + if (!any_on_device) { + return false; + } + func = ggml_cl_dup; + break; + case GGML_OP_ADD: + if (!any_on_device) { + return false; + } + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + func = ggml_cl_add; + break; + case GGML_OP_MUL: + if (!any_on_device) { + return false; + } + func = ggml_cl_mul; + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_GELU: + if (!any_on_device) { + return false; + } + func = ggml_cl_gelu; + break; + case GGML_UNARY_OP_SILU: + if (!any_on_device) { + return false; + } + func = ggml_cl_silu; + break; + case GGML_UNARY_OP_RELU: + if (!any_on_device) { + return false; + } + func = ggml_cl_relu; + break; + default: + return false; + } break; + case GGML_OP_CLAMP: + if (!any_on_device) { + return false; + } + func = ggml_cl_clamp; + break; + case GGML_OP_NORM: + if (!any_on_device) { + return false; + } + func = ggml_cl_norm; + break; + case GGML_OP_RMS_NORM: + if (!any_on_device) { + return false; + } + func = ggml_cl_rms_norm; + break; + case GGML_OP_MUL_MAT: + if (!any_on_device && !ggml_cl_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) { + return false; + } + func = ggml_cl_mul_mat; + break; + case GGML_OP_SCALE: + if (!any_on_device) { + return false; + } + func = ggml_cl_scale; + break; + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + if (!any_on_device) { + return false; + } + func = ggml_cl_nop; + break; + case GGML_OP_DIAG_MASK_INF: + if (!any_on_device) { + return false; + } + func = ggml_cl_diag_mask_inf; + break; + case GGML_OP_SOFT_MAX: + if (!any_on_device) { + return false; + } + func = ggml_cl_soft_max; + break; + case GGML_OP_ROPE: + if (!any_on_device) { + return false; + } + func = ggml_cl_rope; + break; + default: + return false; + } + + func(backend, tensor->src[0], tensor->src[1], tensor); + return true; +} diff --git a/ggml/src/ggml-opencl/kernels/embed_kernel.py b/ggml/src/ggml-opencl/kernels/embed_kernel.py new file mode 100644 index 0000000000000..b5d1d7242b624 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/embed_kernel.py @@ -0,0 +1,26 @@ +# + +import sys +import logging +logger = logging.getLogger("opencl-embed-kernel") + + +def main(): + logging.basicConfig(level=logging.INFO) + + if len(sys.argv) != 3: + logger.info("Usage: python embed_kernel.py ") + sys.exit(1) + + ifile = open(sys.argv[1], "r") + ofile = open(sys.argv[2], "w") + + for i in ifile: + ofile.write('R"({})"\n'.format(i)) + + ifile.close() + ofile.close() + + +if __name__ == "__main__": + main() diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl new file mode 100644 index 0000000000000..d1cdf709babc5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl @@ -0,0 +1,2683 @@ +#ifdef cl_khr_fp16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#elif defined(cl_amd_fp16) +#pragma OPENCL EXTENSION cl_amd_fp16 : enable +#else +#error "Half precision floating point not supportedby OpenCL implementation on your device." +#endif + +#ifdef cl_khr_subgroups +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#elif defined(cl_intel_subgroups) +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#error "Subgroup not supported on your device." +#endif + +#ifdef cl_intel_required_subgroup_size +// Always use subgroup size of 32 on Intel. +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +// Always use subgroups size of 64 on Adreno. +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#else +// TODO: do not know how to choose subgroup size on other GPUs. +#error "Selecting subgroup size is not supported on your device." +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +//------------------------------------------------------------------------------ +// block_q4_1 +//------------------------------------------------------------------------------ +struct block_q4_1 +{ + half d; + half m; + uint8_t qs[QK4_1 / 2]; +}; + +//------------------------------------------------------------------------------ +// block_q5_0 +//------------------------------------------------------------------------------ +struct block_q5_0 +{ + half d; + uint32_t qh; + uint8_t qs[QK5_0 / 2]; +}; + +//------------------------------------------------------------------------------ +// block_q5_1 +//------------------------------------------------------------------------------ +struct block_q5_1 +{ + half d; + half m; + uint32_t qh; + uint8_t qs[QK5_1 / 2]; +}; + +//------------------------------------------------------------------------------ +// block_q8_0 +//------------------------------------------------------------------------------ +struct block_q8_0 +{ + half d; + int8_t qs[QK8_0]; +}; + +//------------------------------------------------------------------------------ +// block_q2_K +//------------------------------------------------------------------------------ +struct block_q2_K +{ + uint8_t scales[16]; + uint8_t qs[64]; + half d; + half dmin; +}; + +//------------------------------------------------------------------------------ +// block_q3_K +//------------------------------------------------------------------------------ +struct block_q3_K +{ + uint8_t hmask[32]; + uint8_t qs[64]; + uint8_t scales[12]; + half d; +}; + +//------------------------------------------------------------------------------ +// block_q4_K +//------------------------------------------------------------------------------ +struct block_q4_K +{ + half d; + half dmin; + uint8_t scales[12]; + uint8_t qs[128]; +}; + +//------------------------------------------------------------------------------ +// block_q5_K +//------------------------------------------------------------------------------ +struct block_q5_K +{ + half d; + half dmin; + uint8_t scales[12]; + uint8_t qh[32]; + uint8_t qs[128]; +}; + +//------------------------------------------------------------------------------ +// block_q6_K +//------------------------------------------------------------------------------ +struct block_q6_K +{ + uint8_t ql[128]; + uint8_t qh[64]; + int8_t scales[16]; + half d; +}; + +//------------------------------------------------------------------------------ +// dequantize_q4_0_f32, dequantize_q4_0_f16 +//------------------------------------------------------------------------------ +void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) { + global ushort * qs = ((global ushort *)xb + 1); + float d1 = il ? (xb->d / 16.h) : xb->d; + float d2 = d1 / 256.f; + float md = -8.h * xb->d; + ushort mask0 = il ? 0x00F0 : 0x000F; + ushort mask1 = mask0 << 8; + + reg->s0 = d1 * (qs[0] & mask0) + md; + reg->s1 = d2 * (qs[0] & mask1) + md; + + reg->s2 = d1 * (qs[1] & mask0) + md; + reg->s3 = d2 * (qs[1] & mask1) + md; + + reg->s4 = d1 * (qs[2] & mask0) + md; + reg->s5 = d2 * (qs[2] & mask1) + md; + + reg->s6 = d1 * (qs[3] & mask0) + md; + reg->s7 = d2 * (qs[3] & mask1) + md; + + reg->s8 = d1 * (qs[4] & mask0) + md; + reg->s9 = d2 * (qs[4] & mask1) + md; + + reg->sa = d1 * (qs[5] & mask0) + md; + reg->sb = d2 * (qs[5] & mask1) + md; + + reg->sc = d1 * (qs[6] & mask0) + md; + reg->sd = d2 * (qs[6] & mask1) + md; + + reg->se = d1 * (qs[7] & mask0) + md; + reg->sf = d2 * (qs[7] & mask1) + md; +} + +void dequantize_q4_0_f16(global struct block_q4_0 * xb, short il, half16 * reg) { + global ushort * qs = ((global ushort *)xb + 1); + half d1 = il ? (xb->d / 16.h) : xb->d; + half d2 = d1 / 256.h; + half md = -8.h * xb->d; + ushort mask0 = il ? 0x00F0 : 0x000F; + ushort mask1 = mask0 << 8; + + reg->s0 = d1 * (qs[0] & mask0) + md; + reg->s1 = d2 * (qs[0] & mask1) + md; + + reg->s2 = d1 * (qs[1] & mask0) + md; + reg->s3 = d2 * (qs[1] & mask1) + md; + + reg->s4 = d1 * (qs[2] & mask0) + md; + reg->s5 = d2 * (qs[2] & mask1) + md; + + reg->s6 = d1 * (qs[3] & mask0) + md; + reg->s7 = d2 * (qs[3] & mask1) + md; + + reg->s8 = d1 * (qs[4] & mask0) + md; + reg->s9 = d2 * (qs[4] & mask1) + md; + + reg->sa = d1 * (qs[5] & mask0) + md; + reg->sb = d2 * (qs[5] & mask1) + md; + + reg->sc = d1 * (qs[6] & mask0) + md; + reg->sd = d2 * (qs[6] & mask1) + md; + + reg->se = d1 * (qs[7] & mask0) + md; + reg->sf = d2 * (qs[7] & mask1) + md; +} + +//------------------------------------------------------------------------------ +// add +//------------------------------------------------------------------------------ + +// general-purpose kernel for addition of two tensors +// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 +// cons: not very efficient +kernel void kernel_add( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] + src1[idx1]; +} + +//------------------------------------------------------------------------------ +// mul +//------------------------------------------------------------------------------ +kernel void kernel_mul( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_mul_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] * src1[idx1]; +} + +//------------------------------------------------------------------------------ +// scale +//------------------------------------------------------------------------------ +kernel void kernel_scale( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + float scale +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + dst[get_global_id(0)] = src0[get_global_id(0)] * scale; +} + +//------------------------------------------------------------------------------ +// gelu +//------------------------------------------------------------------------------ +#define GELU_COEF_A 0.044715f +#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f + +kernel void kernel_gelu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + + dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + + dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +//------------------------------------------------------------------------------ +// silu +//------------------------------------------------------------------------------ +kernel void kernel_silu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x / (1.0f + exp(-x)); +} + +kernel void kernel_silu_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x / (1.0f + exp(-x)); +} + +//------------------------------------------------------------------------------ +// relu +//------------------------------------------------------------------------------ +kernel void kernel_relu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]); +} + +//------------------------------------------------------------------------------ +// clamp +//------------------------------------------------------------------------------ +kernel void kernel_clamp( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + float min, + float max +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = src0[get_global_id(0)] < min ? + min : + (src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]); +} + +//------------------------------------------------------------------------------ +// norm +//------------------------------------------------------------------------------ +kernel void kernel_norm( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + float eps, + local float * sum +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global void*)((global char*)dst + offsetd); + + global float * x = (global float *) ((global char *) src0 + get_group_id(0)*nb01); + + // MEAN + // parallel sum + sum[get_local_id(0)] = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + sum[get_local_id(0)] += x[i00]; + } + // reduce + barrier(CLK_LOCAL_MEM_FENCE); + for (uint i = get_local_size(0)/2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + float mean = sum[0] / ne00; + + // recenter and VARIANCE + barrier(CLK_LOCAL_MEM_FENCE); + global float * y = dst + get_group_id(0)*ne00; + sum[get_local_id(0)] = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + y[i00] = x[i00] - mean; + sum[get_local_id(0)] += y[i00] * y[i00]; + } + + // reduce + barrier(CLK_LOCAL_MEM_FENCE); + for (uint i = get_local_size(0)/2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + float variance = sum[0] / ne00; + + float scale = 1.0f/sqrt(variance + eps); + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + y[i00] = y[i00] * scale; + } +} + +//------------------------------------------------------------------------------ +// rms_norm +//------------------------------------------------------------------------------ +// This kernel depends on subgroup size. +kernel void kernel_rms_norm( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + float eps, + local float * sum // Note, the size depends on number of subgroups +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + global float4 * x = (global float4 *) ((global char *) src0 + get_group_id(0)*nb01); + global float * x_scalar = (global float *) x; + float4 sumf = 0; + float all_sum = 0; + + // parallel sum + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3; + all_sum = sub_group_reduce_add(all_sum); + if (get_sub_group_local_id() == 0) { + sum[get_sub_group_id()] = all_sum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + // broadcast + for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + } + if (get_local_id(0) == 0) { + for (int i = 4 * (ne00 / 4); i < ne00; i++) { + sum[0] += x_scalar[i]; + } + sum[0] /= ne00; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + const float mean = sum[0]; + const float scale = 1.0f/sqrt(mean + eps); + + global float4 * y = (global float4 *) (dst + get_group_id(0)*ne00); + global float * y_scalar = (global float *) y; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + y[i00] = x[i00] * scale; + } + if (get_local_id(0) == 0) { + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { + y_scalar[i00] = x_scalar[i00] * scale; + } + } +} + +//------------------------------------------------------------------------------ +// diag_mask_inf kernels +//------------------------------------------------------------------------------ +kernel void kernel_diag_mask_inf( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int n_past +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i02 = get_global_id(2); + int i01 = get_global_id(1); + int i00 = get_global_id(0); + + if (i00 > n_past + i01) { + dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; + } else { + dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + int ne00, + int ne01, + int n_past +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + int i = 2*get_global_id(0); + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int i4 = 4*i; + int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; + int i01 = i4/(ne00); i4 -= i01*ne00; + int i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + (&dst[i+1])[k] = -INFINITY; + if (i00 + k > n_past + i01) { + (&dst[i])[k] = -INFINITY; + } + } +} + +//------------------------------------------------------------------------------ +// softmax +//------------------------------------------------------------------------------ +kernel void kernel_soft_max( + global float * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0; + global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float max = sub_group_reduce_max(lmax); + + // parallel sum + float lsum = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // wish to compute it twice. + pdst[i00] = exp_psrc0; + } + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + pdst[i00] /= sum; + } +} + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_4( + global float * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0; + global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); + + const float max = sub_group_reduce_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + pdst4[i00] /= sum; + } +} + +//------------------------------------------------------------------------------ +// kernel_rope +//------------------------------------------------------------------------------ +float rope_yarn_ramp(float low, float high, int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +float2 rope_yarn( + float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); + } + return (float2)(cos(theta) * mscale, sin(theta) * mscale); +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +} + +float2 rope_yarn_corr_dims( + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow +) { + // start and end correction dims + return (float2)( + max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))), + min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))) + ); +} + +kernel void kernel_rope_norm_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + float theta = theta_base * pow(freq_base, inv_ndims*i0); + + float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + float x0 = src[0]; + float x1 = src[1]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_norm_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + float theta = theta_base * pow(freq_base, inv_ndims*i0); + + float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + float x0 = src[0]; + float x1 = src[1]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_neox_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_neox_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +//------------------------------------------------------------------------------ +// cpy +//------------------------------------------------------------------------------ + +kernel void kernel_cpy_f16_f16( + global half * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f16_f32( + global half * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + + src0 = (global half*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f16( + global float * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +//------------------------------------------------------------------------------ +// get_rows +//------------------------------------------------------------------------------ +kernel void kernel_get_rows_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { + ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = + ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { + ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = + ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_q4_0( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + const int NL = 2; + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { + float16 temp; + dequantize_q4_0_f32( + ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); + *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + } +} + +//------------------------------------------------------------------------------ +// mul_mat_f32_f32 +//------------------------------------------------------------------------------ +#define N_F32_F32 4 + +kernel void kernel_mul_mat_f32_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F32_F32; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global float * x = (global float *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global float4 * x4 = (global float4 *)x; + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + global float4 * y4 = (global float4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (float) x4[i].s0 * y4[i].s0; + sumf += (float) x4[i].s1 * y4[i].s1; + sumf += (float) x4[i].s2 * y4[i].s2; + sumf += (float) x4[i].s3 * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +//------------------------------------------------------------------------------ +// mul_mat_f16_f16 +//------------------------------------------------------------------------------ +#define N_F16_F16 4 + +kernel void kernel_mul_mat_f16_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3) +{ + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F16_F16; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half * x = (global half *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * y = (global half *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (half) x[i] * (half) y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global half4 * x4 = (global half4 *)x; + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * y = (global half *) (src1 + offset_src1); + global half4 * y4 = (global half4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (half) x4[i].s0 * y4[i].s0; + sumf += (half) x4[i].s1 * y4[i].s1; + sumf += (half) x4[i].s2 * y4[i].s2; + sumf += (half) x4[i].s3 * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (half) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +//------------------------------------------------------------------------------ +// mul_mat_f16_f32_1row +//------------------------------------------------------------------------------ +kernel void kernel_mul_mat_f16_f32_1row( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * x = (global half *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + if (ne00 < 128) { + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + global half4 * x4 = (global half4 *) x; + global float4 * y4 = (global float4 *) y; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (float) x4[i].s0 * y4[i].s0; + sumf += (float) x4[i].s1 * y4[i].s1; + sumf += (float) x4[i].s2 * y4[i].s2; + sumf += (float) x4[i].s3 * y4[i].s3; + } + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + +} + +//------------------------------------------------------------------------------ +// mul_mat_f16_f32 +//------------------------------------------------------------------------------ +#define N_F16_F32 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F16_F32; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half * x = (global half *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += convert_float(x[i]) * y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global half4 * x4 = (global half4 *)x; + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + global float4 * y4 = (global float4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += convert_float(x4[i].s0) * y4[i].s0; + sumf += convert_float(x4[i].s1) * y4[i].s1; + sumf += convert_float(x4[i].s2) * y4[i].s2; + sumf += convert_float(x4[i].s3) * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +//------------------------------------------------------------------------------ +// mul_mat_f16_f32_l4 +//------------------------------------------------------------------------------ +// Assumes row size (ne00) is a multiple of 4 +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32_l4( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int nrows = ne11; + int r0 = get_group_id(0); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half4 * x4 = (global half4 *) (src0 + offset_src0); + + for (int r1 = 0; r1 < nrows; ++r1) { + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float4 * y4 = (global float4 *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += convert_float(x4[i].s0) * y4[i].s0; + sumf += convert_float(x4[i].s1) * y4[i].s1; + sumf += convert_float(x4[i].s2) * y4[i].s2; + sumf += convert_float(x4[i].s3) * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + +//------------------------------------------------------------------------------ +// mul_vec_q_n_f32 +//------------------------------------------------------------------------------ +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_4_0_dot_y( + global struct block_q4_0 * qb_curr, + float sumy, + private float * yl, + int il +) { + float d = qb_curr->d; + float2 acc = 0.f; + global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { + acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (sumy * -8.f + acc.s0 + acc.s1); +} + +#ifdef INTEL_GPU +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global + // id of a SIMD group in the grid. + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; // src1 vector cache + float sumf[N_DST]={0.f}; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + for (int i = 0; i < 8; i += 2) { + sumy += yb[i] + yb[i+1]; + yl[i+0] = yb[i+ 0]; + yl[i+1] = yb[i+ 1]/256.f; + sumy += yb[i+16] + yb[i+17]; + yl[i+8] = yb[i+16]/16.f; + yl[i+9] = yb[i+17]/4096.f; + } + + for (int row = 0; row < N_DST; row++) { + sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il); + } + + // One thread in a SIMD group (i.e., subgroup) handles a half block, + // hence then entire SIMD group handles SIMDWIDTH/2 blocks. + // y points to the activation matrix (of type float). Therefore for + // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because + // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of + // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + // The above does not work for Adreno - it produces incorrect results for + // row = 1, 2, 3 and only row = 0 gives the correct result. + // If N_DST is changed, the below array must be initialized accordingly. + // This also seems to perform better on Intel. + float tot[N_DST] = { + sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]), + sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])}; + for (int row = 0; row < N_DST; ++row) { + if (get_sub_group_local_id() == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row]; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} + +// +// This variant unrolls the loops and uses vector types instead of pointers. +// It improves performance on Adreno but not so much on Intel. +// +inline float block_q_4_0_dot_y_v( + global struct block_q4_0 * qb_curr, + float sumy, + float16 yl, + int il +) { + float d = qb_curr->d; + float acc = 0.f; + global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_v( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global + // id of a SIMD group in the grid. + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; // src1 vector cache + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il); + + // One thread in a SIMD group (i.e., subgroup) handles a half block, + // hence then entire SIMD group handles SIMDWIDTH/2 blocks. + // y points to the activation matrix (of type float). Therefore for + // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because + // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of + // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + // The above does not work for Adreno - it produces incorrect results for + // row = 1, 2, 3 and only row = 0 gives the correct result. + // If N_DST is changed, the below array must be initialized accordingly. + // This also seems to perform better on Intel. + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_v( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} + +//------------------------------------------------------------------------------ +// kernel_convert_block_q4_0 +// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q4_0( + global struct block_q4_0 * src0, + global uchar * dst_q, + global half * dst_d +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK4_0/2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q4_0( + global uchar * src_q, + global half * src_d, + global struct block_q4_0 * dst +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + for (int i = 0; i < QK4_0/2; ++i) { + b->qs[i] = q[i]; + } +} + +//------------------------------------------------------------------------------ +// mul_vec_q_n_f32_flat +// +// This variation uses flat arrays (struct of arrays, SOA) representation for +// quant tensors. +//------------------------------------------------------------------------------ + +// This function requires the original shuffled weights. +// As a reminder, the original weights are shuffled so that (q[0], q[16]) are +// packed together in a byte, so are (q[1], q[17]) and so on. +inline float block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} + +// +// This variant outputs 8 values. +// +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = 0.f; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl new file mode 100644 index 0000000000000..e2024332f81b9 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl @@ -0,0 +1,106 @@ +//------------------------------------------------------------------------------ +// This file is contains additional kernels for data conversion. +// These kernels are used when loading the model, so its performance is less +// important. +//------------------------------------------------------------------------------ +#ifdef cl_khr_fp16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#elif defined(cl_amd_fp16) +#pragma OPENCL EXTENSION cl_amd_fp16 : enable +#else +#error "Half precision floating point not supportedby OpenCL implementation on your device." +#endif + +#ifdef cl_khr_subgroups +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#elif defined(cl_intel_subgroups) +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#error "Subgroup not supported on your device." +#endif + +#ifdef cl_intel_required_subgroup_size +// Always use subgroup size of 32 on Intel. +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +// Always use subgroups size of 64 on Adreno. +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#else +// TODO: do not know how to choose subgroup size on other GPUs. +#error "Selecting subgroup size is not supported on your device." +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +//------------------------------------------------------------------------------ +// mul_vec_q_n_f32_flat_noshuffle +// +// This variation uses flat arrays (struct of arrays, SOA) representation for +// quant tensors. It also uses non shuffled bit order for weights. +// +// The shuffled version is kept in the original file because moving it here +// seems to result in worse performance for adreno. +//------------------------------------------------------------------------------ + +kernel void kernel_convert_block_q4_0_noshuffle( + global struct block_q4_0 * src0, + global uchar * dst_q, + global half * dst_d +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + for (int i = 0; i < QK4_0/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + q[i + QK4_0/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + +#ifdef ADRENO_GPU + // Workaround for adreno - must have the following printf statement for + // the kernel to work properly. Otherwise it produces incorrect result. + // convert_uchar above also seems necessary. + // Compare against a large number so that it does not print anything. + // get_sub_group_local_id() also works. + if (get_global_id(0) == 65536*4096) { + printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0)); + } +#endif + } +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl new file mode 100644 index 0000000000000..5e195411d690e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl @@ -0,0 +1,265 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +// assume +#define QK4_0 32 +#define N_SIMDGROUP 4 + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + + +__attribute__((qcom_reqd_sub_group_size("full"))) +__kernel void kernel_gemv_noshuffle( + __read_only image1d_buffer_t src0_q, // quantized A + global half2 * src0_d, // A scales + __read_only image1d_buffer_t src1, // B + ulong offset1, // offset to B (0) + global float * dst, // C + ulong offsetd, // offset to C (0) + uint K, // K + int ne01, // M + int ne02, // 1 + int ne10, // K + int ne12, // 1 + int ne0, // M + int ne1, // N + int r2, // 1 + int r3) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + __private uint4 regA; + __private half2 regS; + __private float8 regB; + + __private float2 totalSum = (float2)(0.0f); + + // loop along K in block granularity, skip 4 blocks every iter + for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) { + regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows + // first 4 fibers in each wave load 8 B values to its private scope + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + } + + // reduction in local memory, assumes #wave=4 + __local float2 reduceLM[SIMDGROUP_WIDTH * 3]; + if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum; + if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum; + if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum; + barrier(CLK_LOCAL_MEM_FENCE); + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl new file mode 100644 index 0000000000000..5bdd4d067639a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl @@ -0,0 +1,271 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +// assume +#define QK4_0 32 +#define N_SIMDGROUP 4 + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + + +__attribute__((qcom_reqd_sub_group_size("full"))) +__kernel void kernel_gemv_noshuffle( + __read_only image1d_buffer_t src0_q, // quantized A + global half2 * src0_d, // A scales + __read_only image1d_buffer_t src1, // B + ulong offset1, // offset to B (0) + global float * dst, // C + ulong offsetd, // offset to C (0) + int ne00, // K + int ne01, // M + int ne02, // 1 + int ne10, // K + int ne12, // 1 + int ne0, // M + int ne1, // N + int r2, // 1 + int r3) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = N_SIMDGROUP * M; + + __private uint4 regA; + __private half2 regS; + __private float8 regB; + + __private float2 totalSum = (float2)(0.0f); + + // loop along K in block granularity, skip 4 blocks every iter + for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) { + regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows + // first 4 fibers in each wave load 8 B values to its private scope + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + } + + // reduction in local memory, assumes #wave=4 + __local float2 reduceLM[SIMDGROUP_WIDTH * 3]; + if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum; + if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum; + if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum; + barrier(CLK_LOCAL_MEM_FENCE); + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl new file mode 100644 index 0000000000000..e19e9a2f43629 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl @@ -0,0 +1,1225 @@ +//------------------------------------------------------------------------------ +// This file is contains additional mulmat kernels +// (and potentially other kernels). +//------------------------------------------------------------------------------ +#ifdef cl_khr_fp16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#elif defined(cl_amd_fp16) +#pragma OPENCL EXTENSION cl_amd_fp16 : enable +#else +#error "Half precision floating point not supportedby OpenCL implementation on your device." +#endif + +#ifdef cl_khr_subgroups +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#elif defined(cl_intel_subgroups) +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#error "Subgroup not supported on your device." +#endif + +#ifdef cl_intel_required_subgroup_size +// Always use subgroup size of 32 on Intel. +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +// Always use subgroups size of 64 on Adreno. +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#else +// TODO: do not know how to choose subgroup size on other GPUs. +#error "Selecting subgroup size is not supported on your device." +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +//------------------------------------------------------------------------------ +// block_q6_K +//------------------------------------------------------------------------------ +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +} block_q6_K; + +//------------------------------------------------------------------------------ +// These are the variant for matmatmul, based on the matvecmul kernel with +// flattened block_q4_0. +//------------------------------------------------------------------------------ + +// Common dot prod. +inline float mm_block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows (in weights matrix) +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif +// +// This variant performs 1d blocking with 8x output. +// Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix). +// +inline void mul_mat_q_n_f32_1d_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_1d_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 16 // each SIMD group works on 8 rows (in weights matrix) +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 16 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif +// +// This variant performs 1d blocking with 16x output. +// Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix). +// +inline void mul_mat_q_n_f32_1d_16x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 8*nb*QK4_0/2, d + ib + 8*nb, sumy, yl, il); + sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 9*nb*QK4_0/2, d + ib + 9*nb, sumy, yl, il); + sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il); + sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il); + + sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il); + sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il); + sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il); + sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float16 tot = (float16)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7), + + sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9), + sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb), + sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd), + sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + + if (first_row + 8 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8; + } + if (first_row + 9 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9; + } + if (first_row + 10 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa; + } + if (first_row + 11 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb; + } + + if (first_row + 12 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc; + } + if (first_row + 13 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd; + } + if (first_row + 14 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se; + } + if (first_row + 15 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_1d_16x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} + +//------------------------------------------------------------------------------ +// kernel_mul_mat_q4_0_f32_flat_v0 +//------------------------------------------------------------------------------ +inline float block_q_4_0_dot_y_flat_v2( + half x, + half d, + float sumy, + float4 yl +) { + uchar2 q = as_uchar2(x); + float acc = 0.0f; + + acc += (q.s0 & 0x0F) * yl.s0; + acc += (q.s1 & 0x0F) * yl.s1; + + acc += (q.s0 & 0xF0) * yl.s2; + acc += (q.s1 & 0xF0) * yl.s3; + + return d * (sumy * -8.f + acc);; +} + +inline float block_q_4_0_dot_y_flat_v4( + float x, + half d, + float sumy, + float8 yl +) { + uchar4 q = as_uchar4(x); + float acc = 0.0f; + + acc += (q.s0 & 0x0F) * yl.s0; + acc += (q.s1 & 0x0F) * yl.s1; + acc += (q.s2 & 0x0F) * yl.s2; + acc += (q.s3 & 0x0F) * yl.s3; + + acc += (q.s0 & 0xF0) * yl.s4; + acc += (q.s1 & 0xF0) * yl.s5; + acc += (q.s2 & 0xF0) * yl.s6; + acc += (q.s3 & 0xF0) * yl.s7; + + return d * (sumy * -8.f + acc);; +} + +inline float block_q_4_0_dot_y_flat_v8( + float2 x, + half d, + float sumy, + float16 yl +) { + uchar8 q = as_uchar8(x); + float acc = 0.0f; + + acc += (q.s0 & 0x0F) * yl.s0; + acc += (q.s1 & 0x0F) * yl.s1; + acc += (q.s2 & 0x0F) * yl.s2; + acc += (q.s3 & 0x0F) * yl.s3; + acc += (q.s4 & 0x0F) * yl.s4; + acc += (q.s5 & 0x0F) * yl.s5; + acc += (q.s6 & 0x0F) * yl.s6; + acc += (q.s7 & 0x0F) * yl.s7; + + acc += (q.s0 & 0xF0) * yl.s8; + acc += (q.s1 & 0xF0) * yl.s9; + acc += (q.s2 & 0xF0) * yl.sa; + acc += (q.s3 & 0xF0) * yl.sb; + acc += (q.s4 & 0xF0) * yl.sc; + acc += (q.s5 & 0xF0) * yl.sd; + acc += (q.s6 & 0xF0) * yl.se; + acc += (q.s7 & 0xF0) * yl.sf; + + return d * (sumy * -8.f + acc);; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define THREADS_PER_BLK 4 // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 16 +#elif defined (ADRENO_GPU) +#define THREADS_PER_BLK 4 +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block +# define ACT_TY float16 +# define Q_BLK_LD_TY float2 +# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v8 +#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block +# define ACT_TY float8 +# define Q_BLK_LD_TY float +# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v4 +#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block +# define ACT_TY float4 +# define Q_BLK_LD_TY half +# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v2 +#endif + +#define BTYES_PER_THREAD_IN_BLK (QK4_0/2/THREADS_PER_BLK) + +#if N_DST == 2 +# define SUM_TY float2 +#elif N_DST == 4 +# define SUM_TY float4 +#elif N_DST == 8 +# define SUM_TY float8 +#elif N_DST == 16 +# define SUM_TY float16 +#endif + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_flat_v0( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + int ix = get_sub_group_local_id()/THREADS_PER_BLK; + int il = get_sub_group_local_id()%THREADS_PER_BLK; + + global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il; + + // Registers for caching activation + ACT_TY yl = 0.f; + + // Registers for caching quants + Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0; +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0; +#endif +#if N_DST == 8 || N_DST == 16 + Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0; +#endif + + // Partial sum + SUM_TY sumf = 0.f; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) { + float sumy = 0.f; + + q_blk_0 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 0*nb*QK4_0/2); + q_blk_1 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 1*nb*QK4_0/2); +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + q_blk_2 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 2*nb*QK4_0/2); + q_blk_3 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 3*nb*QK4_0/2); +#endif +#if N_DST == 8 || N_DST == 16 + q_blk_4 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 4*nb*QK4_0/2)); + q_blk_5 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 5*nb*QK4_0/2)); + q_blk_6 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 6*nb*QK4_0/2)); + q_blk_7 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 7*nb*QK4_0/2)); +#endif + + // Load activation +#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block + yl.s01234567 = *(global float8 *)(yb); + yl.s89abcdef = *(global float8 *)(yb + 16); + + sumy += yl.s0; + sumy += yl.s1; + sumy += yl.s2; + sumy += yl.s3; + sumy += yl.s4; + sumy += yl.s5; + sumy += yl.s6; + sumy += yl.s7; + sumy += yl.s8; yl.s8 /= 16.f; + sumy += yl.s9; yl.s9 /= 16.f; + sumy += yl.sa; yl.sa /= 16.f; + sumy += yl.sb; yl.sb /= 16.f; + sumy += yl.sc; yl.sc /= 16.f; + sumy += yl.sd; yl.sd /= 16.f; + sumy += yl.se; yl.se /= 16.f; + sumy += yl.sf; yl.sf /= 16.f; +#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block + yl.s0123 = *(global float4 *)(yb); + yl.s4567 = *(global float4 *)(yb + 16); + + sumy += yl.s0; + sumy += yl.s1; + sumy += yl.s2; + sumy += yl.s3; + sumy += yl.s4; yl.s4 /= 16.f; + sumy += yl.s5; yl.s5 /= 16.f; + sumy += yl.s6; yl.s6 /= 16.f; + sumy += yl.s7; yl.s7 /= 16.f; +#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block + yl.s01 = *(global float2 *)(yb); + yl.s23 = *(global float2 *)(yb + 16); + + sumy += yl.s0; + sumy += yl.s1; + sumy += yl.s2; yl.s2 /= 16.f; + sumy += yl.s3; yl.s3 /= 16.f; +#endif + + sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, *(d + ib + 0*nb), sumy, yl); + sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, *(d + ib + 1*nb), sumy, yl); +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, *(d + ib + 2*nb), sumy, yl); + sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, *(d + ib + 3*nb), sumy, yl); +#endif +#if N_DST == 8 || N_DST == 16 + sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, *(d + ib + 4*nb), sumy, yl); + sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, *(d + ib + 5*nb), sumy, yl); + sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, *(d + ib + 6*nb), sumy, yl); + sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, *(d + ib + 7*nb), sumy, yl); +#endif + + yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK); + } + + SUM_TY tot = (SUM_TY)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1) +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) +#endif +#if N_DST == 8 || N_DST == 16 + , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5) + , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) +#endif + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } +#endif +#if N_DST == 8 || N_DST == 16 + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } +#endif + } +} + +//------------------------------------------------------------------------------ +// Using image1d_buffer_t + +#if defined(cl_qcom_subgroup_shuffle) +#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable +float qcom_sub_group_reduce_add(float sum) { + sum += qcom_sub_group_shuffle_down(sum, 32, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); + sum += qcom_sub_group_shuffle_down(sum, 16, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); + sum += qcom_sub_group_shuffle_down(sum, 8, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); + sum += qcom_sub_group_shuffle_down(sum, 4, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); + sum += qcom_sub_group_shuffle_down(sum, 2, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); + sum += qcom_sub_group_shuffle_down(sum, 1, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); + return sum; +} +#define sub_group_reduce_add qcom_sub_group_reduce_add +#else +#define sub_group_reduce_add sub_group_reduce_add +#endif + +#undef THREADS_PER_BLK +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define THREADS_PER_BLK 4 // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 16 +#elif defined (ADRENO_GPU) +#define THREADS_PER_BLK 4 +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block +# define ACT_TY float16 +# define Q_BLK_LD_TY float2 +# define EXTRACT_BLK_DATA(tmp, part) *((float2*)&tmp + part) +# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v8 +#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block +# define ACT_TY float8 +# define Q_BLK_LD_TY float +# define EXTRACT_BLK_DATA(tmp, part) *((float*)&tmp + part) +# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v4 +#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block +# define ACT_TY float4 +# define Q_BLK_LD_TY half +# define EXTRACT_BLK_DATA(tmp, part) *((half*)&tmp + part) +# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v2 +#endif + +#define BTYES_PER_THREAD_IN_BLK (QK4_0/2/THREADS_PER_BLK) + +#if N_DST == 2 +# define SUM_TY float2 +#elif N_DST == 4 +# define SUM_TY float4 +#elif N_DST == 8 +# define SUM_TY float8 +#elif N_DST == 16 +# define SUM_TY float16 +#endif + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_flat_img_v0( + read_only image1d_buffer_t src0_q, + read_only image1d_buffer_t src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + int ix = get_sub_group_local_id()/THREADS_PER_BLK; + int il = get_sub_group_local_id()%THREADS_PER_BLK; + + global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il; + + // Registers for caching activation + ACT_TY yl = 0.f; + + // Registers for caching quants + Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0; +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0; +#endif +#if N_DST == 8 || N_DST == 16 + Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0; +#endif + + // Partial sum + SUM_TY sumf = 0.f; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) { + float sumy = 0.f;; + + float4 tmp; + tmp = read_imagef(src0_q, offset0_q + ib + 0*nb); + q_blk_0 = EXTRACT_BLK_DATA(tmp, il); + tmp = read_imagef(src0_q, offset0_q + ib + 1*nb); + q_blk_1 = EXTRACT_BLK_DATA(tmp, il); +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + tmp = read_imagef(src0_q, offset0_q + ib + 2*nb); + q_blk_2 = EXTRACT_BLK_DATA(tmp, il); + tmp = read_imagef(src0_q, offset0_q + ib + 3*nb); + q_blk_3 = EXTRACT_BLK_DATA(tmp, il); +#endif +#if N_DST == 8 || N_DST == 16 + tmp = read_imagef(src0_q, offset0_q + ib + 4*nb); + q_blk_4 = EXTRACT_BLK_DATA(tmp, il); + tmp = read_imagef(src0_q, offset0_q + ib + 5*nb); + q_blk_5 = EXTRACT_BLK_DATA(tmp, il); + tmp = read_imagef(src0_q, offset0_q + ib + 6*nb); + q_blk_6 = EXTRACT_BLK_DATA(tmp, il); + tmp = read_imagef(src0_q, offset0_q + ib + 7*nb); + q_blk_7 = EXTRACT_BLK_DATA(tmp, il); +#endif + + // Load activation +#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block + yl.s01234567 = *(global float8 *)(yb); + yl.s89abcdef = *(global float8 *)(yb + 16); + + sumy += yl.s0; + sumy += yl.s1; + sumy += yl.s2; + sumy += yl.s3; + sumy += yl.s4; + sumy += yl.s5; + sumy += yl.s6; + sumy += yl.s7; + sumy += yl.s8; yl.s8 /= 16.f; + sumy += yl.s9; yl.s9 /= 16.f; + sumy += yl.sa; yl.sa /= 16.f; + sumy += yl.sb; yl.sb /= 16.f; + sumy += yl.sc; yl.sc /= 16.f; + sumy += yl.sd; yl.sd /= 16.f; + sumy += yl.se; yl.se /= 16.f; + sumy += yl.sf; yl.sf /= 16.f; +#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block + yl.s0123 = *(global float4 *)(yb); + yl.s4567 = *(global float4 *)(yb + 16); + + sumy += yl.s0; + sumy += yl.s1; + sumy += yl.s2; + sumy += yl.s3; + sumy += yl.s4; yl.s4 /= 16.f; + sumy += yl.s5; yl.s5 /= 16.f; + sumy += yl.s6; yl.s6 /= 16.f; + sumy += yl.s7; yl.s7 /= 16.f; +#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block + yl.s01 = *(global float2 *)(yb); + yl.s23 = *(global float2 *)(yb + 16); + + sumy += yl.s0; + sumy += yl.s1; + sumy += yl.s2; yl.s2 /= 16.f; + sumy += yl.s3; yl.s3 /= 16.f; +#endif + + sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, read_imageh(src0_d, offset0_d + ib + 0*nb).s0, sumy, yl); + sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, read_imageh(src0_d, offset0_d + ib + 1*nb).s0, sumy, yl); +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, read_imageh(src0_d, offset0_d + ib + 2*nb).s0, sumy, yl); + sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, read_imageh(src0_d, offset0_d + ib + 3*nb).s0, sumy, yl); +#endif +#if N_DST == 8 || N_DST == 16 + sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, read_imageh(src0_d, offset0_d + ib + 4*nb).s0, sumy, yl); + sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, read_imageh(src0_d, offset0_d + ib + 5*nb).s0, sumy, yl); + sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, read_imageh(src0_d, offset0_d + ib + 6*nb).s0, sumy, yl); + sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, read_imageh(src0_d, offset0_d + ib + 7*nb).s0, sumy, yl); +#endif + + yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK); + } + + SUM_TY tot = (SUM_TY)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1) +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) +#endif +#if N_DST == 8 || N_DST == 16 + , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5) + , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) +#endif + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } +#endif +#if N_DST == 8 || N_DST == 16 + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } +#endif + } +} + +//------------------------------------------------------------------------------ +// kernel_mul_mv_q6_K_f32 +//------------------------------------------------------------------------------ + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 1 // number of rows each SIMD group works on +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 1 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q6_K_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + uchar kmask1 = 0x03; + uchar kmask2 = 0x0C; + uchar kmask3 = 0x30; + uchar kmask4 = 0xC0; + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int row = N_SIMDGROUP * r0 + get_sub_group_id(); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0; + global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float sumf = 0; + + // For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a + // block. Values in a subblock shares a scale that is quantized with 8 bits; + // the entire block shares a single floating point scale. + // For work distribution, each thread processes a subblock (16 weights), hence + // 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16 + // (super) blocks -- this is the block stride. + // The 16 threads that process a (super) block are split into 2 portions, each has + // 8 threads; each portion works on 8 subblocks. + // For subgroup of 16 threads, the entire subgroup works on a single (super) block + // before moving to the next (super) block. Thread0 - thread7 work on the + // first 8 subblocks; thread8 - thread15 works on the last 8 subblocks. + // Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on + // subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but + // works on a total of 16 weight values. + int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0 + int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1 + int ip = tid/8; // first or second half of (super) block (0 or 1) + int il = tid%8; // each half has 8 parts, one per scale + int n = 4; // 4 scales at a time (and 4 sums) + int l0 = n*il; // offset into half-block, 0..28 + int is = 8*ip + l0/16; // 0, 1, 8, 9 + + int y_offset = 128*ip + l0; + int q_offset_l = 64*ip + l0; + int q_offset_h = 32*ip + l0; + + for (int i = ix; i < nb; i += BLOCK_STRIDE) { + + global uint8_t * q1 = x[i].ql + q_offset_l; + global uint8_t * q2 = q1 + QK_K/8; + global uint8_t * qh = x[i].qh + q_offset_h; + global int8_t * sc = x[i].scales + is; + + global float * y = yy + i * QK_K + y_offset; + + float dall = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + + sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f); + sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f); + sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & kmask3) << 0)) - 32.f); + sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f); + sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f); + sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & kmask3) << 0)) - 32.f); + sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f); + sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f); + sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & kmask3) << 0)) - 32.f); + sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f); + sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f); + sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & kmask3) << 0)) - 32.f); + sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f); + + sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]); + } + + float tot = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[r1*ne0 + im*ne0*ne1 + row] = tot; + } +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl new file mode 100644 index 0000000000000..57768c80334eb --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl @@ -0,0 +1,130 @@ +// src0_q, src0_d, src1 are transposed as a preprocessing step +// 4-bit weights are transposed in groups of 4 (unsigned short int) +// consider weights originally "next to each other", now "on top of each other" +// each fiber computes a 8x4 tile of output elements +// using unshuffled weights + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +__attribute__((qcom_reqd_sub_group_size("full"))) +kernel void kernel_mul_mat_Ab_Bi_8x4( + global const ushort * src0_q, // quantized A + global const half * src0_d, // A scales + __read_only image1d_buffer_t src1, // B (1d image) + global float * dst, // C + int m, // M + int n, // N with padding + int k, // K + int n_no_padding // N without padding +) { + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; // 8x4 output elements + half8 B; // registers for activations + half4 dequantized_weights; // registers for dequantized weights + __global const ushort* weight_ptr = src0_q + gx_2; // pointer for weights + __global const half* scale_ptr = src0_d + gx_2; // pointer for scales + + for(int i=0; i> 4) - 8) * scale.s0; // dequantize a row of the 16 weights + dequantized_weights.s1 = (((bits4.s1 & (0x00F0)) >> 4) - 8) * scale.s1; + dequantized_weights.s2 = (((bits4.s2 & (0x00F0)) >> 4) - 8) * scale.s2; + dequantized_weights.s3 = (((bits4.s3 & (0x00F0)) >> 4) - 8) * scale.s3; + c0 += B * dequantized_weights.s0; //vector-scalar multiplication to accumulate + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1); + dequantized_weights.s0 = (((bits4.s0 & (0x0F00)) >> 8) - 8) * scale.s0; // dequantize a row of the 16 weights + dequantized_weights.s1 = (((bits4.s1 & (0x0F00)) >> 8) - 8) * scale.s1; + dequantized_weights.s2 = (((bits4.s2 & (0x0F00)) >> 8) - 8) * scale.s2; + dequantized_weights.s3 = (((bits4.s3 & (0x0F00)) >> 8) - 8) * scale.s3; + c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1); + dequantized_weights.s0 = (((bits4.s0 & (0xF000)) >> 12) - 8) * scale.s0; // dequantize a row of the 16 weights + dequantized_weights.s1 = (((bits4.s1 & (0xF000)) >> 12) - 8) * scale.s1; + dequantized_weights.s2 = (((bits4.s2 & (0xF000)) >> 12) - 8) * scale.s2; + dequantized_weights.s3 = (((bits4.s3 & (0xF000)) >> 12) - 8) * scale.s3; + c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); // vectorized store 16 elements + + // conditional check if store is to a valid location. Required when N is not a multiple of 8 + // if statements allow registers to be reused for each store + // provides a performance boost due to reduced register footprint, which increases number of concurrent waves + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl new file mode 100644 index 0000000000000..d59a0c05ddfd0 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl @@ -0,0 +1,32 @@ +// 16-bit transpose, loading/storing an 8x8 tile of elements + +kernel void kernel_transpose_16( + __read_only image1d_buffer_t input, + __write_only image1d_buffer_t output, + const uint rows, + const uint cols +) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_3 = i<<3; + const int j_3 = j<<3; + + ushort8 temp0 = as_ushort8(read_imagef(input, (j_3+0)*cols+i)); + ushort8 temp1 = as_ushort8(read_imagef(input, (j_3+1)*cols+i)); + ushort8 temp2 = as_ushort8(read_imagef(input, (j_3+2)*cols+i)); + ushort8 temp3 = as_ushort8(read_imagef(input, (j_3+3)*cols+i)); + ushort8 temp4 = as_ushort8(read_imagef(input, (j_3+4)*cols+i)); + ushort8 temp5 = as_ushort8(read_imagef(input, (j_3+5)*cols+i)); + ushort8 temp6 = as_ushort8(read_imagef(input, (j_3+6)*cols+i)); + ushort8 temp7 = as_ushort8(read_imagef(input, (j_3+7)*cols+i)); + + write_imagef(output, (i_3+0)*rows+j, as_float4((ushort8)(temp0.s0, temp1.s0, temp2.s0, temp3.s0, temp4.s0, temp5.s0, temp6.s0, temp7.s0))); + write_imagef(output, (i_3+1)*rows+j, as_float4((ushort8)(temp0.s1, temp1.s1, temp2.s1, temp3.s1, temp4.s1, temp5.s1, temp6.s1, temp7.s1))); + write_imagef(output, (i_3+2)*rows+j, as_float4((ushort8)(temp0.s2, temp1.s2, temp2.s2, temp3.s2, temp4.s2, temp5.s2, temp6.s2, temp7.s2))); + write_imagef(output, (i_3+3)*rows+j, as_float4((ushort8)(temp0.s3, temp1.s3, temp2.s3, temp3.s3, temp4.s3, temp5.s3, temp6.s3, temp7.s3))); + write_imagef(output, (i_3+4)*rows+j, as_float4((ushort8)(temp0.s4, temp1.s4, temp2.s4, temp3.s4, temp4.s4, temp5.s4, temp6.s4, temp7.s4))); + write_imagef(output, (i_3+5)*rows+j, as_float4((ushort8)(temp0.s5, temp1.s5, temp2.s5, temp3.s5, temp4.s5, temp5.s5, temp6.s5, temp7.s5))); + write_imagef(output, (i_3+6)*rows+j, as_float4((ushort8)(temp0.s6, temp1.s6, temp2.s6, temp3.s6, temp4.s6, temp5.s6, temp6.s6, temp7.s6))); + write_imagef(output, (i_3+7)*rows+j, as_float4((ushort8)(temp0.s7, temp1.s7, temp2.s7, temp3.s7, temp4.s7, temp5.s7, temp6.s7, temp7.s7))); +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl new file mode 100644 index 0000000000000..914ec0193e7e9 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl @@ -0,0 +1,25 @@ +// 32-bit transpose, loading/storing a 4x4 tile of elements + +kernel void kernel_transpose_32( + __read_only image1d_buffer_t input, + __write_only image1d_buffer_t output, + const uint rows, + const uint cols +) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + + float4 temp0 = read_imagef(input, (j_2+0)*cols+i); + float4 temp1 = read_imagef(input, (j_2+1)*cols+i); + float4 temp2 = read_imagef(input, (j_2+2)*cols+i); + float4 temp3 = read_imagef(input, (j_2+3)*cols+i); + + write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); + write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); + +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl new file mode 100644 index 0000000000000..d3bd1fabb76c3 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl @@ -0,0 +1,35 @@ +// 32-bit transpose, loading/storing a 4x4 tile of elements +// Only used for activations +// converts to FP16 +// also adds zero padding for non multiple of 8 prompt lengths +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + half4 temp0 = {0,0,0,0}; // initialize outputs to 0 + half4 temp1 = {0,0,0,0}; + half4 temp2 = {0,0,0,0}; + half4 temp3 = {0,0,0,0}; + + if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0 + temp0 = read_imageh(input, (j_2+0)*cols+i); + } + if((j_2+1)*cols+i*4+3 < rows*cols*16){ + temp1 = read_imageh(input, (j_2+1)*cols+i); + } + if((j_2+2)*cols+i*4+3 < rows*cols*16){ + temp2 = read_imageh(input, (j_2+2)*cols+i); + } + if((j_2+3)*cols+i*4+3 < rows*cols*16){ + temp3 = read_imageh(input, (j_2+3)*cols+i); + } + + write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding + write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); +}