Skip to content

Commit

Permalink
Use uint64_t instead of unsigned long for clarity (opendatahub-io#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
mawong-amd authored Jun 21, 2024
1 parent b02fcb2 commit 3e9dac6
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions csrc/custom/custom_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B,

__shared__ half s[1024 * 32];

unsigned long n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE;
uint64_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE;

for (uint32_t k = 0; k < min(K * M, 32 * 1024);
k += THRDS * WvPrGrp * A_CHUNK) {
Expand Down Expand Up @@ -484,7 +484,7 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B,
// Algorithm does 64 lane k-splitting / wave and uses
// WG ID and Thread ID to find the index.
//----------------------------------------------------
unsigned long n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE;
uint64_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE;

// Check whether there will be fragmenation!
// This will happen only for the last wave!
Expand Down Expand Up @@ -836,7 +836,7 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B,
// Algorithm does 64 lane k-splitting / wave and uses
// WG ID and Thread ID to find the index.
//----------------------------------------------------
unsigned long n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE;
uint64_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE;

// Check whether there will be fragmenation!
// This will happen only for the last wave!
Expand Down Expand Up @@ -1188,7 +1188,7 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B,
// Algorithm does 64 lane k-splitting / wave and uses
// WG ID and Thread ID to find the index.
//----------------------------------------------------
unsigned long n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE;
uint64_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE;

// Check whether there will be fragmenation!
// This will happen only for the last wave!
Expand Down Expand Up @@ -1540,7 +1540,7 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B,
// Algorithm does 64 lane k-splitting / wave and uses
// WG ID and Thread ID to find the index.
//----------------------------------------------------
unsigned long n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE;
uint64_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE;

// Check whether there will be fragmenation!
// This will happen only for the last wave!
Expand Down

0 comments on commit 3e9dac6

Please sign in to comment.