From 4073423f786af5c1416a3774104460d99e28225c Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Tue, 9 Jul 2024 14:29:58 -0700 Subject: [PATCH] [Tutorial] Fix a pointer not advanced issue in the persistent kernels. (#4218) It looks like the operand pointers are not advanced in the baseline and persistent kernels. --- python/tutorials/09-persistent-matmul.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 3c78435b1724..fdbdbfecfb86 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -69,11 +69,10 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, # start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N - offs_am = tl.arange(0, BLOCK_SIZE_M) - offs_bn = tl.arange(0, BLOCK_SIZE_N) - - offs_am = tl.where(offs_am < M - start_m, offs_am, 0) - offs_bn = tl.where(offs_bn < N - start_n, offs_bn, 0) + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) @@ -186,10 +185,10 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N - offs_am = tl.arange(0, BLOCK_SIZE_M) - offs_bn = tl.arange(0, BLOCK_SIZE_N) - offs_am = tl.where(offs_am < M - start_m, offs_am, 0) - offs_bn = tl.where(offs_bn < N - start_n, offs_bn, 0) + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)