Skip to content

Commit

Permalink
refactor sigmoid & ctx if >=800 or (>=530 && <800)
Browse files Browse the repository at this point in the history
  • Loading branch information
haricot committed Jan 16, 2025
1 parent 94d26bd commit f8c25e8
Showing 1 changed file with 68 additions and 73 deletions.
141 changes: 68 additions & 73 deletions candle-kernels/src/unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,68 +61,8 @@ __device__ __forceinline__ T silu_fwd(T x) {
}

template<typename T>
__device__ __forceinline__ T sigmoid_fwd(T x);

__device__ __forceinline__ __nv_bfloat16 exp_bf16(__nv_bfloat16 x) {
// Convert to double for maximum mantissa precision
double x_double = static_cast<double>(__bfloat162float(x));

// Compute exp in double precision to preserve mantissa bits
double exp_result = exp(x_double);

// Careful conversion back to preserve significant bits
return __float2bfloat16(static_cast<float>(exp_result));
}

__device__ __forceinline__ __half exp_halft(__half x) {
// Convert to double for maximum mantissa precision
double x_double = static_cast<double>(__half2float(x));

// Compute exp in double precision to preserve mantissa bits
double exp_result = exp(x_double);

// Careful conversion back to half
return __float2half(static_cast<float>(exp_result));
}

template<>
__device__ __forceinline__ __nv_bfloat16 sigmoid_fwd<__nv_bfloat16>(__nv_bfloat16 x) {
#if __CUDA_ARCH__ >= 800
__device__ __forceinline__ T sigmoid_fwd(T x) {
return x / (static_cast<T>(1) + expg(-x));
#elif __CUDA_ARCH__ >= 530
__nv_bfloat16 exp_neg_x = exp_bf16(__nv_bfloat16(-x));
__nv_bfloat16 one = __float2bfloat16(1.0f);
return recipg(one + exp_neg_x);
#else
// Fallback using float computation
float x_float = __bfloat162float(x);
float result = 1.0f / (1.0f + expf(-x_float));
return __float2bfloat16(result);
#endif
}

template<>
__device__ __forceinline__ __half sigmoid_fwd<__half>(__half x) {
#if __CUDA_ARCH__ >= 530
__half exp_neg_x = exp_halft(__hneg(x));
__half one = __float2half(1.0f);
return recipg(one + exp_neg_x);
#else
// Fallback using float computation
float x_float = __half2float(x);
float result = 1.0f / (1.0f + expf(-x_float));
return __float2half(result);
#endif
}

template<>
__device__ __forceinline__ float sigmoid_fwd<float>(float x) {
return 1.0f / (1.0f + expf(-x));
}

template<>
__device__ __forceinline__ double sigmoid_fwd<double>(double x) {
return 1.0 / (1.0 + exp(-x));
}

#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
Expand Down Expand Up @@ -156,18 +96,7 @@ __device__ T sign_(T t) {
return static_cast<T>(t > static_cast<T>(0)) - static_cast<T>(t < static_cast<T>(0));
}


#if __CUDA_ARCH__ >= 800
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x))
UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x))


#elif __CUDA_ARCH__ >= 530
#if __CUDA_ARCH__ >= 800 || (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800)
UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
UNARY_OP(__nv_bfloat16, uneg_bf16, -x)
UNARY_OP(__nv_bfloat16, urecip_bf16, recipg(x))
Expand All @@ -190,8 +119,74 @@ UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x))
#endif

#if __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 530
UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x))
#elif __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800

template<typename T>
__device__ __forceinline__ T sigmoid_fwd(T x);

__device__ __forceinline__ __nv_bfloat16 exp_bf16(__nv_bfloat16 x) {
// Convert to double for maximum mantissa precision
double x_double = static_cast<double>(__bfloat162float(x));

// Compute exp in double precision to preserve mantissa bits
double exp_result = exp(x_double);

// Careful conversion back to preserve significant bits
return __float2bfloat16(static_cast<float>(exp_result));
}


__device__ __forceinline__ __half exp_halft(__half x) {
// Convert to double for maximum mantissa precision
double x_double = static_cast<double>(__half2float(x));

// Compute exp in double precision to preserve mantissa bits
double exp_result = exp(x_double);

// Careful conversion back to half
return __float2half(static_cast<float>(exp_result));
}

template<>
__device__ __forceinline__ __nv_bfloat16 sigmoid_fwd<__nv_bfloat16>(__nv_bfloat16 x) {
__nv_bfloat16 exp_neg_x = exp_bf16(__nv_bfloat16(-x));
__nv_bfloat16 one = __float2bfloat16(1.0f);
return recipg(one + exp_neg_x);
}

template<>
__device__ __forceinline__ __half sigmoid_fwd<__half>(__half x) {
#if __CUDA_ARCH__ >= 530
__half exp_neg_x = exp_halft(__hneg(x));
__half one = __float2half(1.0f);
return recipg(one + exp_neg_x);
#else
// Fallback using float computation
float x_float = __half2float(x);
float result = 1.0f / (1.0f + expf(-x_float));
return __float2half(result);
#endif
}

template<>
__device__ __forceinline__ float sigmoid_fwd<float>(float x) {
return 1.0f / (1.0f + expf(-x));
}

template<>
__device__ __forceinline__ double sigmoid_fwd<double>(double x) {
return 1.0 / (1.0 + exp(-x));
}

UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x))
#endif


#if __CUDA_ARCH__ >= 530
UNARY_OP(__half, ucopy_f16, x)
UNARY_OP(__half, uneg_f16, -x)
UNARY_OP(__half, urecip_f16, recipg(x))
Expand Down

0 comments on commit f8c25e8

Please sign in to comment.