From 9f34dda3e40d889fac869c71c03fb89c135b0905 Mon Sep 17 00:00:00 2001 From: Bolun Date: Fri, 5 Jan 2024 06:57:40 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=8E=A5=E5=85=A5CNNL=EF=BC=8C?= =?UTF-8?q?=E5=B9=B6=E6=B7=BB=E5=8A=A0unary/binary/softmax/batchnorm/reduc?= =?UTF-8?q?e/transpose/pooling=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/02hardware/CMakeLists.txt | 8 +- src/02hardware/src/device_manager.cpp | 2 + src/02hardware/src/devices/mlu/device.cc | 8 +- src/02hardware/src/devices/mlu/functions.cc | 2 + src/02hardware/src/devices/mlu/functions.hh | 5 +- src/02hardware/src/devices/mlu/memory.cc | 2 + src/02hardware/src/devices/nvidia/device.cc | 4 + .../src/collectors/batch_normalization.cc | 4 + src/04kernel/src/collectors/reduce.cc | 4 + src/04kernel/src/collectors/simple_binary.cc | 4 + src/04kernel/src/collectors/simple_unary.cc | 6 + src/04kernel/src/collectors/softmax.cc | 7 + src/04kernel/src/collectors/transpose.cc | 6 + .../batch_normalization/cnnl_kernel.cc | 158 ++++++++++++++ .../batch_normalization/cnnl_kernel.hh | 32 +++ src/04kernel/src/kernels/pool/cnnl_kernel.cc | 156 ++++++++++++++ src/04kernel/src/kernels/pool/cnnl_kernel.hh | 45 ++++ .../src/kernels/reduce/cnnl_kernel.cc | 128 ++++++++++++ .../src/kernels/reduce/cnnl_kernel.hh | 32 +++ .../src/kernels/simple_binary/binary_cnnl.cc | 195 ++++++++++++++++++ .../src/kernels/simple_binary/binary_cnnl.hh | 28 +++ .../simple_unary/cnnl_activation_kernel.cc | 91 ++++++++ .../simple_unary/cnnl_activation_kernel.hh | 27 +++ .../simple_unary/cnnl_simple_unary_kernel.cc | 94 +++++++++ .../simple_unary/cnnl_simple_unary_kernel.hh | 27 +++ .../src/kernels/softmax/cnnl_kernel.cc | 86 ++++++++ .../src/kernels/softmax/cnnl_kernel.hh | 36 ++++ .../src/kernels/transpose/cnnl_kernel.cc | 92 +++++++++ .../src/kernels/transpose/cnnl_kernel.hh | 32 +++ .../src/utilities/bang/cnnl_context.cc | 35 ++++ .../src/utilities/bang/cnnl_context.hh | 29 +++ .../src/utilities/bang/cnnl_functions.cpp | 38 ++++ .../src/utilities/bang/cnnl_functions.h | 40 ++++ .../kernels/batch_normalization/test_cnnl.cpp | 70 +++++++ src/04kernel/test/kernels/pool/test_cnnl.cpp | 70 +++++++ .../test/kernels/reduce/test_cnnl.cpp | 64 ++++++ .../simple_binary/test_binary_cnnl.cpp | 90 ++++++++ .../test/kernels/simple_unary/test_cnnl.cpp | 63 ++++++ .../test/kernels/softmax/test_cnnl.cpp | 52 +++++ .../test/kernels/transpose/test_cnnl.cpp | 55 +++++ 40 files changed, 1918 insertions(+), 9 deletions(-) create mode 100644 src/04kernel/src/kernels/batch_normalization/cnnl_kernel.cc create mode 100644 src/04kernel/src/kernels/batch_normalization/cnnl_kernel.hh create mode 100644 src/04kernel/src/kernels/pool/cnnl_kernel.cc create mode 100644 src/04kernel/src/kernels/pool/cnnl_kernel.hh create mode 100644 src/04kernel/src/kernels/reduce/cnnl_kernel.cc create mode 100644 src/04kernel/src/kernels/reduce/cnnl_kernel.hh create mode 100644 src/04kernel/src/kernels/simple_binary/binary_cnnl.cc create mode 100644 src/04kernel/src/kernels/simple_binary/binary_cnnl.hh create mode 100644 src/04kernel/src/kernels/simple_unary/cnnl_activation_kernel.cc create mode 100644 src/04kernel/src/kernels/simple_unary/cnnl_activation_kernel.hh create mode 100644 src/04kernel/src/kernels/simple_unary/cnnl_simple_unary_kernel.cc create mode 100644 src/04kernel/src/kernels/simple_unary/cnnl_simple_unary_kernel.hh create mode 100644 src/04kernel/src/kernels/softmax/cnnl_kernel.cc create mode 100644 src/04kernel/src/kernels/softmax/cnnl_kernel.hh create mode 100644 src/04kernel/src/kernels/transpose/cnnl_kernel.cc create mode 100644 src/04kernel/src/kernels/transpose/cnnl_kernel.hh create mode 100644 src/04kernel/src/utilities/bang/cnnl_context.cc create mode 100644 src/04kernel/src/utilities/bang/cnnl_context.hh create mode 100644 src/04kernel/src/utilities/bang/cnnl_functions.cpp create mode 100644 src/04kernel/src/utilities/bang/cnnl_functions.h create mode 100644 src/04kernel/test/kernels/batch_normalization/test_cnnl.cpp create mode 100644 src/04kernel/test/kernels/pool/test_cnnl.cpp create mode 100644 src/04kernel/test/kernels/reduce/test_cnnl.cpp create mode 100644 src/04kernel/test/kernels/simple_binary/test_binary_cnnl.cpp create mode 100644 src/04kernel/test/kernels/simple_unary/test_cnnl.cpp create mode 100644 src/04kernel/test/kernels/softmax/test_cnnl.cpp create mode 100644 src/04kernel/test/kernels/transpose/test_cnnl.cpp diff --git a/src/02hardware/CMakeLists.txt b/src/02hardware/CMakeLists.txt index 1e38c5e2e..b42ef6327 100644 --- a/src/02hardware/CMakeLists.txt +++ b/src/02hardware/CMakeLists.txt @@ -3,14 +3,10 @@ project(hardware VERSION 0.0.0 LANGUAGES CXX) message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION}) # Source files -file(GLOB HARDWARE_SRC src/*.cc src/*.cpp src/devices/cpu/*.cc) +file(GLOB_RECURSE HARDWARE_SRC src/*.cc src/*.cpp) if(USE_CUDA) - file(GLOB_RECURSE HARDWARE_CUDA_SRC src/devices/nvidia/*.cu src/devices/nvidia/*.cc) -endif() - -if(USE_BANG) - file(GLOB_RECURSE HARDWARE_BANG_SRC src/devices/mlu/*.cc) + file(GLOB_RECURSE HARDWARE_CUDA_SRC src/devices/nvidia/*.cu) endif() add_library(hardware STATIC ${HARDWARE_SRC} ${HARDWARE_CUDA_SRC} ${HARDWARE_BANG_SRC}) diff --git a/src/02hardware/src/device_manager.cpp b/src/02hardware/src/device_manager.cpp index bcfab3bb8..7e449da7f 100644 --- a/src/02hardware/src/device_manager.cpp +++ b/src/02hardware/src/device_manager.cpp @@ -1,6 +1,7 @@ #include "hardware/device_manager.h" #include "hardware/devices/cpu.h" #include "hardware/devices/nvidia.h" +#include "hardware/devices/mlu.h" namespace refactor::hardware::device { @@ -37,6 +38,7 @@ namespace refactor::hardware::device { using T = Device::Type; // clang-format off auto device = type == T::Nvidia ? std::make_shared(card) + : type == T::Mlu ? std::make_shared(card) : UNREACHABLEX(Arc, ""); // clang-format on auto [kind, ok] = DEVICES.try_emplace(static_cast(type)); diff --git a/src/02hardware/src/devices/mlu/device.cc b/src/02hardware/src/devices/mlu/device.cc index 87b6150db..ea1f6affd 100644 --- a/src/02hardware/src/devices/mlu/device.cc +++ b/src/02hardware/src/devices/mlu/device.cc @@ -1,4 +1,4 @@ -#include "functions.cc" +#include "functions.hh" #include "hardware/devices/mlu.h" #include "hardware/mem_pool.h" #include "memory.hh" @@ -6,16 +6,20 @@ namespace refactor::hardware { static Arc bangMemory(int32_t card) { +#ifdef USE_BANG ASSERT(0 <= card && card < getDeviceCount(), "Invalid card id: {}", card); setDevice(card); auto [free, total] = getMemInfo(); auto size = std::min(free, std::max(5ul << 30, total * 4 / 5)); - fmt::println("initializing Nvidia GPU {}, memory {} / {}, alloc {}", + fmt::println("initializing Cambricon MLU {}, memory {} / {}, alloc {}", card, free, total, size); return std::make_shared( std::make_shared(), size, 256ul); +#else + return nullptr; +#endif } Mlu::Mlu(int32_t card) : Device(card, bangMemory(card)) {} diff --git a/src/02hardware/src/devices/mlu/functions.cc b/src/02hardware/src/devices/mlu/functions.cc index d8f30d0fe..bedea0458 100644 --- a/src/02hardware/src/devices/mlu/functions.cc +++ b/src/02hardware/src/devices/mlu/functions.cc @@ -2,6 +2,7 @@ namespace refactor::hardware { +#ifdef USE_BANG int getDeviceCount() { unsigned deviceCount; BANG_ASSERT(cnrtGetDeviceCount(&deviceCount)); @@ -15,5 +16,6 @@ namespace refactor::hardware { BANG_ASSERT(cnrtMemGetInfo(&memInfo.free, &memInfo.total)); return memInfo; } +#endif }// namespace refactor::hardware diff --git a/src/02hardware/src/devices/mlu/functions.hh b/src/02hardware/src/devices/mlu/functions.hh index 0244e01f0..c664ea6f8 100644 --- a/src/02hardware/src/devices/mlu/functions.hh +++ b/src/02hardware/src/devices/mlu/functions.hh @@ -1,14 +1,17 @@ #ifndef HARDWARE_DEVICES_MLU_FUNCTIONS_CUH #define HARDWARE_DEVICES_MLU_FUNCTIONS_CUH -#include "cnrt.h" #include "common.h" +#ifdef USE_BANG +#include "cnrt.h" + #define BANG_ASSERT(STATUS) \ if (auto status = (STATUS); status != CNRT_RET_SUCCESS) { \ RUNTIME_ERROR(fmt::format("bang failed on \"" #STATUS "\" with \"{}\" ({})", \ cnrtGetErrorStr(status), (int) status)); \ } +#endif namespace refactor::hardware { diff --git a/src/02hardware/src/devices/mlu/memory.cc b/src/02hardware/src/devices/mlu/memory.cc index 81b3c626a..55550314a 100644 --- a/src/02hardware/src/devices/mlu/memory.cc +++ b/src/02hardware/src/devices/mlu/memory.cc @@ -2,6 +2,7 @@ #include "functions.hh" namespace refactor::hardware { +#ifdef USE_BANG using M = MluMemory; void *M::malloc(size_t size) { @@ -27,5 +28,6 @@ namespace refactor::hardware { CNRT_MEM_TRANS_DIR_PEER2PEER)); return dst; } +#endif }// namespace refactor::hardware diff --git a/src/02hardware/src/devices/nvidia/device.cc b/src/02hardware/src/devices/nvidia/device.cc index 403921cba..345d71772 100644 --- a/src/02hardware/src/devices/nvidia/device.cc +++ b/src/02hardware/src/devices/nvidia/device.cc @@ -6,6 +6,7 @@ namespace refactor::hardware { static Arc cudaMemory(int32_t card) { +#ifdef USE_CUDA ASSERT(0 <= card && card < getDeviceCount(), "Invalid card id: {}", card); setDevice(card); auto [free, total] = getMemInfo(); @@ -16,6 +17,9 @@ namespace refactor::hardware { std::make_shared(), size, 256ul); +#else + return nullptr; +#endif } Nvidia::Nvidia(int32_t card) : Device(card, cudaMemory(card)) {} diff --git a/src/04kernel/src/collectors/batch_normalization.cc b/src/04kernel/src/collectors/batch_normalization.cc index 93bcb240e..e944e37d7 100644 --- a/src/04kernel/src/collectors/batch_normalization.cc +++ b/src/04kernel/src/collectors/batch_normalization.cc @@ -1,6 +1,7 @@ #include "kernel/collectors/batch_normalization.h" #include "../kernels/batch_normalization/cpu_kernel.hh" #include "../kernels/batch_normalization/cudnn_kernel.hh" +#include "../kernels/batch_normalization/cnnl_kernel.hh" namespace refactor::kernel { @@ -20,6 +21,9 @@ namespace refactor::kernel { case decltype(_target)::Nvidia: REGISTER(BatchNormalizationCudnn) break; + case decltype(_target)::Mlu: + REGISTER(BatchNormalizationCnnl) + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/reduce.cc b/src/04kernel/src/collectors/reduce.cc index bec37731d..71fa194ba 100644 --- a/src/04kernel/src/collectors/reduce.cc +++ b/src/04kernel/src/collectors/reduce.cc @@ -1,6 +1,7 @@ #include "kernel/collectors/reduce.h" #include "../kernels/reduce/cpu_kernel.hh" #include "../kernels/reduce/cudnn_kernel.hh" +#include "../kernels/reduce/cnnl_kernel.hh" namespace refactor::kernel { @@ -27,6 +28,9 @@ namespace refactor::kernel { case decltype(_target)::Nvidia: REGISTER(ReduceCudnn) break; + case decltype(_target)::Mlu: + REGISTER(ReduceCnnl) + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/simple_binary.cc b/src/04kernel/src/collectors/simple_binary.cc index e2c001ff7..f9f5b8dc0 100644 --- a/src/04kernel/src/collectors/simple_binary.cc +++ b/src/04kernel/src/collectors/simple_binary.cc @@ -2,6 +2,7 @@ #include "../kernels/simple_binary/binary_cudnn.hh" #include "../kernels/simple_binary/cpu_kernel.hh" #include "../kernels/simple_binary/cuda_kernel.hh" +#include "../kernels/simple_binary/binary_cnnl.hh" namespace refactor::kernel { @@ -48,6 +49,9 @@ namespace refactor::kernel { REGISTER_BROCAST(BinaryCudnn) REGISTER(BinaryCuda) break; + case decltype(_target)::Mlu: + REGISTER_BROCAST(BinaryCnnl) + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/simple_unary.cc b/src/04kernel/src/collectors/simple_unary.cc index de9e0bb07..b95169815 100644 --- a/src/04kernel/src/collectors/simple_unary.cc +++ b/src/04kernel/src/collectors/simple_unary.cc @@ -2,6 +2,8 @@ #include "../kernels/simple_unary/cpu_kernel.hh" #include "../kernels/simple_unary/cuda_kernel.hh" #include "../kernels/simple_unary/cudnn_activation_kernel.hh" +#include "../kernels/simple_unary/cnnl_activation_kernel.hh" +#include "../kernels/simple_unary/cnnl_simple_unary_kernel.hh" #include "common.h" namespace refactor::kernel { @@ -54,6 +56,10 @@ namespace refactor::kernel { REGISTER(ActivationCudnn) REGISTER(SimpleUnaryCuda) break; + case decltype(_target)::Mlu: + REGISTER(ActivationCnnl) + REGISTER(SimpleUnaryCnnl) + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/softmax.cc b/src/04kernel/src/collectors/softmax.cc index 2ce442696..020bc6ded 100644 --- a/src/04kernel/src/collectors/softmax.cc +++ b/src/04kernel/src/collectors/softmax.cc @@ -1,4 +1,5 @@ #include "kernel/collectors/softmax.h" +#include "../kernels/softmax/cnnl_kernel.hh" #include "../kernels/softmax/cpu_kernel.hh" #include "../kernels/softmax/cuda_kernel.hh" #include "../kernels/softmax/cudnn_kernel.hh" @@ -28,6 +29,12 @@ namespace refactor::kernel { } break; } + case decltype(_target)::Mlu: { + if (auto ptr = SoftmaxCnnl::build(cnnl::SoftmaxAlgo::ACCURATE, info); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; + } default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/transpose.cc b/src/04kernel/src/collectors/transpose.cc index 7fe8b294a..58d5d7180 100644 --- a/src/04kernel/src/collectors/transpose.cc +++ b/src/04kernel/src/collectors/transpose.cc @@ -1,6 +1,7 @@ #include "kernel/collectors/transpose.h" #include "../kernels/transpose/cpu_kernel.hh" #include "../kernels/transpose/cuda_kernel.hh" +#include "../kernels/transpose/cnnl_kernel.hh" namespace refactor::kernel { @@ -25,6 +26,11 @@ namespace refactor::kernel { ans.emplace_back(std::move(ptr)); } break; + case decltype(_target)::Mlu: + if (auto ptr = TransposeCnnl::build(data.dataType, data.shape, perm); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/kernels/batch_normalization/cnnl_kernel.cc b/src/04kernel/src/kernels/batch_normalization/cnnl_kernel.cc new file mode 100644 index 000000000..be06233cd --- /dev/null +++ b/src/04kernel/src/kernels/batch_normalization/cnnl_kernel.cc @@ -0,0 +1,158 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include +#endif + +namespace refactor::kernel { + using K = BatchNormalizationCnnl; + using DT = DataType; + + K::BatchNormalizationCnnl(decltype(info) info_) noexcept + : info(info_) {} + + auto K::build(float epsilon, TensorRefs inputs) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + + auto const &x = inputs[0].get(); + auto const &scale = inputs[1].get(); + auto const &mean = inputs[3].get(); + + if (x.rank() != 4) { + return nullptr; + } + + // see "Supported Configurations for `cnnlBatchNormalizationForwardInference`" + if (scale.dataType != mean.dataType) { + return nullptr; + } + if (x.dataType == DT::F64) { + if (scale.dataType != DT::F64) { + return nullptr; + } + } else { + if (scale.dataType != DT::F32) { + return nullptr; + } + } + return std::make_unique(decltype(info){ + epsilon, + x.dataType, + scale.dataType, + x.layout, + { + static_cast(x.shape[0]), + static_cast(x.shape[1]), + static_cast(x.shape[2]), + static_cast(x.shape[3]), + }}); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing batch normalization for non-training-mode using CNNL"; + } + +#ifdef USE_BANG + + auto K::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + using DT = DataType; + + // RAII for closure + struct Descriptors { + cnnlTensorDescriptor_t inDesc, inDescTrans, p; + cnnlTransposeDescriptor_t NCHW2NHWC, NHWC2NCHW; + bool f32; + + explicit Descriptors(decltype(f32) f32_) + : inDesc(nullptr), inDescTrans(nullptr), p(nullptr), + NCHW2NHWC(nullptr), NHWC2NCHW(nullptr), f32(f32_) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&inDesc)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&inDescTrans)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&p)); + CNNL_ASSERT(cnnlCreateTransposeDescriptor(&NCHW2NHWC)); + CNNL_ASSERT(cnnlCreateTransposeDescriptor(&NHWC2NCHW)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(inDesc)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(inDescTrans)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(p)); + CNNL_ASSERT(cnnlDestroyTransposeDescriptor(NCHW2NHWC)); + CNNL_ASSERT(cnnlDestroyTransposeDescriptor(NHWC2NCHW)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared(info.dtX != DT::F64); + int dimNCHW[4] = {info.dimAx[0], info.dimAx[1], info.dimAx[2], info.dimAx[3]}; + int dimNHWC[4] = {info.dimAx[0], info.dimAx[2], info.dimAx[3], info.dimAx[1]}; + int dimParam[]{info.dimAx[1]}; + setCnnlTensor(d->inDesc, info.dtX, slice(dimNCHW, 4)); + CNNL_ASSERT(cnnlSetTensorDescriptor(d->inDescTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(info.dtX), 4, dimNHWC)); + CNNL_ASSERT(cnnlSetTensorDescriptor(d->p, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dtP), 1, dimParam)); + int permute[4] = {0, 2, 3, 1}; + int permuteOut[4] = {0, 3, 1, 2}; + CNNL_ASSERT(cnnlSetTransposeDescriptor(d->NCHW2NHWC, 4, permute)); + CNNL_ASSERT(cnnlSetTransposeDescriptor(d->NHWC2NCHW, 4, permuteOut)); + + auto handle = res.fetchOrStore()->handle; + auto xTransSize = cnnlGetTensorElementNum(d->inDescTrans) * sizeof(info.dtX); + size_t workspaceSize; + CNNL_ASSERT(cnnlGetTransposeWorkspaceSize(handle, d->inDesc, d->NCHW2NHWC, &workspaceSize)); + size_t totalWorkspaceSize = xTransSize + workspaceSize; + + res.fetchOrStore(); + auto routine = [d = std::move(d), + epsilon = info.epsilon, + xTransSize, workspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + // fetch cnnl handle from resources + auto handle = res.fetchOrStore()->handle; + + // name inputs and outputs + auto x = inputs[0], + scale = inputs[1], + bias = inputs[2], + mean = inputs[3], + var = inputs[4]; + auto y = outputs[0]; + + void *xTrans = workspace; + void *yTrans = xTrans + xTransSize; + void *cursor = yTrans + workspaceSize; + + // transpose NCHW input to NHWC + CNNL_ASSERT(cnnlTranspose_v2(handle, d->NCHW2NHWC, d->inDesc, x, + d->inDescTrans, xTrans, cursor, workspaceSize)); + + // build alpha/beta for double + auto a = d->f32 ? factor(1) : factor(1), + b = d->f32 ? factor(0) : factor(0); + CNNL_ASSERT(cnnlBatchNormForwardInference( + handle, &a, &b, + d->inDescTrans, xTrans, d->p, scale, bias, mean, var, + epsilon, d->inDescTrans, yTrans)); + + // transpose NHWC intermediates to NCHW + CNNL_ASSERT(cnnlTranspose_v2(handle, d->NHWC2NCHW, d->inDescTrans, yTrans, + d->inDesc, y, cursor, workspaceSize)); + + BANG_ASSERT(cnrtQueueSync(res.fetchOrStore()->queue)); + }; + + return {std::move(routine), totalWorkspaceSize}; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/batch_normalization/cnnl_kernel.hh b/src/04kernel/src/kernels/batch_normalization/cnnl_kernel.hh new file mode 100644 index 000000000..978b0dedc --- /dev/null +++ b/src/04kernel/src/kernels/batch_normalization/cnnl_kernel.hh @@ -0,0 +1,32 @@ +#ifndef KERNEL_BATCH_NORMALIZATION_CNNL_KERNEL_HH +#define KERNEL_BATCH_NORMALIZATION_CNNL_KERNEL_HH + +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + /// @brief Use `cnnlBatchNormalizationForwardInference`. + /// It only supports 4D and 5D tensors. + struct BatchNormalizationCnnl final : public Kernel { + struct { + float epsilon; + DataType dtX, dtP; + LayoutType layout; + int dimAx[4];// dimA for x + } info; + + explicit BatchNormalizationCnnl(decltype(info)) noexcept; + + static KernelBox build(float, TensorRefs) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_BATCH_NORMALIZATION_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/pool/cnnl_kernel.cc b/src/04kernel/src/kernels/pool/cnnl_kernel.cc new file mode 100644 index 000000000..083125b1f --- /dev/null +++ b/src/04kernel/src/kernels/pool/cnnl_kernel.cc @@ -0,0 +1,156 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#endif + +namespace refactor::kernel { + using K = PoolCnnl; + + K::PoolCnnl(decltype(info) info_) noexcept + : Kernel(), info(std::move(info_)) {} + + auto K::build(PoolType poolType, + bool ceil, + KernelShape const &kernelShape, + PoolAttributes const &poolAttributes, + Tensor const &x, + Tensor const &y) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + + // TODO check data type + auto p = poolAttributes.pads(), + d = poolAttributes.dilations(), + s = poolAttributes.strides(); + if (x.rank() != 4 || + poolType == PoolType::Lp || + d[0] != 1 || d[1] != 1) { + return nullptr; + } + return std::make_unique(decltype(info){ + poolType, + x.dataType, + { + static_cast(x.shape[0]), + static_cast(x.shape[1]), + static_cast(x.shape[2]), + static_cast(x.shape[3]), + }, + { + static_cast(y.shape[0]), + static_cast(y.shape[1]), + static_cast(y.shape[2]), + static_cast(y.shape[3]), + }, + { + static_cast(kernelShape[0]), + static_cast(kernelShape[1]), + }, + {p[0], p[1], p[2], p[3]}, + {s[0], s[1]}, + {d[0], d[1]}, + ceil + }); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing pool using CNNL"; + } + +#ifdef USE_BANG + + auto PoolCnnl::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + using Ty = PoolType; + + // RAII for closure + struct Descriptors { + cnnlTensorDescriptor_t x, y; + cnnlPoolingDescriptor_t pooling; + bool f32; + + Descriptors(decltype(f32) f32_) : f32(f32_) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&x)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&y)); + CNNL_ASSERT(cnnlCreatePoolingDescriptor(&pooling)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(x)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(y)); + CNNL_ASSERT(cnnlDestroyPoolingDescriptor(pooling)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared(info.dt != DataType::F64); + int const + xs[]{ + info.xShape[0], + info.xShape[1], + info.xShape[2] + std::abs(info.pads[0] - info.pads[2]), + info.xShape[3] + std::abs(info.pads[1] - info.pads[3]), + }, + *ys = info.yShape; + setCnnlTensor(d->x, info.dt, slice(xs, 4)); + setCnnlTensor(d->y, info.dt, slice(ys, 4)); + + // clang-format off + auto mode = info.poolType == Ty::Average ? CNNL_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING + : info.poolType == Ty::Max ? CNNL_POOLING_MAX + : UNREACHABLEX(cnnlPoolingMode_t, ""); + // clang-format on + auto pp = info.pads; + auto ss = info.strides; + auto kk = info.kernelShape; + auto dd = info.dilations; + CNNL_ASSERT(cnnlSetPooling2dDescriptor_v2( + d->pooling, mode, CNNL_NOT_PROPAGATE_NAN, + kk[0], kk[1], pp[0], pp[2], pp[1], pp[3], + ss[0], ss[1], dd[0], dd[1], ceil)); + + auto handle = res.fetchOrStore()->handle; + size_t extraInputSize, workspaceSize; + CNNL_ASSERT(cnnlGetPoolingWorkspaceSize(handle, mode, ys[3], ys[2], &workspaceSize)); + CNNL_ASSERT(cnnlGetPoolingExtraInputSize(handle, mode, ys[3], ys[2], &extraInputSize)); + + res.fetchOrStore(); + auto routine = [d, workspaceSize, + extraInputSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + auto handle = res.fetchOrStore()->handle; + + void *extraInputDev = workspace; + void *poolWorkSpace = workspace + extraInputSize; + + void *extraInputHost = malloc(extraInputSize); + CNNL_ASSERT(cnnlInitPoolingExtraInput(handle, d->pooling, d->x, d->y, extraInputHost)); + BANG_ASSERT(cnrtMemcpy(extraInputDev, extraInputHost, extraInputSize, CNRT_MEM_TRANS_DIR_HOST2DEV)); + + // build alpha/beta for double + auto a = d->f32 ? factor(1) : factor(1), + b = d->f32 ? factor(0) : factor(0); + CNNL_ASSERT(cnnlPoolingForward_v2( + handle, d->pooling, + &a, d->x, inputs[0], + &b, extraInputDev, d->y, outputs[0], + poolWorkSpace, workspaceSize)); + + BANG_ASSERT(cnrtQueueSync(res.fetchOrStore()->queue)); + + free(extraInputHost); + }; + return {std::move(routine), workspaceSize + extraInputSize}; + } +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/pool/cnnl_kernel.hh b/src/04kernel/src/kernels/pool/cnnl_kernel.hh new file mode 100644 index 000000000..0a0298ede --- /dev/null +++ b/src/04kernel/src/kernels/pool/cnnl_kernel.hh @@ -0,0 +1,45 @@ +#ifndef KERNEL_POOL_CNNL_KERNEL_HH +#define KERNEL_POOL_CNNL_KERNEL_HH + +#include "kernel/attributes/pool_attributes.h" +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + /// @brief Use `cnnlPoolingForward`. + /// It only supports 4D tensors. + struct PoolCnnl final : public Kernel { + struct + { + PoolType poolType; + DataType dt; + int xShape[4], + yShape[4], + kernelShape[2], + pads[4], + strides[2], + dilations[2]; + bool ceil; + } info; + + explicit PoolCnnl(decltype(info)) noexcept; + + static KernelBox build(PoolType, + bool, + KernelShape const &, + PoolAttributes const &, + Tensor const &, + Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_POOL_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/reduce/cnnl_kernel.cc b/src/04kernel/src/kernels/reduce/cnnl_kernel.cc new file mode 100644 index 000000000..752bee690 --- /dev/null +++ b/src/04kernel/src/kernels/reduce/cnnl_kernel.cc @@ -0,0 +1,128 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include "hardware/functions.h" +#endif + +namespace refactor::kernel { + using K = ReduceCnnl; + + K::ReduceCnnl( + decltype(dataType) dataType_, + decltype(reduceType) reduceType_, + decltype(axes) axes_, + decltype(shape) shape_) noexcept + : Kernel(), + dataType(dataType_), + reduceType(reduceType_), + axes(std::move(axes_)), + shape(std::move(shape_)) {} + + auto K::build(decltype(axes) axes_, ReduceType reduceType_, TensorRefs inputs_) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + + auto const &x = inputs_[0].get(); + return x.dataType.isFloat() + ? std::make_unique(x.dataType, reduceType_, std::move(axes_), x.shape) + : nullptr; + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing reduce operation using CNNL"; + } + +#ifdef USE_BANG + + auto ReduceCnnl::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + + // RAII for closure + struct Descriptors { + cnnlTensorDescriptor_t x, y; + cnnlReduceDescriptor_t reduce; + bool f32; + + explicit Descriptors(decltype(f32) f32_) : f32(f32_) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&x)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&y)); + CNNL_ASSERT(cnnlCreateReduceDescriptor(&reduce)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(x)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(y)); + CNNL_ASSERT(cnnlDestroyReduceDescriptor(reduce)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared(dataType != DataType::F64); + + std::vector + dimsI(shape.begin(), shape.end()), + dimsO(shape.begin(), shape.end()); + for (auto axis : axes) { + dimsO[axis] = 1; + } + setCnnlTensor(d->x, dataType, slice(dimsI.data(), dimsI.size())); + setCnnlTensor(d->y, dataType, slice(dimsO.data(), dimsO.size())); + + // clang-format off + auto reduceOp = reduceType == ReduceType::Mean ? CNNL_REDUCE_AVG + : reduceType == ReduceType::Sum ? CNNL_REDUCE_ADD + : reduceType == ReduceType::Min ? CNNL_REDUCE_MIN + : reduceType == ReduceType::Max ? CNNL_REDUCE_MAX + : reduceType == ReduceType::L1 ? CNNL_REDUCE_NORM1 + : reduceType == ReduceType::L2 ? CNNL_REDUCE_NORM2 + : reduceType == ReduceType::Prod ? CNNL_REDUCE_MUL + : UNREACHABLEX(cnnlReduceOp_t, ""); + // clang-format on + CNNL_ASSERT(cnnlSetReduceDescriptor_v2( + d->reduce, (int *) (axes.data()), axes.size(), reduceOp, + cnnlDataTypeConvert(d->f32 ? DataType::F32 : DataType::F64), + CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES, 0.0)); + + auto handler = res.fetchOrStore()->handle; + size_t idxWorkspaceSize = axes.size() * sizeof(int); + // idxWorkspaceSize = hardware::alignBytes(idxWorkspaceSize, 256); + size_t workspaceSize; + // get workspace + CNNL_ASSERT(cnnlGetReduceOpWorkspaceSize(handler, d->x, d->y, d->reduce, &workspaceSize)); + + res.fetchOrStore(); + auto routine = [d = std::move(d), + idxWorkspaceSize, + workspaceSize](Resources &res, + void *workspace, + void const *const *inputs, + void *const *outputs) { + void *idxWorkspace = workspace, + *dataWorkspace = reinterpret_cast(workspace) + idxWorkspaceSize; + // build alpha/beta for double + auto a = d->f32 ? factor(1) : factor(1), + b = d->f32 ? factor(0) : factor(0); + CNNL_ASSERT(cnnlReduce( + res.fetchOrStore()->handle, + d->reduce, + dataWorkspace, workspaceSize, + &a, d->x, inputs[0], + idxWorkspaceSize, idxWorkspace, + &b, d->y, outputs[0])); + }; + return RoutineWorkspace(std::move(routine), idxWorkspaceSize + workspaceSize); + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/reduce/cnnl_kernel.hh b/src/04kernel/src/kernels/reduce/cnnl_kernel.hh new file mode 100644 index 000000000..6ffaf7387 --- /dev/null +++ b/src/04kernel/src/kernels/reduce/cnnl_kernel.hh @@ -0,0 +1,32 @@ +#ifndef KERNEL_REDUCE_MEAN_CNNL_KERNEL_HH +#define KERNEL_REDUCE_MEAN_CNNL_KERNEL_HH + +#include "kernel/collectors/reduce.h" +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct ReduceCnnl final : public Kernel { + DataType dataType; + ReduceType reduceType; + Axes axes; + Shape shape; + + ReduceCnnl(decltype(dataType), + decltype(reduceType), + decltype(axes), + decltype(shape)) noexcept; + + static KernelBox build(decltype(axes), ReduceType, TensorRefs) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; +}// namespace refactor::kernel + +#endif// KERNEL_REDUCE_MEAN_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/simple_binary/binary_cnnl.cc b/src/04kernel/src/kernels/simple_binary/binary_cnnl.cc new file mode 100644 index 000000000..cfe2ce0fa --- /dev/null +++ b/src/04kernel/src/kernels/simple_binary/binary_cnnl.cc @@ -0,0 +1,195 @@ +#include "binary_cnnl.hh" +#include + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#endif + +namespace refactor::kernel { + using K = BinaryCnnl; + using Op = SimpleBinaryType; + using DT = DataType; + + K::BinaryCnnl(Op opType_, DT dataType_, std::vector aDims_, std::vector bDims_, std::vector cDims_) noexcept + : Kernel(), dataType(dataType_), opType(opType_), aDims(aDims_), bDims(bDims_), cDims(cDims_) {} + + auto K::build(Op op, Tensor const &a, Tensor const &b, Tensor const &c) noexcept -> KernelBox { + static const std::unordered_set + ARTHIMETIC{Op::Add, Op::Sub, Op::Mul, Op::Div, Op::And, Op::Or, Op::Xor, Op::Pow}; + +#ifndef USE_BANG + return nullptr; +#endif + + if (a.dataType != b.dataType || + !a.dataType.isFloat() || + !ARTHIMETIC.contains(op) || + // At least one of a,b should have the same shape as c + (a.shape != c.shape && b.shape != c.shape) || + // Sub only supports brocasting b + (a.shape != c.shape && op == Op::Sub) || + // Cnnl binary op only supports up to 5D + !((a.rank() == 5 && b.rank() == 5) || (a.rank() <= 4 && b.rank() <= 4))) { + return nullptr; + } + + auto shape2IntVec = [](Shape shape) -> std::vector { + std::vector intVector; + intVector.reserve(shape.size()); + for (const uint32_t &element : shape) { + intVector.push_back(static_cast(element)); + } + return intVector; + }; + + return std::make_unique(op, a.dataType, shape2IntVec(a.shape), shape2IntVec(b.shape), shape2IntVec(c.shape)); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing element-wise op of 2 tensors with CNNL"; + } + +#ifdef USE_BANG + + auto BinaryCnnl::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + + struct Descriptors { + cnnlOpTensorDescriptor_t opDesc; + cnnlTensorDescriptor_t aDesc, bDesc, cDesc; + bool f32, sub; + + Descriptors(decltype(f32) f32_) : f32(f32_), sub(false) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&aDesc)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&bDesc)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&cDesc)); + CNNL_ASSERT(cnnlCreateOpTensorDescriptor(&opDesc)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(aDesc)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(bDesc)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(cDesc)); + CNNL_ASSERT(cnnlDestroyOpTensorDescriptor(opDesc)); + } + }; + auto d = std::make_shared(dataType != DT::F64); + cnnlOpTensorDesc_t cnnlOP; + cnnlLogicOp_t cnnlLogicOP; + if (opType == SimpleBinaryType::Add) { + cnnlOP = CNNL_OP_TENSOR_ADD; + } else if (opType == SimpleBinaryType::Sub) { + cnnlOP = CNNL_OP_TENSOR_ADD; + d->sub = true; + } else if (opType == SimpleBinaryType::Mul) { + cnnlOP = CNNL_OP_TENSOR_MUL; + } else if (opType == SimpleBinaryType::And) { + cnnlLogicOP = CNNL_LOGIC_OP_AND; + } else if (opType == SimpleBinaryType::Or) { + cnnlLogicOP = CNNL_LOGIC_OP_OR; + } else if (opType == SimpleBinaryType::Xor) { + cnnlLogicOP = CNNL_LOGIC_OP_XOR; + } + + setCnnlTensor(d->aDesc, dataType, slice(aDims.data(), aDims.size())); + setCnnlTensor(d->bDesc, dataType, slice(bDims.data(), bDims.size())); + setCnnlTensor(d->cDesc, dataType, slice(cDims.data(), cDims.size())); + CNNL_ASSERT(cnnlSetOpTensorDescriptor( + d->opDesc, cnnlOP, + cnnlDataTypeConvert(d->f32 ? DT::F32 : DT::F64), + CNNL_NOT_PROPAGATE_NAN)); + + auto cnnlGetBinaryWorkspaceSize = + (opType == SimpleBinaryType::Add || opType == SimpleBinaryType::Sub || opType == SimpleBinaryType::Mul) ? cnnlGetOpTensorWorkspaceSize + : (opType == SimpleBinaryType::Div) ? cnnlGetDivWorkspaceSize + : (opType == SimpleBinaryType::And || opType == SimpleBinaryType::Or || opType == SimpleBinaryType::Xor) ? cnnlGetLogicOpWorkspaceSize + : (opType == SimpleBinaryType::Pow) ? cnnlGetPowWorkspaceSize + : nullptr; + + if (cnnlGetBinaryWorkspaceSize == nullptr) { + UNREACHABLE(); + } + + auto handle = res.fetchOrStore()->handle; + size_t workspaceSize; + if (aDims != cDims) { + CNNL_ASSERT(cnnlGetBinaryWorkspaceSize(handle, d->bDesc, + d->aDesc, d->cDesc, + &workspaceSize)); + } else { + CNNL_ASSERT(cnnlGetBinaryWorkspaceSize(handle, d->aDesc, + d->bDesc, d->cDesc, + &workspaceSize)); + } + + res.fetchOrStore(); + auto routine = [swap = aDims != cDims, d, + workspaceSize, cnnlLogicOP, + op = this->opType](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + auto handle = res.fetchOrStore()->handle; + // name inputs and outputs + auto a = inputs[0], + b = inputs[1]; + auto c = outputs[0]; + if (op == SimpleBinaryType::Add || op == SimpleBinaryType::Sub || op == SimpleBinaryType::Mul) { + auto alphaA = d->f32 + ? factor(1) + : factor(1), + alphaB = d->f32 + ? factor(d->sub ? -1 : 1) + : factor(d->sub ? -1 : 1), + beta = d->f32 + ? factor(0) + : factor(0); + + if (swap) { + CNNL_ASSERT(cnnlOpTensor(handle, d->opDesc, + &alphaB, d->bDesc, b, + &alphaA, d->aDesc, a, + workspace, workspaceSize, + &beta, d->cDesc, c)); + } else { + CNNL_ASSERT(cnnlOpTensor(handle, d->opDesc, + &alphaA, d->aDesc, a, + &alphaB, d->bDesc, b, + workspace, workspaceSize, + &beta, d->cDesc, c)); + } + } else if (op == SimpleBinaryType::Div) { + CNNL_ASSERT(cnnlDiv_v2(handle, + CNNL_COMPUTATION_HIGH_PRECISION, + d->aDesc, a, + d->bDesc, b, + workspace, workspaceSize, + d->cDesc, c)); + } else if (op == SimpleBinaryType::And || op == SimpleBinaryType::Or || op == SimpleBinaryType::Xor) { + CNNL_ASSERT(cnnlLogicOp(handle, cnnlLogicOP, + d->aDesc, a, + d->bDesc, b, + workspace, workspaceSize, + d->cDesc, c)); + } else if (op == SimpleBinaryType::Pow) { + CNNL_ASSERT(cnnlPow(handle, + CNNL_COMPUTATION_HIGH_PRECISION, + d->aDesc, a, + d->bDesc, b, + workspace, workspaceSize, + d->cDesc, c)); + } + }; + + return {std::move(routine), workspaceSize}; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/simple_binary/binary_cnnl.hh b/src/04kernel/src/kernels/simple_binary/binary_cnnl.hh new file mode 100644 index 000000000..2d5c7cfaa --- /dev/null +++ b/src/04kernel/src/kernels/simple_binary/binary_cnnl.hh @@ -0,0 +1,28 @@ +#ifndef KERNEL_BINARY_CNNL_HH +#define KERNEL_BINARY_CNNL_HH + +#include "kernel/collectors/simple_binary.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct BinaryCnnl final : public Kernel { + DataType dataType; + SimpleBinaryType opType; + std::vector aDims, bDims, cDims; + + BinaryCnnl(SimpleBinaryType, DataType, std::vector aDims_, std::vector bDims_, std::vector cDims_) noexcept; + + static KernelBox build(SimpleBinaryType, Tensor const &, Tensor const &, Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_BINARY_CNNL_HH diff --git a/src/04kernel/src/kernels/simple_unary/cnnl_activation_kernel.cc b/src/04kernel/src/kernels/simple_unary/cnnl_activation_kernel.cc new file mode 100644 index 000000000..d35535948 --- /dev/null +++ b/src/04kernel/src/kernels/simple_unary/cnnl_activation_kernel.cc @@ -0,0 +1,91 @@ +#include "cnnl_activation_kernel.hh" +#include "kernel/collectors/simple_unary.h" +#include + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include +#endif + +namespace refactor::kernel { + using K = ActivationCnnl; + using DT = DataType; + using Op = SimpleUnaryType; + + K::ActivationCnnl(Op type_, DT dataType_, int size_) noexcept + : Kernel(), type(type_), dataType(dataType_), size(size_) {} + + auto K::build(Op op, Tensor const &a) noexcept -> KernelBox { + static const std::unordered_set ARTHIMETIC{Op::Sigmoid, Op::Relu, Op::Tanh}; + +#ifndef USE_BANG + return nullptr; +#endif + + return ARTHIMETIC.contains(op) + ? std::make_unique(op, a.dataType, static_cast(a.elementsSize())) + : nullptr; + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing activation using CNNL"; + } + +#ifdef USE_BANG + + auto ActivationCnnl::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + using Ty = SimpleUnaryType; + + // RAII for closure + struct Descriptors { + cnnlActivationDescriptor_t activation; + cnnlTensorDescriptor_t tensor; + + Descriptors() : activation(nullptr), tensor(nullptr) { + CNNL_ASSERT(cnnlCreateActivationDescriptor(&activation)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&tensor)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyActivationDescriptor(activation)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(tensor)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared(); + + // clang-format off + auto mode = type == Ty::Relu ? CNNL_ACTIVATION_RELU + : type == Ty::Sigmoid ? CNNL_ACTIVATION_SIGMOID + : type == Ty::Tanh ? CNNL_ACTIVATION_TANH + : UNREACHABLEX(cnnlActivationMode_t, ""); + // clang-format on + + setCnnlTensor(d->tensor, dataType, slice(&size, 1)); + CNNL_ASSERT(cnnlSetActivationDescriptor_v2(d->activation, mode, CNNL_ACTIVATION_HIGH_PRECISION, + CNNL_NOT_PROPAGATE_NAN, 0.0)); + + res.fetchOrStore(); + return [d = std::move(d)]// + (Resources & res, void *, void const *const *inputs, void *const *outputs) { + float alpha = 1, beta = 0; + CNNL_ASSERT(cnnlActivationForward( + res.fetchOrStore()->handle, + d->activation, + &alpha, d->tensor, inputs[0], + &beta, d->tensor, outputs[0])); + }; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/simple_unary/cnnl_activation_kernel.hh b/src/04kernel/src/kernels/simple_unary/cnnl_activation_kernel.hh new file mode 100644 index 000000000..a5d7ad65c --- /dev/null +++ b/src/04kernel/src/kernels/simple_unary/cnnl_activation_kernel.hh @@ -0,0 +1,27 @@ +#ifndef KERNEL_ACTIVATION_CNNL_KERNEL_HH +#define KERNEL_ACTIVATION_CNNL_KERNEL_HH + +#include "kernel/collectors/simple_unary.h" + +namespace refactor::kernel { + + struct ActivationCnnl final : public Kernel { + SimpleUnaryType type; + DataType dataType; + int size; + + ActivationCnnl(SimpleUnaryType, DataType, int) noexcept; + + static KernelBox build(SimpleUnaryType, Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_ACTIVATION_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/simple_unary/cnnl_simple_unary_kernel.cc b/src/04kernel/src/kernels/simple_unary/cnnl_simple_unary_kernel.cc new file mode 100644 index 000000000..f6e32159d --- /dev/null +++ b/src/04kernel/src/kernels/simple_unary/cnnl_simple_unary_kernel.cc @@ -0,0 +1,94 @@ +#include "cnnl_simple_unary_kernel.hh" +#include "kernel/collectors/simple_unary.h" +#include + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include +#endif + +namespace refactor::kernel { + using K = SimpleUnaryCnnl; + using DT = DataType; + using Op = SimpleUnaryType; + + K::SimpleUnaryCnnl(Op type_, DT dataType_, int size_) noexcept + : Kernel(), type(type_), dataType(dataType_), size(size_) {} + + auto K::build(Op op, Tensor const &a) noexcept -> KernelBox { + static const std::unordered_set supportedOp{Op::Abs, Op::Sqrt, Op::Neg}; + +#ifndef USE_BANG + return nullptr; +#endif + + return supportedOp.contains(op) + ? std::make_unique(op, a.dataType, static_cast(a.elementsSize())) + : nullptr; + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing simple unary using CNNL"; + } + +#ifdef USE_BANG + + auto SimpleUnaryCnnl::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + using Ty = SimpleUnaryType; + + // RAII for closure + struct Descriptors { + cnnlTensorDescriptor_t tensor; + + Descriptors() : tensor(nullptr) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&tensor)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(tensor)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared(); + + setCnnlTensor(d->tensor, dataType, slice(&size, 1)); + + auto cnnlUnaryForward = [this](cnnlHandle_t handle, + const cnnlTensorDescriptor_t x_desc, + const void *x, + const cnnlTensorDescriptor_t y_desc, + void *y) -> cnnlStatus_t { + switch (this->type) { + case Ty::Abs: + return cnnlAbs(handle, x_desc, x, y_desc, y); + case Ty::Neg: + return cnnlNegTensor(handle, x_desc, x, y_desc, y); + case Ty::Sqrt: + return cnnlSqrt_v2(handle, CNNL_COMPUTATION_HIGH_PRECISION, x_desc, x, y_desc, y); + default: + UNREACHABLE(); + } + }; + + res.fetchOrStore(); + return [d = std::move(d), cnnlUnaryForward]// + (Resources & res, void *, void const *const *inputs, void *const *outputs) { + CNNL_ASSERT(cnnlUnaryForward( + res.fetchOrStore()->handle, + d->tensor, inputs[0], + d->tensor, outputs[0])); + }; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/simple_unary/cnnl_simple_unary_kernel.hh b/src/04kernel/src/kernels/simple_unary/cnnl_simple_unary_kernel.hh new file mode 100644 index 000000000..b69902f7b --- /dev/null +++ b/src/04kernel/src/kernels/simple_unary/cnnl_simple_unary_kernel.hh @@ -0,0 +1,27 @@ +#ifndef KERNEL_SIMPLE_UNARY_CNNL_KERNEL_HH +#define KERNEL_SIMPLE_UNARY_CNNL_KERNEL_HH + +#include "kernel/collectors/simple_unary.h" + +namespace refactor::kernel { + + struct SimpleUnaryCnnl final : public Kernel { + SimpleUnaryType type; + DataType dataType; + int size; + + SimpleUnaryCnnl(SimpleUnaryType, DataType, int) noexcept; + + static KernelBox build(SimpleUnaryType, Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_SIMPLE_UNARY_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/softmax/cnnl_kernel.cc b/src/04kernel/src/kernels/softmax/cnnl_kernel.cc new file mode 100644 index 000000000..865e452e1 --- /dev/null +++ b/src/04kernel/src/kernels/softmax/cnnl_kernel.cc @@ -0,0 +1,86 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#endif + +namespace refactor::kernel { + using K = SoftmaxCnnl; + + K::SoftmaxCnnl(cnnl::SoftmaxAlgo algo_, DataType type_, + int pre_, int mid_, int post_) noexcept + : Kernel(), algo(algo_), dataType(type_), + pre(pre_), mid(mid_), post(post_) {} + + auto K::build(cnnl::SoftmaxAlgo algo, SoftmaxInfo info) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + + return std::make_unique(algo, info.type, info.pre, info.mid, info.post); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing softmax forward with CNNL"; + } + +#ifdef USE_BANG + + auto SoftmaxCnnl::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + + // RAII for closure + struct Descriptors { + cnnlTensorDescriptor_t t; + cnnlSoftmaxAlgorithm_t algo; + bool f32; + + Descriptors(decltype(algo) algo_, decltype(f32) f32_) + : algo(algo_), f32(f32_) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&t)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(t)); + } + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + + auto d = std::make_shared( + static_cast(algo), + dataType != DataType::F64); + int dims[]{pre, mid, post}; + cnnlSoftmaxMode_t mode = (post == 1) ? CNNL_SOFTMAX_MODE_HIGH_DIMENSION + : (pre == 1) ? CNNL_SOFTMAX_MODE_LOW_DIMENSION + : CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION; + + // cnnlSoftmaxForward_v2 is applied to a 3D input tensor only + CNNL_ASSERT(cnnlSetTensorDescriptor(d->t, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(dataType), 3, dims)); + + res.fetchOrStore(); + return [d = std::move(d), mode](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + // build alpha/beta for double + auto a = d->f32 ? factor(1) : factor(1), + b = d->f32 ? factor(0) : factor(0); + CNNL_ASSERT(cnnlSoftmaxForward_v2( + res.fetchOrStore()->handle, + d->algo, + mode, + CNNL_COMPUTATION_ULTRAHIGH_PRECISION, + &a, d->t, inputs[0], + &b, d->t, outputs[0])); + }; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/softmax/cnnl_kernel.hh b/src/04kernel/src/kernels/softmax/cnnl_kernel.hh new file mode 100644 index 000000000..b9bedb5a4 --- /dev/null +++ b/src/04kernel/src/kernels/softmax/cnnl_kernel.hh @@ -0,0 +1,36 @@ +#ifndef KERNEL_SOFTMAX_CNNL_HH +#define KERNEL_SOFTMAX_CNNL_HH + +#include "kernel/attributes/softmax_info.h" +#include "kernel/collectors/softmax.h" + +namespace refactor::kernel { + + namespace cnnl { + enum class SoftmaxAlgo { + FAST = 0, + ACCURATE = 1, + LOG = 2, + }; + }// namespace cnnl + + struct SoftmaxCnnl final : public Kernel { + cnnl::SoftmaxAlgo algo; + DataType dataType; + int pre, mid, post; + + SoftmaxCnnl(cnnl::SoftmaxAlgo, DataType, int, int, int) noexcept; + + static KernelBox build(cnnl::SoftmaxAlgo, SoftmaxInfo) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_SOFTMAX_CNNL_HH diff --git a/src/04kernel/src/kernels/transpose/cnnl_kernel.cc b/src/04kernel/src/kernels/transpose/cnnl_kernel.cc new file mode 100644 index 000000000..26a0bef1b --- /dev/null +++ b/src/04kernel/src/kernels/transpose/cnnl_kernel.cc @@ -0,0 +1,92 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include +#endif + +namespace refactor::kernel { + using K = TransposeCnnl; + using Info = TransposeInfo; + + K::TransposeCnnl(DataType dataType_, Shape dimIn_, Shape dimOut_, Permutation perm_) noexcept + : Kernel(), dataType(dataType_), dimIn(std::move(dimIn_)), + dimOut(std::move(dimOut_)), perm(std::move(perm_)) {} + + auto K::build(DataType dataType, Shape shape_, Permutation perm_) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + Shape dimOut_; + for (uint32_t i = 0; i < shape_.size(); i++) { + dimOut_.push_back(shape_[perm_[i]]); + } + return std::make_unique(dataType, std::move(shape_), std::move(dimOut_), std::move(perm_)); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing transpose operation using CNNL"; + } + +#ifdef USE_BANG + auto TransposeCnnl::lower(Resources &res) const noexcept -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + using DT = DataType; + + struct Descriptors { + cnnlTensorDescriptor_t x, y; + cnnlTransposeDescriptor_t trans; + bool f32; + + explicit Descriptors(decltype(f32) f32_) + : x(nullptr), y(nullptr), trans(nullptr), f32(f32_) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&x)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&y)); + CNNL_ASSERT(cnnlCreateTransposeDescriptor(&trans)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(x)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(y)); + CNNL_ASSERT(cnnlDestroyTransposeDescriptor(trans)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + + auto d = std::make_shared(dataType != DT::F64); + setCnnlTensor(d->x, dataType, slice((int *)(dimIn.data()), dimIn.size())); + setCnnlTensor(d->y, dataType, slice((int *)(dimOut.data()), dimOut.size())); + CNNL_ASSERT(cnnlSetTransposeDescriptor(d->trans, perm.size(), (int *)perm.data())); + + auto handle = res.fetchOrStore()->handle; + size_t workspaceSize; + CNNL_ASSERT(cnnlGetTransposeWorkspaceSize(handle, d->x, d->trans, &workspaceSize)); + + res.fetchOrStore(); + auto routine = [d = std::move(d), workspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + // fetch cnnl handle from resources + auto handle = res.fetchOrStore()->handle; + + // name inputs and outputs + auto x = inputs[0]; + auto y = outputs[0]; + + CNNL_ASSERT(cnnlTranspose_v2(handle, d->trans, d->x, x, + d->y, y, workspace, workspaceSize)); + }; + + return {std::move(routine), workspaceSize}; + } +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/transpose/cnnl_kernel.hh b/src/04kernel/src/kernels/transpose/cnnl_kernel.hh new file mode 100644 index 000000000..37bb7e088 --- /dev/null +++ b/src/04kernel/src/kernels/transpose/cnnl_kernel.hh @@ -0,0 +1,32 @@ +#ifndef KERNEL_TRANSPOSE_CNNL_KERNEL_HH +#define KERNEL_TRANSPOSE_CNNL_KERNEL_HH + +#include "kernel/collectors/transpose.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + using Shape = absl::InlinedVector; + using Permutation = Shape; + + struct TransposeCnnl final : public Kernel { + DataType dataType; + Shape dimIn; + Shape dimOut; + Permutation perm; + + TransposeCnnl(DataType, Shape, Shape, Permutation) noexcept; + + static KernelBox build(DataType, Shape, Permutation) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_TRANSPOSE_CNNL_KERNEL_HH diff --git a/src/04kernel/src/utilities/bang/cnnl_context.cc b/src/04kernel/src/utilities/bang/cnnl_context.cc new file mode 100644 index 000000000..15cc13829 --- /dev/null +++ b/src/04kernel/src/utilities/bang/cnnl_context.cc @@ -0,0 +1,35 @@ +#ifdef USE_BANG + +#include "cnnl_context.hh" +#include "cnnl_functions.h" + +namespace refactor::kernel::cnnl { + + CnnlContext::CnnlContext() : runtime::Resource() { + BANG_ASSERT(cnrtQueueCreate(&queue)); + CNNL_ASSERT(cnnlCreate(&handle)); + CNNL_ASSERT(cnnlSetQueue(handle, queue)); + } + CnnlContext::~CnnlContext() { + BANG_ASSERT(cnrtQueueDestroy(queue)); + CNNL_ASSERT(cnnlDestroy(handle)); + } + + auto CnnlContext::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto CnnlContext::build() -> runtime::ResourceBox { + return std::make_unique(); + } + + auto CnnlContext::resourceTypeId() const noexcept -> size_t { + return typeId(); + } + auto CnnlContext::description() const noexcept -> std::string_view { + return "CnnlContext"; + } + +}// namespace refactor::kernel::cnnl + +#endif diff --git a/src/04kernel/src/utilities/bang/cnnl_context.hh b/src/04kernel/src/utilities/bang/cnnl_context.hh new file mode 100644 index 000000000..7db40d3d3 --- /dev/null +++ b/src/04kernel/src/utilities/bang/cnnl_context.hh @@ -0,0 +1,29 @@ +#ifndef KERNEL_CNNL_CONTEXT_HH +#define KERNEL_CNNL_CONTEXT_HH + +#include "runtime/resource.h" +#include +#include + +namespace refactor::kernel::cnnl { + + struct CnnlContext final : public runtime::Resource { + cnnlHandle_t handle; + cnrtQueue_t queue; + + CnnlContext(); + ~CnnlContext(); + CnnlContext(CnnlContext const &) noexcept = delete; + CnnlContext(CnnlContext &&) noexcept = delete; + + static size_t typeId() noexcept; + static runtime::ResourceBox build(); + + size_t resourceTypeId() const noexcept final; + std::string_view description() const noexcept final; + + }; + +}// namespace refactor::kernel::cnnl + +#endif// KERNEL_CNNL_CONTEXT_HH diff --git a/src/04kernel/src/utilities/bang/cnnl_functions.cpp b/src/04kernel/src/utilities/bang/cnnl_functions.cpp new file mode 100644 index 000000000..8dfeb6457 --- /dev/null +++ b/src/04kernel/src/utilities/bang/cnnl_functions.cpp @@ -0,0 +1,38 @@ +#ifdef USE_BANG + +#include "cnnl_functions.h" + +namespace refactor::kernel::cnnl { + + cnnlDataType_t cnnlDataTypeConvert(DataType dataType) { + // clang-format off + switch (dataType) { + case DataType::F32 : return CNNL_DTYPE_FLOAT; break; + case DataType::F64 : return CNNL_DTYPE_DOUBLE; break; + case DataType::FP16: return CNNL_DTYPE_HALF; break; + case DataType::I8 : return CNNL_DTYPE_INT8; break; + case DataType::I32 : return CNNL_DTYPE_INT32; break; + case DataType::U8 : return CNNL_DTYPE_UINT8; break; + case DataType::BF16: return CNNL_DTYPE_BFLOAT16; break; + case DataType::I64 : return CNNL_DTYPE_INT64; break; + case DataType::Bool: return CNNL_DTYPE_BOOL; break; + default: UNREACHABLE(); + } + // clang-format on + } + + void setCnnlTensor(cnnlTensorDescriptor_t t, DataType dt, slice_t d) { + auto dt_ = cnnlDataTypeConvert(dt); + if (auto n = d.size(); n == 4) { + CNNL_ASSERT(cnnlSetTensorDescriptor(t, CNNL_LAYOUT_NCHW, dt_, d.size(), d.begin())); + } else if (n < 4) { + int d_[]{1, 1, 1, 1}; + std::copy_n(d.begin(), n, d_ + 4 - n); + CNNL_ASSERT(cnnlSetTensorDescriptor(t, CNNL_LAYOUT_NCHW, dt_, 4, std::move(d_))); + } else { + CNNL_ASSERT(cnnlSetTensorDescriptor(t, CNNL_LAYOUT_NCHW, dt_, d.size(), d.begin())); + } + } +}// namespace refactor::kernel::cnnl + +#endif diff --git a/src/04kernel/src/utilities/bang/cnnl_functions.h b/src/04kernel/src/utilities/bang/cnnl_functions.h new file mode 100644 index 000000000..4ba2f89d7 --- /dev/null +++ b/src/04kernel/src/utilities/bang/cnnl_functions.h @@ -0,0 +1,40 @@ +#ifndef KERNEL_CNNL_FUNCTIONS_H +#define KERNEL_CNNL_FUNCTIONS_H + +#include "common.h" +#include + +#define BANG_ASSERT(STATUS) \ + if (auto status = (STATUS); status != CNRT_RET_SUCCESS) { \ + RUNTIME_ERROR(fmt::format("bang failed on \"" #STATUS "\" with \"{}\" ({})", \ + cnrtGetErrorStr(status), (int) status)); \ + } + +#define CNNL_ASSERT(STATUS) \ + if (auto status = (STATUS); status != CNNL_STATUS_SUCCESS) { \ + fmt::println("cnnl failed on \"" #STATUS "\" with {}", \ + cnnlGetErrorString(status)); \ + abort(); \ + } + +namespace refactor::kernel::cnnl { + + cnnlDataType_t cnnlDataTypeConvert(DataType); + + // A helper function that set Cnnl tensor descriptor given tensor shape and type + void setCnnlTensor(cnnlTensorDescriptor_t, DataType, slice_t); + + template + constexpr uint64_t factor(T x) noexcept { + static_assert(std::is_floating_point_v); + static_assert(sizeof(T) <= sizeof(uint64_t)); + union { + T f; + uint64_t i; + } u{x}; + return u.i; + } + +}// namespace refactor::kernel::cnnl + +#endif// KERNEL_CNNL_FUNCTIONS_H diff --git a/src/04kernel/test/kernels/batch_normalization/test_cnnl.cpp b/src/04kernel/test/kernels/batch_normalization/test_cnnl.cpp new file mode 100644 index 000000000..14a1a07f4 --- /dev/null +++ b/src/04kernel/test/kernels/batch_normalization/test_cnnl.cpp @@ -0,0 +1,70 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/batch_normalization/cnnl_kernel.hh" +#include "../../../src/kernels/batch_normalization/cpu_kernel.hh" +#include "hardware/device_manager.h" +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, BatchNormalizationCnnl) { + // build routine + auto xTensor = Tensor::share(DataType::F32, Shape{1, 2, 3, 2}); + auto outTensor = Tensor::share(DataType::F32, Shape{1, 2, 3, 2}); + auto scaleTensor = Tensor::share(DataType::F32, Shape{2}); + auto biasTensor = Tensor::share(DataType::F32, Shape{2}); + auto meanTensor = Tensor::share(DataType::F32, Shape{2}); + auto varTensor = Tensor::share(DataType::F32, Shape{2}); + float epsilon = 0.00001; + TensorRefs inputs = TensorRefs{*xTensor, *scaleTensor, *biasTensor, *meanTensor, *varTensor}; + auto kCpu = BatchNormalization::build(epsilon, inputs); + auto kCnnl = BatchNormalizationCnnl::build(epsilon, inputs); + ASSERT_TRUE(kCpu && kCnnl); + auto res = runtime::Resources(); + auto rCpu = kCpu->lower(res).routine; + auto [rMlu, workspaceSize] = kCnnl->lower(res); + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto workspace = dev.malloc(workspaceSize), + mluIn = dev.malloc(xTensor->bytesSize()), + mluScale = dev.malloc(scaleTensor->bytesSize()), + mluBias = dev.malloc(biasTensor->bytesSize()), + mluMean = dev.malloc(meanTensor->bytesSize()), + mluVar = dev.malloc(varTensor->bytesSize()), + mluOut = dev.malloc(outTensor->bytesSize()); + // put input data + std::vector + data(xTensor->elementsSize(), 1.0f), + scale(scaleTensor->elementsSize(), 0.5f), + bias(biasTensor->elementsSize(), 1.0f), + mean(meanTensor->elementsSize(), 0.5f), + var(varTensor->elementsSize(), 1.0f), + cpuOut(outTensor->elementsSize()); + mluIn->copyFromHost(data.data(), xTensor->bytesSize()); + mluScale->copyFromHost(scale.data(), scaleTensor->bytesSize()); + mluBias->copyFromHost(bias.data(), biasTensor->bytesSize()); + mluMean->copyFromHost(mean.data(), meanTensor->bytesSize()); + mluVar->copyFromHost(var.data(), varTensor->bytesSize()); + // inference + { + void const *inputs[]{data.data(), scale.data(), bias.data(), mean.data(), var.data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{*mluIn, *mluScale, *mluBias, *mluMean, *mluVar}; + void *outputs[]{*mluOut}; + rMlu(res, *workspace, inputs, outputs); + } + // take output data + std::vector result(outTensor->elementsSize()); + mluOut->copyToHost(result.data(), outTensor->bytesSize()); + // check + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + } +} + +#endif diff --git a/src/04kernel/test/kernels/pool/test_cnnl.cpp b/src/04kernel/test/kernels/pool/test_cnnl.cpp new file mode 100644 index 000000000..405bf3f8c --- /dev/null +++ b/src/04kernel/test/kernels/pool/test_cnnl.cpp @@ -0,0 +1,70 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/pool/cnnl_kernel.hh" +#include "hardware/device_manager.h" +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +void testPoolCnnl(PoolType poolType, int rank, const int64_t *pads, const int64_t *strides, KernelShape kernelShape, Shape xShape, Shape yShape, const std::vector &ExpectData) { + auto dataTensor = Tensor::share(DataType::F32, xShape); + auto yTensor = Tensor::share(DataType::F32, yShape); + //bool ceil = false; + bool ceil = true; + int64_t const dilations[] = {1, 1}; + PoolAttributes poolAttributes(rank, dilations, pads, strides); + + auto kernel = PoolCnnl::build(poolType, ceil, kernelShape, poolAttributes, *dataTensor, *yTensor); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto [routine, workspaceSize] = kernel->lower(res); + // bang malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto workspace = dev.malloc(workspaceSize), + mluMem = dev.malloc(dataTensor->bytesSize()); + // put input data + std::vector data(dataTensor->elementsSize()); + for (auto i : range0_(data.size())) { data[i] = i * 0.1f; } + mluMem->copyFromHost(data.data(), dataTensor->bytesSize()); + // inference + void const *inputs[]{*mluMem}; + void *outputs[]{*mluMem}; + routine(res, *workspace, inputs, outputs); + // take output data + std::vector result(yTensor->elementsSize()); + mluMem->copyToHost(result.data(), yTensor->bytesSize()); + // check + for (auto i : range0_(ExpectData.size())) { + EXPECT_FLOAT_EQ(ExpectData[i], result[i]); + } +} + +TEST(kernel, PoolCnnlMax) { + int rank = 2; + int64_t const + pads[]{0, 0, 0, 0}, + strides[]{2, 2}; + KernelShape kernelShape{2, 2}; + Shape + xShape{1, 1, 4, 4}, + yShape{1, 1, 2, 2}; + const std::vector ExpectData = {0.5, 0.7, 1.3, 1.5}; + testPoolCnnl(PoolType::Max, rank, pads, strides, kernelShape, xShape, yShape, ExpectData); +} + +TEST(kernel, PoolCnnlAvg) { + int rank = 2; + int64_t const + pads[]{0, 0, 0, 0}, + strides[]{2, 2}; + KernelShape kernelShape{2, 2}; + Shape + xShape{1, 1, 4, 4}, + yShape{1, 1, 2, 2}; + const std::vector ExpectData = {0.25, 0.45, 1.05, 1.25}; + testPoolCnnl(PoolType::Average, rank, pads, strides, kernelShape, xShape, yShape, ExpectData); +} + +#endif diff --git a/src/04kernel/test/kernels/reduce/test_cnnl.cpp b/src/04kernel/test/kernels/reduce/test_cnnl.cpp new file mode 100644 index 000000000..32952fead --- /dev/null +++ b/src/04kernel/test/kernels/reduce/test_cnnl.cpp @@ -0,0 +1,64 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/reduce/cnnl_kernel.hh" +#include "hardware/device_manager.h" +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +static void testReducemean(const Shape &shape, const std::vector &data, + Axes axes, const std::vector ExpectData) { + // build routine + auto dataTensor = Tensor::share(DataType::F32, shape); + auto kernel = ReduceCnnl::build(axes, ReduceType::Mean, {*dataTensor}); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto [routine, workspaceSize] = kernel->lower(res); + // bang malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto workspace = dev.malloc(workspaceSize), + mluMemIn = dev.malloc(dataTensor->bytesSize()), + mluMemOut = dev.malloc(dataTensor->bytesSize()); + // put input output data + mluMemIn->copyFromHost(data.data(), dataTensor->bytesSize()); + // inference + { + void const *inputs[]{*mluMemIn}; + void *outputs[]{*mluMemOut}; + routine(res, *workspace, inputs, outputs); + } + // take output data + Shape outDimArray; + std::unordered_set axesSet(axes.begin(), axes.end()); + for (size_t i = 0; i < shape.size(); ++i) { + if (axesSet.contains(i)) { + outDimArray.push_back(shape[i]); + } + } + auto outputTensor = Tensor::share(DataType::F32, outDimArray); + std::vector result(outDimArray.size()); + mluMemOut->copyToHost(result.data(), outputTensor->bytesSize()); + // check + for (auto i : range0_(ExpectData.size())) { + EXPECT_FLOAT_EQ(ExpectData[i], result[i]); + } +} + +TEST(kernel, ReduceMeanCnnl) { + testReducemean({2, 3, 2, 2}, + {0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23}, + {1, 2}, + {5, 6, 17, 18}); + testReducemean({2, 3, 2, 2, 1}, + {0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23}, + {1, 2}, + {5, 6, 17, 18}); +} + +#endif diff --git a/src/04kernel/test/kernels/simple_binary/test_binary_cnnl.cpp b/src/04kernel/test/kernels/simple_binary/test_binary_cnnl.cpp new file mode 100644 index 000000000..4ef7c6d23 --- /dev/null +++ b/src/04kernel/test/kernels/simple_binary/test_binary_cnnl.cpp @@ -0,0 +1,90 @@ +#ifdef USE_BANG + +#include "../src/kernels/simple_binary/binary_cnnl.hh" +#include "../src/kernels/simple_binary/cpu_kernel.hh" +#include "hardware/device_manager.h" +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +void testBinaryCnnl(SimpleBinaryType binaryOPT, Shape dimA, Shape dimB, Shape dimC) { + // Create Tensor and build kernels + auto aTensor = Tensor::share(DataType::F32, dimA, LayoutType::NCHW); + auto bTensor = Tensor::share(DataType::F32, dimB, LayoutType::NCHW); + auto cTensor = Tensor::share(DataType::F32, dimC, LayoutType::NCHW); + auto kernel = BinaryCnnl::build(binaryOPT, *aTensor, *bTensor, *cTensor); + auto kCpu = BinaryCpu::build(binaryOPT, *aTensor, *bTensor); + ASSERT_TRUE(kCpu && kernel); + auto res = runtime::Resources(); + auto [routine, workspaceSize] = kernel->lower(res); + auto rCpu = kCpu->lower(res).routine; + // Init inputs and outputs + std::vector + a(aTensor->elementsSize(), 3.0f), + b(bTensor->elementsSize(), 2.0f), + c(cTensor->elementsSize()); + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto workspace = dev.malloc(workspaceSize), + aMLU = dev.malloc(aTensor->bytesSize()), + bMLU = dev.malloc(bTensor->bytesSize()), + cMLU = dev.malloc(cTensor->bytesSize()); + aMLU->copyFromHost(a.data(), aTensor->bytesSize()); + bMLU->copyFromHost(b.data(), bTensor->bytesSize()); + // Compute + { + void const *inputs[]{*aMLU, *bMLU}; + void *outputs[]{*cMLU}; + routine(res, *workspace, inputs, outputs); + } + { + void const *inputs[]{a.data(), b.data()}; + void *outputs[]{c.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // Compare + std::vector result(cTensor->elementsSize()); + cMLU->copyToHost(result.data(), cTensor->bytesSize()); + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(c[i], result[i]); + } +} + +TEST(kernel, BinaryCnnlAdd) { + testBinaryCnnl(SimpleBinaryType::Add, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}); +} + +TEST(kernel, BinaryCnnlMul) { + testBinaryCnnl(SimpleBinaryType::Mul, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}); +} + +TEST(kernel, BinaryCnnlSub) { + testBinaryCnnl(SimpleBinaryType::Sub, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}); +} + +TEST(kernel, BinaryCnnlDiv) { + testBinaryCnnl(SimpleBinaryType::Div, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}); +} + +// TEST(kernel, BinaryCnnlAnd) { +// testBinaryCnnl(SimpleBinaryType::And, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}); +// } + +// TEST(kernel, BinaryCnnlOr) { +// testBinaryCnnl(SimpleBinaryType::Or, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}); +// } + +// TEST(kernel, BinaryCnnlXor) { +// testBinaryCnnl(SimpleBinaryType::Xor, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}); +// } + +TEST(kernel, BinaryCnnlPow) { + testBinaryCnnl(SimpleBinaryType::Pow, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}); +} + +TEST(kernel, BinaryCnnlBroadcast) { + testBinaryCnnl(SimpleBinaryType::Add, Shape{3, 4, 5, 6}, Shape{}, Shape{3, 4, 5, 6}); +} + +#endif diff --git a/src/04kernel/test/kernels/simple_unary/test_cnnl.cpp b/src/04kernel/test/kernels/simple_unary/test_cnnl.cpp new file mode 100644 index 000000000..2707e6274 --- /dev/null +++ b/src/04kernel/test/kernels/simple_unary/test_cnnl.cpp @@ -0,0 +1,63 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/simple_unary/cnnl_activation_kernel.hh" +#include "../../../src/kernels/simple_unary/cnnl_simple_unary_kernel.hh" +#include "../../../src/kernels/simple_unary/cpu_kernel.hh" +#include "hardware/device_manager.h" +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +static void testOp(SimpleUnaryType opType, bool activation = true) { + // build routine + auto dataTensor = Tensor::share(DataType::F32, Shape{20, 30, 50}); + auto kernel = activation ? ActivationCnnl::build(opType, *dataTensor) + : SimpleUnaryCnnl::build(opType, *dataTensor); + auto kCpu = SimpleUnaryCpu::build(opType, *dataTensor); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine, + rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto mluMem = dev.malloc(dataTensor->bytesSize()); + // put input data + std::vector data(dataTensor->elementsSize()); + for (auto i : range0_(data.size())) { data[i] = i * 1e-4f; } + mluMem->copyFromHost(data.data(), dataTensor->bytesSize()); + // inference + { + void const *inputs[]{*mluMem}; + void *outputs[]{*mluMem}; + routine(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{data.data()}; + void *outputs[]{data.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // take output data + std::vector result(dataTensor->elementsSize()); + mluMem->copyToHost(result.data(), dataTensor->bytesSize()); + // check + for (auto i : range0_(data.size())) { + EXPECT_NEAR(data[i], result[i], 1e-4); + } +} + +TEST(kernel, SimpleUnaryCnnl) { + testOp(SimpleUnaryType::Abs, false); + testOp(SimpleUnaryType::Neg, false); + testOp(SimpleUnaryType::Sqrt, false); +} + +TEST(kernel, ActivationCnnl) { + testOp(SimpleUnaryType::Relu); + testOp(SimpleUnaryType::Sigmoid); + testOp(SimpleUnaryType::Tanh); +} + + +#endif// USE_BANG diff --git a/src/04kernel/test/kernels/softmax/test_cnnl.cpp b/src/04kernel/test/kernels/softmax/test_cnnl.cpp new file mode 100644 index 000000000..a8c7fb283 --- /dev/null +++ b/src/04kernel/test/kernels/softmax/test_cnnl.cpp @@ -0,0 +1,52 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/softmax/cpu_kernel.hh" +#include "../../../src/kernels/softmax/cnnl_kernel.hh" +#include "hardware/device_manager.h" +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, SoftmaxCnnl) { + // build routine + auto xTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 5, 4}); + auto outTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 5, 4}); + dim_t axis = 2; + auto kCpu = SoftmaxCpu::build(SoftmaxInfo(*xTensor, axis)); + auto kCnnl = SoftmaxCnnl::build(cnnl::SoftmaxAlgo::FAST, SoftmaxInfo(*xTensor, axis)); + ASSERT_TRUE(kCpu && kCnnl); + auto res = runtime::Resources(); + auto rCpu = kCpu->lower(res).routine; + auto rCnnl = kCnnl->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto mluIn = dev.malloc(xTensor->bytesSize()), + mluOut = dev.malloc(outTensor->bytesSize()); + // put input data + std::vector + data(xTensor->elementsSize(), 0), + cpuOut(outTensor->elementsSize()); + mluIn->copyFromHost(data.data(), xTensor->bytesSize()); + // inference + { + void const *inputs[]{data.data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{*mluIn}; + void *outputs[]{*mluOut}; + rCnnl(res, nullptr, inputs, outputs); + } + // take output data + std::vector result(outTensor->elementsSize()); + mluOut->copyToHost(result.data(), outTensor->bytesSize()); + // check + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + } +} + +#endif diff --git a/src/04kernel/test/kernels/transpose/test_cnnl.cpp b/src/04kernel/test/kernels/transpose/test_cnnl.cpp new file mode 100644 index 000000000..4f4301d86 --- /dev/null +++ b/src/04kernel/test/kernels/transpose/test_cnnl.cpp @@ -0,0 +1,55 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/transpose/cnnl_kernel.hh" +#include "../../../src/kernels/transpose/cpu_kernel.hh" +#include "hardware/device_manager.h" +#include +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, TransposeCnnl) { + // build routine + auto dataTensor = Tensor::share(DataType::F32, Shape{1, 3, 2, 5}); + auto info = TransposeInfo(dataTensor->shape, Permutation{2, 3, 0, 1}); + auto kCpu = TransposeCpu::build(dataTensor->dataType, info); + auto kernel = TransposeCnnl::build(dataTensor->dataType, dataTensor->shape, Permutation{2, 3, 0, 1}); + ASSERT_TRUE(kCpu && kernel); + auto res = runtime::Resources(); + auto rCpu = kCpu->lower(res).routine; + auto [routine, workspaceSize] = kernel->lower(res); + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto bytes = dataTensor->bytesSize(); + auto workspace = dev.malloc(workspaceSize), + mluIn = dev.malloc(bytes), + mluOut = dev.malloc(bytes); + // put input data + std::vector + cpuIn(dataTensor->elementsSize()), + cpuOut(cpuIn.size()); + std::iota(cpuIn.begin(), cpuIn.end(), 0); + mluIn->copyFromHost(cpuIn.data(), bytes); + // inference + { + void const *inputs[]{cpuIn.data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{*mluIn}; + void *outputs[]{*mluOut}; + routine(res, *workspace, inputs, outputs); + } + // take output data + std::vector result(dataTensor->elementsSize()); + mluOut->copyToHost(result.data(), bytes); + // check + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + } +} + +#endif