Skip to content

Commit

Permalink
CUDA: rename macros to avoid conflicts with WinAPI (ggerganov#10736)
Browse files Browse the repository at this point in the history
* Renames NVIDIA GPU-architecture flags to avoid name clashes with WinAPI. (e.g. CC_PASCAL, GPU architecture or WinAPI pascal compiler flag?)

* Reverts erroneous rename in SYCL-code.

* Renames GGML_CUDA_MIN_CC_DP4A to GGML_CUDA_CC_DP4A.

* Renames the rest of the compute capability macros for consistency.
  • Loading branch information
aendk authored Dec 10, 2024
1 parent a86ad84 commit 750cb3e
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 71 deletions.
2 changes: 1 addition & 1 deletion ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
GGML_TABLE_END()

//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics
GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
Expand Down
70 changes: 35 additions & 35 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,28 @@
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons

#define CC_PASCAL 600
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700
#define CC_TURING 750
#define CC_AMPERE 800
#define CC_OFFSET_AMD 1000000
#define GGML_CUDA_CC_PASCAL 600
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define GGML_CUDA_CC_VOLTA 700
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_OFFSET_AMD 1000000

// GCN/CNDA, wave size is 64
#define CC_GCN4 (CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16
#define CC_VEGA (CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue
#define CC_VEGA20 (CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a
#define CC_CDNA (CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers
#define CC_CDNA2 (CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing
#define CC_CDNA3 (CC_OFFSET_AMD + 942) // MI300
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 942) // MI300

// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
#define CC_RDNA1 (CC_OFFSET_AMD + 1010) // RX 5000
#define CC_RDNA2 (CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
#define CC_RDNA3 (CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA

#define CC_QY1 210
#define CC_QY2 220
#define GGML_CUDA_CC_QY1 210
#define GGML_CUDA_CC_QY2 220

#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses

Expand Down Expand Up @@ -131,36 +131,36 @@ typedef float dfloat; // dequantize float
typedef float2 dfloat2;
#endif // GGML_CUDA_F16

#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#define FP16_AVAILABLE
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL

#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
#define FAST_FP16_AVAILABLE
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610

#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#define FP16_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA

#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define INT8_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING

#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
#define FLASH_ATTN_AVAILABLE
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)

static constexpr bool fast_fp16_available(const int cc) {
return cc >= CC_PASCAL && cc != 610;
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
}

static constexpr bool fp16_mma_available(const int cc) {
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
}

static constexpr bool int8_mma_available(const int cc) {
return cc < CC_OFFSET_AMD && cc >= CC_TURING;
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
}

[[noreturn]]
Expand All @@ -187,15 +187,15 @@ static __device__ void no_device_code(
#endif // __CUDA_ARCH__

static __device__ __forceinline__ int warp_reduce_sum(int x) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
return __reduce_add_sync(0xffffffff, x);
#else
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
}
return x;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
}

static __device__ __forceinline__ float warp_reduce_sum(float x) {
Expand Down Expand Up @@ -284,7 +284,7 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
}

static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
Expand All @@ -293,7 +293,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#else
GGML_UNUSED(x);
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
}

#if CUDART_VERSION < CUDART_HMASK
Expand Down Expand Up @@ -333,13 +333,13 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i

#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)

#if __CUDA_ARCH__ >= MIN_CC_DP4A
#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
return __dp4a(a, b, c);
#else // __CUDA_ARCH__ >= MIN_CC_DP4A
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
const int8_t * a8 = (const int8_t *) &a;
const int8_t * b8 = (const int8_t *) &b;
return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A

#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
}
Expand Down
6 changes: 3 additions & 3 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __

template <bool need_check>
static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
#if __CUDA_ARCH__ >= CC_PASCAL
#if __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;

const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
Expand Down Expand Up @@ -64,7 +64,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
GGML_UNUSED(y);
GGML_UNUSED(k);
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ >= CC_PASCAL
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
}

template<typename dst_t>
Expand Down Expand Up @@ -599,7 +599,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) {
if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_PASCAL) {
return dequantize_block_q8_0_f16_cuda;
}
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);

// On AMD the tile kernels perform poorly, use the vec kernel instead:
if (cc >= CC_OFFSET_AMD) {
if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {
Expand Down
12 changes: 6 additions & 6 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].smpb = prop.sharedMemPerBlock;
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
info.devices[id].smpbo = prop.sharedMemPerBlock;
info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
info.devices[id].cc = 100*prop.major + 10*prop.minor + GGML_CUDA_CC_OFFSET_AMD;
#else
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = 100*prop.major + 10*prop.minor;
Expand Down Expand Up @@ -1081,7 +1081,7 @@ static void ggml_cuda_op_mul_mat_cublas(

const int compute_capability = ggml_cuda_info().devices[id].cc;

if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
if (src0->type != GGML_TYPE_F16) {
Expand All @@ -1108,7 +1108,7 @@ static void ggml_cuda_op_mul_mat_cublas(
const half beta_f16 = 0.0f;

cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) {
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
cu_compute_type = CUBLAS_COMPUTE_32F;
}

Expand Down Expand Up @@ -1612,7 +1612,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
cudaDataType_t cu_data_type = CUDA_R_16F;

if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) {
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
cu_compute_type = CUBLAS_COMPUTE_32F;
}

Expand Down Expand Up @@ -2357,7 +2357,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
std::vector<void *> ggml_cuda_cpy_fn_ptrs;

if (cuda_ctx->cuda_graph->graph == nullptr) {
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
Expand Down Expand Up @@ -3028,7 +3028,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return true;
}
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
}
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ struct mma_int_C_I16J8 {

__device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
#ifdef INT8_MMA_AVAILABLE
#if __CUDA_ARCH__ >= CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
Expand All @@ -183,7 +183,7 @@ struct mma_int_C_I16J8 {
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
#endif // __CUDA_ARCH__ >= CC_AMPERE
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
Expand All @@ -193,7 +193,7 @@ struct mma_int_C_I16J8 {

__device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
#ifdef INT8_MMA_AVAILABLE
#if __CUDA_ARCH__ >= CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
Expand All @@ -211,7 +211,7 @@ struct mma_int_C_I16J8 {
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[3]), "r"(mma_B.x[1]));
#endif // __CUDA_ARCH__ >= CC_AMPERE
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
Expand Down
10 changes: 5 additions & 5 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void ggml_cuda_op_mul_mat_q(
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
// Also its fixup needs to allocate a temporary buffer in the memory pool.
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
const bool use_stream_k = compute_capability >= CC_VOLTA && compute_capability < CC_OFFSET_AMD && src1_ncols == ne11;
const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};

switch (src0->type) {
Expand Down Expand Up @@ -136,17 +136,17 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return true;
}

if (cc < MIN_CC_DP4A) {
if (cc < GGML_CUDA_CC_DP4A) {
return false;
}

#ifdef GGML_CUDA_FORCE_MMQ
return true;
#endif //GGML_CUDA_FORCE_MMQ

if (cc < CC_OFFSET_AMD) {
return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
if (cc < GGML_CUDA_CC_OFFSET_AMD) {
return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}

return (cc < CC_RDNA3 && cc != CC_CDNA && cc != CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
return (cc < GGML_CUDA_CC_RDNA3 && cc != GGML_CUDA_CC_CDNA && cc != GGML_CUDA_CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
Loading

0 comments on commit 750cb3e

Please sign in to comment.