Skip to content

Commit

Permalink
cuda :skip fallback sigmoid (>=530 && <800)
Browse files Browse the repository at this point in the history
  • Loading branch information
haricot committed Jan 17, 2025
1 parent f8c25e8 commit c0a7838
Showing 1 changed file with 1 addition and 65 deletions.
66 changes: 1 addition & 65 deletions candle-kernels/src/unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ __device__ __forceinline__ T silu_fwd(T x) {

template<typename T>
__device__ __forceinline__ T sigmoid_fwd(T x) {
return x / (static_cast<T>(1) + expg(-x));
return recipg(static_cast<T>(1) + expg(-x));
}

#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
Expand Down Expand Up @@ -119,73 +119,9 @@ UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x))
#endif

#if __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 530
UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x))
#elif __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800

template<typename T>
__device__ __forceinline__ T sigmoid_fwd(T x);

__device__ __forceinline__ __nv_bfloat16 exp_bf16(__nv_bfloat16 x) {
// Convert to double for maximum mantissa precision
double x_double = static_cast<double>(__bfloat162float(x));

// Compute exp in double precision to preserve mantissa bits
double exp_result = exp(x_double);

// Careful conversion back to preserve significant bits
return __float2bfloat16(static_cast<float>(exp_result));
}


__device__ __forceinline__ __half exp_halft(__half x) {
// Convert to double for maximum mantissa precision
double x_double = static_cast<double>(__half2float(x));

// Compute exp in double precision to preserve mantissa bits
double exp_result = exp(x_double);

// Careful conversion back to half
return __float2half(static_cast<float>(exp_result));
}

template<>
__device__ __forceinline__ __nv_bfloat16 sigmoid_fwd<__nv_bfloat16>(__nv_bfloat16 x) {
__nv_bfloat16 exp_neg_x = exp_bf16(__nv_bfloat16(-x));
__nv_bfloat16 one = __float2bfloat16(1.0f);
return recipg(one + exp_neg_x);
}

template<>
__device__ __forceinline__ __half sigmoid_fwd<__half>(__half x) {
#if __CUDA_ARCH__ >= 530
__half exp_neg_x = exp_halft(__hneg(x));
__half one = __float2half(1.0f);
return recipg(one + exp_neg_x);
#else
// Fallback using float computation
float x_float = __half2float(x);
float result = 1.0f / (1.0f + expf(-x_float));
return __float2half(result);
#endif
}

template<>
__device__ __forceinline__ float sigmoid_fwd<float>(float x) {
return 1.0f / (1.0f + expf(-x));
}

template<>
__device__ __forceinline__ double sigmoid_fwd<double>(double x) {
return 1.0 / (1.0 + exp(-x));
}

UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x))
#endif


#if __CUDA_ARCH__ >= 530
UNARY_OP(__half, ucopy_f16, x)
UNARY_OP(__half, uneg_f16, -x)
Expand Down

0 comments on commit c0a7838

Please sign in to comment.