From f8c25e8a75e30503ffe0ad4aac32801b40791a3e Mon Sep 17 00:00:00 2001 From: Nicolas <344493+haricot@users.noreply.github.com> Date: Thu, 16 Jan 2025 22:22:12 +0100 Subject: [PATCH] refactor sigmoid & ctx if >=800 or (>=530 && <800) --- candle-kernels/src/unary.cu | 141 +++++++++++++++++------------------- 1 file changed, 68 insertions(+), 73 deletions(-) diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index 6f7e64933..1a47a9dc1 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -61,68 +61,8 @@ __device__ __forceinline__ T silu_fwd(T x) { } template -__device__ __forceinline__ T sigmoid_fwd(T x); - -__device__ __forceinline__ __nv_bfloat16 exp_bf16(__nv_bfloat16 x) { - // Convert to double for maximum mantissa precision - double x_double = static_cast(__bfloat162float(x)); - - // Compute exp in double precision to preserve mantissa bits - double exp_result = exp(x_double); - - // Careful conversion back to preserve significant bits - return __float2bfloat16(static_cast(exp_result)); -} - -__device__ __forceinline__ __half exp_halft(__half x) { - // Convert to double for maximum mantissa precision - double x_double = static_cast(__half2float(x)); - - // Compute exp in double precision to preserve mantissa bits - double exp_result = exp(x_double); - - // Careful conversion back to half - return __float2half(static_cast(exp_result)); -} - -template<> -__device__ __forceinline__ __nv_bfloat16 sigmoid_fwd<__nv_bfloat16>(__nv_bfloat16 x) { -#if __CUDA_ARCH__ >= 800 +__device__ __forceinline__ T sigmoid_fwd(T x) { return x / (static_cast(1) + expg(-x)); -#elif __CUDA_ARCH__ >= 530 - __nv_bfloat16 exp_neg_x = exp_bf16(__nv_bfloat16(-x)); - __nv_bfloat16 one = __float2bfloat16(1.0f); - return recipg(one + exp_neg_x); -#else - // Fallback using float computation - float x_float = __bfloat162float(x); - float result = 1.0f / (1.0f + expf(-x_float)); - return __float2bfloat16(result); -#endif -} - -template<> -__device__ __forceinline__ __half sigmoid_fwd<__half>(__half x) { -#if __CUDA_ARCH__ >= 530 - __half exp_neg_x = exp_halft(__hneg(x)); - __half one = __float2half(1.0f); - return recipg(one + exp_neg_x); -#else - // Fallback using float computation - float x_float = __half2float(x); - float result = 1.0f / (1.0f + expf(-x_float)); - return __float2half(result); -#endif -} - -template<> -__device__ __forceinline__ float sigmoid_fwd(float x) { - return 1.0f / (1.0f + expf(-x)); -} - -template<> -__device__ __forceinline__ double sigmoid_fwd(double x) { - return 1.0 / (1.0 + exp(-x)); } #define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \ @@ -156,18 +96,7 @@ __device__ T sign_(T t) { return static_cast(t > static_cast(0)) - static_cast(t < static_cast(0)); } - -#if __CUDA_ARCH__ >= 800 -UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x)) -UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x)) -UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x)) -UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x)) -UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param)) -UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x)) -UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x)) - - -#elif __CUDA_ARCH__ >= 530 +#if __CUDA_ARCH__ >= 800 || (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800) UNARY_OP(__nv_bfloat16, ucopy_bf16, x) UNARY_OP(__nv_bfloat16, uneg_bf16, -x) UNARY_OP(__nv_bfloat16, urecip_bf16, recipg(x)) @@ -190,8 +119,74 @@ UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param)) UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x)) UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param)) UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x)) +#endif + +#if __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 530 UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x)) +#elif __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800 + +template +__device__ __forceinline__ T sigmoid_fwd(T x); + +__device__ __forceinline__ __nv_bfloat16 exp_bf16(__nv_bfloat16 x) { + // Convert to double for maximum mantissa precision + double x_double = static_cast(__bfloat162float(x)); + + // Compute exp in double precision to preserve mantissa bits + double exp_result = exp(x_double); + + // Careful conversion back to preserve significant bits + return __float2bfloat16(static_cast(exp_result)); +} + +__device__ __forceinline__ __half exp_halft(__half x) { + // Convert to double for maximum mantissa precision + double x_double = static_cast(__half2float(x)); + + // Compute exp in double precision to preserve mantissa bits + double exp_result = exp(x_double); + + // Careful conversion back to half + return __float2half(static_cast(exp_result)); +} + +template<> +__device__ __forceinline__ __nv_bfloat16 sigmoid_fwd<__nv_bfloat16>(__nv_bfloat16 x) { + __nv_bfloat16 exp_neg_x = exp_bf16(__nv_bfloat16(-x)); + __nv_bfloat16 one = __float2bfloat16(1.0f); + return recipg(one + exp_neg_x); +} + +template<> +__device__ __forceinline__ __half sigmoid_fwd<__half>(__half x) { +#if __CUDA_ARCH__ >= 530 + __half exp_neg_x = exp_halft(__hneg(x)); + __half one = __float2half(1.0f); + return recipg(one + exp_neg_x); +#else + // Fallback using float computation + float x_float = __half2float(x); + float result = 1.0f / (1.0f + expf(-x_float)); + return __float2half(result); +#endif +} + +template<> +__device__ __forceinline__ float sigmoid_fwd(float x) { + return 1.0f / (1.0f + expf(-x)); +} + +template<> +__device__ __forceinline__ double sigmoid_fwd(double x) { + return 1.0 / (1.0 + exp(-x)); +} + +UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x)) +#endif + + +#if __CUDA_ARCH__ >= 530 UNARY_OP(__half, ucopy_f16, x) UNARY_OP(__half, uneg_f16, -x) UNARY_OP(__half, urecip_f16, recipg(x))