-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ GPU/OpenCL ] change swiglu_cl to inherit LayerImplCl
- This commit updates swiglu_cl.cpp/h to inherit LayerImplCl. - This commit implements registerClKernels() of swiglu_cl layer. - This commit update cl_context.cpp (applying swiglu_cl's update) Self evaluation: Build test: [X]Passed [ ]Failed [ ]Skipped Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Eunju Yang <[email protected]>
- Loading branch information
Showing
3 changed files
with
69 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
* @brief Implementation of SwiGLU activation function | ||
* @see https://github.com/nnstreamer/nntrainer | ||
* @author Niket Agarwal <[email protected]> | ||
* @author Eunju Yang <[email protected]> | ||
* @bug No known bugs except for NYI items | ||
* | ||
*/ | ||
|
@@ -63,9 +64,6 @@ void SwiGLULayerCl::incremental_forwarding(RunLayerContext &context, | |
swigluProcess(in1, in2, out); | ||
} | ||
|
||
opencl::Kernel SwiGLULayerCl::kernel_swiglu; | ||
opencl::Kernel SwiGLULayerCl::kernel_swiglu_fp16; | ||
|
||
void SwiGLULayerCl::swigluProcess(Tensor const &in1, Tensor const &in2, | ||
Tensor &result) { | ||
|
||
|
@@ -90,18 +88,54 @@ void SwiGLULayerCl::swigluProcess(Tensor const &in1, Tensor const &in2, | |
} | ||
} | ||
|
||
void SwiGLULayerCl::swiglu_cl(const float *matAdata, const float *vecXdata, | ||
float *vecYdata, unsigned int dim1, | ||
unsigned int dim2) { | ||
bool SwiGLULayerCl::registerClKernels() { | ||
|
||
bool result = false; | ||
// check if the kernels are already registered. | ||
if (!layer_kernel_ptrs.empty()) { | ||
ml_loge("kernels for swiglu_cl are already registered."); | ||
return false; | ||
} | ||
|
||
do { | ||
ClContext::SharedPtrClKernel kernel_swiglu_ptr = | ||
|
||
ClContext::SharedPtrClKernel kernel_swiglu_ptr = nullptr; | ||
|
||
kernel_swiglu_ptr = | ||
cl_context_ref.registerClKernel(swiglu_cl_kernel_, "swiglu_cl"); | ||
if (!kernel_swiglu_ptr) { | ||
ml_loge("OpenCL Error: Fail to register swiglu_cl kernel"); | ||
break; | ||
} | ||
layer_kernel_ptrs.emplace_back(kernel_swiglu_ptr); | ||
kernel_swiglu_ptr = | ||
cl_context_ref.registerClKernel(swiglu_cl_kernel_fp16_, "swiglu_cl_fp16"); | ||
|
||
#ifdef ENABLE_FP16 | ||
if (!kernel_swiglu_ptr) { | ||
ml_loge("OpenCL Error: Fail to register swiglu_cl_fp16 kernel"); | ||
break; | ||
} | ||
layer_kernel_ptrs.emplace_back(kernel_swiglu_ptr); | ||
#endif | ||
|
||
return true; | ||
} while (false); | ||
|
||
// clear all registered kernels if any error occurs during registration | ||
layer_kernel_ptrs.clear(); | ||
|
||
return false; | ||
} | ||
|
||
void SwiGLULayerCl::swiglu_cl(const float *matAdata, const float *vecXdata, | ||
float *vecYdata, unsigned int dim1, | ||
unsigned int dim2) { | ||
|
||
bool result = false; | ||
|
||
do { | ||
|
||
auto kernel_swiglu_ptr = layer_kernel_ptrs[Kernels::SWIGLU_CL]; | ||
|
||
int dim = int(dim1 * dim2); | ||
opencl::Buffer inputA(cl_context_ref.context_inst_, | ||
|
@@ -160,18 +194,16 @@ void SwiGLULayerCl::swiglu_cl(const float *matAdata, const float *vecXdata, | |
} while (false); | ||
} | ||
|
||
#ifdef ENABLE_FP16 | ||
void SwiGLULayerCl::swiglu_cl_fp16(const __fp16 *matAdata, | ||
const __fp16 *vecXdata, __fp16 *vecYdata, | ||
unsigned int dim1, unsigned int dim2) { | ||
|
||
bool result = false; | ||
|
||
do { | ||
ClContext::SharedPtrClKernel kernel_swiglu_ptr = | ||
cl_context_ref.registerClKernel(swiglu_cl_kernel_fp16_, "swiglu_cl_fp16"); | ||
if (!kernel_swiglu_ptr) { | ||
break; | ||
} | ||
|
||
auto kernel_swiglu_ptr = layer_kernel_ptrs[Kernels::SWIGLU_CL_FP16]; | ||
|
||
int dim = int(dim1 * dim2); | ||
opencl::Buffer inputA(cl_context_ref.context_inst_, | ||
|
@@ -229,6 +261,7 @@ void SwiGLULayerCl::swiglu_cl_fp16(const __fp16 *matAdata, | |
|
||
} while (false); | ||
} | ||
#endif | ||
|
||
void SwiGLULayerCl::calcDerivative(nntrainer::RunLayerContext &context) { | ||
std::throw_with_nested(std::runtime_error("Training is not supported yet.")); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
* @brief Implementation of SwiGLU activation function | ||
* @see https://github.com/nnstreamer/nntrainer | ||
* @author Niket Agarwal <[email protected]> | ||
* @author Eunju Yang <[email protected]> | ||
* @bug No known bugs except for NYI items | ||
* | ||
*/ | ||
|
@@ -18,7 +19,7 @@ | |
#include <common_properties.h> | ||
#include <layer_context.h> | ||
#include <layer_devel.h> | ||
#include <layer_impl.h> | ||
#include <layer_impl_cl.h> | ||
#include <node_exporter.h> | ||
#include <opencl_buffer.h> | ||
#include <opencl_kernel.h> | ||
|
@@ -30,17 +31,14 @@ namespace nntrainer { | |
* @brief A SwiGLU layer | ||
* | ||
*/ | ||
class SwiGLULayerCl final : public Layer { | ||
|
||
private: | ||
inline static ClContext cl_context_ref; | ||
class SwiGLULayerCl final : public LayerImplCl { | ||
|
||
public: | ||
/** | ||
* @brief Construct a new SwiGLU layer object | ||
* | ||
*/ | ||
SwiGLULayerCl() : Layer(), swiglu_props(props::Print()) {} | ||
SwiGLULayerCl() : LayerImplCl(), swiglu_props(props::Print()) {} | ||
|
||
/** | ||
* @brief Destroy the SwiGLU layer object | ||
|
@@ -79,7 +77,7 @@ class SwiGLULayerCl final : public Layer { | |
* @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method) | ||
*/ | ||
void exportTo(Exporter &exporter, | ||
const ml::train::ExportMethods &method) const override {}; | ||
const ml::train::ExportMethods &method) const override{}; | ||
|
||
/** | ||
* @copydoc Layer::getType() | ||
|
@@ -93,12 +91,6 @@ class SwiGLULayerCl final : public Layer { | |
|
||
inline static const std::string type = "swiglu"; | ||
|
||
static opencl::Kernel kernel_swiglu; | ||
static opencl::Kernel kernel_swiglu_fp16; | ||
|
||
std::tuple<props::Print> swiglu_props; /**< swiglu layer properties : unit - | ||
number of output neurons */ | ||
|
||
/** | ||
* @brief Process data and dimensions for swiglu operation | ||
* @param[in] input1 Tensor | ||
|
@@ -130,6 +122,20 @@ class SwiGLULayerCl final : public Layer { | |
void swiglu_cl_fp16(const __fp16 *matAdata, const __fp16 *vecXdata, | ||
__fp16 *vecYdata, unsigned int dim1, unsigned int dim2); | ||
#endif | ||
|
||
/** | ||
* @brief Register OpenCL kernels for SwiGLU layer. This should be called | ||
*/ | ||
static bool registerClKernels(); | ||
|
||
private: | ||
std::tuple<props::Print> swiglu_props; /**< swiglu layer properties : unit - | ||
number of output neurons */ | ||
|
||
inline static std::vector<ClContext::SharedPtrClKernel> | ||
layer_kernel_ptrs; /** kernel list relevant with this layer */ | ||
|
||
enum Kernels { SWIGLU_CL, SWIGLU_CL_FP16 }; /** kernels enum */ | ||
}; | ||
|
||
} // namespace nntrainer | ||
|