Skip to content

Commit

Permalink
fix __CUDA_ARCH__ if >=800 or (>=530 && <800)
Browse files Browse the repository at this point in the history
  • Loading branch information
haricot committed Jan 16, 2025
1 parent 9b1022e commit 94d26bd
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 50 deletions.
3 changes: 1 addition & 2 deletions candle-kernels/src/affine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ extern "C" __global__ void FN_NAME( \
} \
} \

#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 800 || (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800)
AFFINE_OP(__nv_bfloat16, affine_bf16)
#endif

#if __CUDA_ARCH__ >= 530
AFFINE_OP(__nv_bfloat16, affine_bf16)
AFFINE_OP(__half, affine_f16)
#endif

Expand Down
28 changes: 11 additions & 17 deletions candle-kernels/src/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,34 +71,28 @@ extern "C" __global__ void FN_NAME( \
} \

#if __CUDA_ARCH__ >= 800
CAST_OP(__nv_bfloat16, double, cast_bf16_f64)
CAST_OP(double, __nv_bfloat16, cast_f64_bf16)
CAST_OP(uint8_t, __nv_bfloat16, cast_u8_bf16 )
#elif __CUDA_ARCH__ >= 800 || (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800)
CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16)
CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32)
CAST_OP(__nv_bfloat16, float, cast_bf16_f32)
CAST_OP(uint32_t, __nv_bfloat16, cast_u32_bf16)
CAST_OP(float, __nv_bfloat16, cast_f32_bf16)
CAST_OP(uint8_t, __nv_bfloat16, cast_u8_bf16 )

CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32)
CAST_OP(__nv_bfloat16, float, cast_bf16_f32)
CAST_OP(__nv_bfloat16, double, cast_bf16_f64)
CAST_OP(uint8_t, __nv_bfloat16, cast_u8_bf16)
CAST_OP(uint32_t, __nv_bfloat16, cast_u32_bf16)
CAST_OP(float, __nv_bfloat16, cast_f32_bf16)
CAST_OP(double, __nv_bfloat16, cast_f64_bf16)
CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8)
CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16)
CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16)
#else
#include <cuda.h>
#if CUDA_VERSION >= 11000
CAST_OP(__nv_bfloat16, float, cast_bf16_f32)
CAST_OP(float, __nv_bfloat16, cast_f32_bf16)
CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8)
CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16)
#elif __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800 // needed CUDA_VERSION >= 11000
CAST_THROUGH_OP(__nv_bfloat16, double, float, cast_bf16_f64)
CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16)
CAST_THROUGH_OP(double, __nv_bfloat16, float, cast_f64_bf16)
CAST_THROUGH_OP(uint8_t, __nv_bfloat16, float, cast_u8_bf16)
#endif
#endif

#if __CUDA_ARCH__ >= 530
CAST_OP(__half, __half, cast_f16_f16)
CAST_OP(__half, __half, cast_f16_f16)

CAST_THROUGH_OP(__half, uint8_t, float, cast_f16_u8)
CAST_OP(__half, uint32_t, cast_f16_u32)
Expand Down
29 changes: 5 additions & 24 deletions candle-kernels/src/cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ __device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a,
__device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a, b); }
__device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b); }
__device__ __forceinline__ uint8_t maxg(uint8_t a, uint8_t b) { return max(a, b); }
#if __CUDA_ARCH__ >= 530
#include "cuda_bf16.h"

#if __CUDA_ARCH__ >= 800 || (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800)
__device__ __forceinline__ __nv_bfloat16 powg(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(powf(__bfloat162float(a), __bfloat162float(b))); }
__device__ __forceinline__ bool isnang(__nv_bfloat16 a) { return __hisnan(a); }
__device__ __forceinline__ __nv_bfloat16 sqrtg(__nv_bfloat16 a) { return hsqrt(a); }
Expand All @@ -178,7 +178,9 @@ __device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a);
__device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); }
__device__ __forceinline__ __nv_bfloat16 absg(__nv_bfloat16 a) { return __habs(a); }
__device__ __forceinline__ __nv_bfloat16 copysigng(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(copysignf(__bfloat162float(a), __bfloat162float(b))); }
#endif

#if __CUDA_ARCH__ >= 530
__device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); }
__device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); }
__device__ __forceinline__ __half sqrtg(__half a) { return hsqrt(a); }
Expand All @@ -197,25 +199,4 @@ __device__ __forceinline__ __half logg(__half a) { return hlog(a); }
__device__ __forceinline__ __half expg(__half a) { return hexp(a); }
__device__ __forceinline__ __half absg(__half a) { return __habs(a); }
__device__ __forceinline__ __half copysigng(__half a, __half b) { return __float2half(copysignf(__half2float(a), __half2float(b))); }
#endif

#if __CUDA_ARCH__ >= 800
__device__ __forceinline__ __nv_bfloat16 powg(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(powf(__bfloat162float(a), __bfloat162float(b))); }
__device__ __forceinline__ bool isnang(__nv_bfloat16 a) { return __hisnan(a); }
__device__ __forceinline__ __nv_bfloat16 sqrtg(__nv_bfloat16 a) { return hsqrt(a); }
__device__ __forceinline__ __nv_bfloat16 cosg(__nv_bfloat16 a) { return hcos(a); }
__device__ __forceinline__ __nv_bfloat16 sing(__nv_bfloat16 a) { return hsin(a); }
__device__ __forceinline__ __nv_bfloat16 recipg(__nv_bfloat16 a) { __nv_bfloat16 one = 1.0; return one / a; }
__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); }
__device__ __forceinline__ __nv_bfloat16 tanhg(__nv_bfloat16 a) { return __float2bfloat16(tanhf(__bfloat162float(a))); }
__device__ __forceinline__ __nv_bfloat16 erfg(__nv_bfloat16 a) { return __float2bfloat16(erff(__bfloat162float(a))); }
__device__ __forceinline__ __nv_bfloat16 ceilg(__nv_bfloat16 a) { return __float2bfloat16(ceilf(__bfloat162float(a))); }
__device__ __forceinline__ __nv_bfloat16 floorg(__nv_bfloat16 a) { return __float2bfloat16(floorf(__bfloat162float(a))); }
__device__ __forceinline__ __nv_bfloat16 roundg(__nv_bfloat16 a) { return __float2bfloat16(roundf(__bfloat162float(a))); }
__device__ __forceinline__ __nv_bfloat16 normcdfg(__nv_bfloat16 a) { return __float2bfloat16(normcdff(__bfloat162float(a))); }
__device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); }
__device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); }
__device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); }
__device__ __forceinline__ __nv_bfloat16 absg(__nv_bfloat16 a) { return __habs(a); }
__device__ __forceinline__ __nv_bfloat16 copysigng(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(copysignf(__bfloat162float(a), __bfloat162float(b))); }
#endif
#endif
7 changes: 3 additions & 4 deletions candle-kernels/src/fill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,15 @@ COPY2D_OP(double, copy2d_f64)
COPY2D_OP(uint8_t, copy2d_u8)
COPY2D_OP(uint32_t, copy2d_u32)
COPY2D_OP(int64_t, copy2d_i64)

#if __CUDA_ARCH__ >= 530
#include <cuda_bf16.h>
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
COPY2D_OP(__half, copy2d_f16)
COPY2D_OP(__nv_bfloat16, copy2d_bf16)
#endif

#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 800 || (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800)
#include <cuda_bf16.h>
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
COPY2D_OP(__nv_bfloat16, copy2d_bf16)
#endif
#endif
4 changes: 1 addition & 3 deletions candle-kernels/src/indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ extern "C" __global__ void FN_NAME( \
) { scatter_add(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \


#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 800 || (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ < 800)
IS_OP(__nv_bfloat16, int64_t, is_i64_bf16)
IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16)
IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16)
Expand All @@ -162,8 +162,6 @@ SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
#endif

#if __CUDA_ARCH__ >= 530
#include "cuda_bf16.h"
IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16)
IS_OP(__half, int64_t, is_i64_f16)
IS_OP(__half, uint32_t, is_u32_f16)
IS_OP(__half, uint8_t, is_u8_f16)
Expand Down

0 comments on commit 94d26bd

Please sign in to comment.