Skip to content

Commit

Permalink
[Tutorial] Fix a pointer not advanced issue in the persistent kernels. (
Browse files Browse the repository at this point in the history
#4218)

It looks like the operand pointers are not advanced in the baseline and
persistent kernels.
  • Loading branch information
htyu authored Jul 9, 2024
1 parent 18996e7 commit 4073423
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions python/tutorials/09-persistent-matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,10 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, #
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N

offs_am = tl.arange(0, BLOCK_SIZE_M)
offs_bn = tl.arange(0, BLOCK_SIZE_N)

offs_am = tl.where(offs_am < M - start_m, offs_am, 0)
offs_bn = tl.where(offs_bn < N - start_n, offs_bn, 0)
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < M, offs_am, 0)
offs_bn = tl.where(offs_bn < N, offs_bn, 0)

offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
Expand Down Expand Up @@ -186,10 +185,10 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, #

start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offs_am = tl.arange(0, BLOCK_SIZE_M)
offs_bn = tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < M - start_m, offs_am, 0)
offs_bn = tl.where(offs_bn < N - start_n, offs_bn, 0)
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < M, offs_am, 0)
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
Expand Down

0 comments on commit 4073423

Please sign in to comment.