Skip to content

Commit

Permalink
Dummy functions for linking
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jan 8, 2025
1 parent 9949182 commit 2b9082a
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
22 changes: 22 additions & 0 deletions mistralrs-core/src/cuda/nonzero_bitwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,24 @@ void count_nonzero(const T *d_in, const uint32_t N, uint32_t *h_out) {
count_nonzero(d_in, N, &result); \
return result; \
}
#define COUNT_NONZERO_OP_DUMMY(RUST_NAME) \
extern "C" uint32_t count_nonzero_##RUST_NAME(const uint16_t *d_in, \
uint32_t N) { \
uint32_t result; \
count_nonzero(d_in, N, &result); \
return result; \
}

#if __CUDA_ARCH__ >= 800
COUNT_NONZERO_OP(__nv_bfloat16, bf16)
#else
COUNT_NONZERO_OP_DUMMY(bf16)
#endif

#if __CUDA_ARCH__ >= 530
COUNT_NONZERO_OP(__half, f16)
#else
COUNT_NONZERO_OP_DUMMY(f16)
#endif

COUNT_NONZERO_OP(float, f32)
Expand Down Expand Up @@ -115,12 +126,23 @@ void nonzero(const T *d_in, const uint32_t N, const uint32_t num_nonzero,
nonzero(d_in, N, num_nonzero, dims, num_dims, d_out); \
}

#define NONZERO_OP_DUMMY(RUST_NAME) \
extern "C" void nonzero_##RUST_NAME( \
const uint16_t *d_in, uint32_t N, uint32_t num_nonzero, \
const uint32_t *dims, uint32_t num_dims, uint32_t *d_out) { \
nonzero(d_in, N, num_nonzero, dims, num_dims, d_out); \
}

#if __CUDA_ARCH__ >= 800
NONZERO_OP(__nv_bfloat16, bf16)
#else
NONZERO_OP_DUMMY(bf16)
#endif

#if __CUDA_ARCH__ >= 530
NONZERO_OP(__half, f16)
#else
NONZERO_OP_DUMMY(f16)
#endif

NONZERO_OP(float, f32)
Expand Down
40 changes: 40 additions & 0 deletions mistralrs-quant/kernels/hqq/hqq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,21 @@ extern "C" void dequantize_8bit_u8_kernel_f16(unsigned char* Wq_packed, __half*
int blocks = cdiv(h*w, BLOCK_SIZE);
dequantize_8bit_u8_kernel<<<blocks, BLOCK_SIZE>>>(Wq_packed, scale, zero, W_r, h, w);
}
#else
extern "C" void dequantize_8bit_u8_kernel_f16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) {
assert(false);
}
#endif

#if __CUDA_ARCH__ >= 800
extern "C" void dequantize_8bit_u8_kernel_bf16(unsigned char* Wq_packed, __nv_bfloat16* scale, __nv_bfloat16* zero, __nv_bfloat16* W_r, int h, int w) {
int blocks = cdiv(h*w, BLOCK_SIZE);
dequantize_8bit_u8_kernel<<<blocks, BLOCK_SIZE>>>(Wq_packed, scale, zero, W_r, h, w);
}
#else
extern "C" void dequantize_8bit_u8_kernel_bf16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) {
assert(false);
}
#endif


Expand Down Expand Up @@ -86,13 +94,21 @@ extern "C" void dequantize_4bit_u8_kernel_f16(unsigned char* Wq_packed, __half*
int blocks = cdiv(h*w, BLOCK_SIZE);
dequantize_4bit_u8_kernel<<<blocks, BLOCK_SIZE>>>(Wq_packed, scale, zero, W_r, h, w);
}
#else
extern "C" void dequantize_4bit_u8_kernel_f16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) {
assert(false);
}
#endif

#if __CUDA_ARCH__ >= 800
extern "C" void dequantize_4bit_u8_kernel_bf16(unsigned char* Wq_packed, __nv_bfloat16* scale, __nv_bfloat16* zero, __nv_bfloat16* W_r, int h, int w) {
int blocks = cdiv(h*w, BLOCK_SIZE);
dequantize_4bit_u8_kernel<<<blocks, BLOCK_SIZE>>>(Wq_packed, scale, zero, W_r, h, w);
}
#else
extern "C" void dequantize_4bit_u8_kernel_bf16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) {
assert(false);
}
#endif

/*******************************************************************************************************************************************/
Expand Down Expand Up @@ -135,13 +151,21 @@ extern "C" void dequantize_2bit_u8_kernel_f16(unsigned char* Wq_packed, __half*
int blocks = cdiv(h*w, BLOCK_SIZE);
dequantize_2bit_u8_kernel<<<blocks, BLOCK_SIZE>>>(Wq_packed, scale, zero, W_r, h, w);
}
#else
extern "C" void dequantize_2bit_u8_kernel_f16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) {
assert(false);
}
#endif

#if __CUDA_ARCH__ >= 800
extern "C" void dequantize_2bit_u8_kernel_bf16(unsigned char* Wq_packed, __nv_bfloat16* scale, __nv_bfloat16* zero, __nv_bfloat16* W_r, int h, int w) {
int blocks = cdiv(h*w, BLOCK_SIZE);
dequantize_2bit_u8_kernel<<<blocks, BLOCK_SIZE>>>(Wq_packed, scale, zero, W_r, h, w);
}
#else
extern "C" void dequantize_2bit_u8_kernel_bf16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) {
assert(false);
}
#endif


Expand Down Expand Up @@ -219,13 +243,21 @@ extern "C" void dequantize_1bit_u8_kernel_f16(unsigned char* Wq_packed, __half*
int blocks = cdiv(h*w, BLOCK_SIZE);
dequantize_1bit_u8_kernel<<<blocks, BLOCK_SIZE>>>(Wq_packed, scale, zero, W_r, h, w);
}
#else
extern "C" void dequantize_1bit_u8_kernel_f16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) {
assert(false);
}
#endif

#if __CUDA_ARCH__ >= 800
extern "C" void dequantize_1bit_u8_kernel_bf16(unsigned char* Wq_packed, __nv_bfloat16* scale, __nv_bfloat16* zero, __nv_bfloat16* W_r, int h, int w) {
int blocks = cdiv(h*w, BLOCK_SIZE);
dequantize_1bit_u8_kernel<<<blocks, BLOCK_SIZE>>>(Wq_packed, scale, zero, W_r, h, w);
}
#else
extern "C" void dequantize_1bit_u8_kernel_bf16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) {
assert(false);
}
#endif

// //Shared
Expand Down Expand Up @@ -308,11 +340,19 @@ extern "C" void dequantize_3bit_32_kernel_f16(int32_t* Wq_packed, __half* scale,
int blocks = cdiv(h*w, BLOCK_SIZE);
dequantize_3bit_32_kernel<<<blocks, BLOCK_SIZE>>>(Wq_packed, scale, zero, W_r, h, w);
}
#else
extern "C" void dequantize_3bit_32_kernel_f16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) {
assert(false);
}
#endif

#if __CUDA_ARCH__ >= 800
extern "C" void dequantize_3bit_32_kernel_bf16(int32_t* Wq_packed, __nv_bfloat16* scale, __nv_bfloat16* zero, __nv_bfloat16* W_r, int h, int w) {
int blocks = cdiv(h*w, BLOCK_SIZE);
dequantize_3bit_32_kernel<<<blocks, BLOCK_SIZE>>>(Wq_packed, scale, zero, W_r, h, w);
}
#else
extern "C" void dequantize_3bit_32_kernel_bf16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) {
assert(false);
}
#endif

0 comments on commit 2b9082a

Please sign in to comment.