Skip to content

Commit

Permalink
Define and optimize RDNA1 (ggerganov#8085)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniandtheweb authored Jul 3, 2024
1 parent 5f2d4e6 commit d23287f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ typedef float2 dfloat2;
#define RDNA2
#endif

#if defined(__gfx1010__) || defined(__gfx1012__)
#define RDNA1
#endif

#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
Expand Down
10 changes: 7 additions & 3 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,16 @@ static constexpr __device__ int get_mmq_x_max_device() {
}

static constexpr int get_mmq_y_host(const int cc) {
return int8_mma_available(cc) || cc >= CC_VOLTA ? 128 : 64;
return cc >= CC_OFFSET_AMD ? (cc == CC_RDNA1 ? 64 : 128) : (cc >= CC_VOLTA ? 128 : 64);
}

static constexpr __device__ int get_mmq_y_device() {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#if defined(RDNA1)
return 64;
#else
return 128;
#endif // defined RDNA1
#else
#if __CUDA_ARCH__ >= CC_VOLTA
return 128;
Expand Down Expand Up @@ -2259,9 +2263,9 @@ static __device__ void mul_mat_q_process_tile(

template <ggml_type type, int mmq_x, int nwarps, bool need_check>
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#if defined(RDNA3) || defined(RDNA2)
#if defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
__launch_bounds__(WARP_SIZE*nwarps, 2)
#endif // defined(RDNA3) || defined(RDNA2)
#endif // defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
#else
#if __CUDA_ARCH__ >= CC_VOLTA
__launch_bounds__(WARP_SIZE*nwarps, 1)
Expand Down

0 comments on commit d23287f

Please sign in to comment.