You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Tensorize currently doesn't work when axis of a buffer has extent = 1. See the example.
import tvm
from tvm import te, tir
from tvm.script import ty
@tvm.script.tir
def intrin_mma_desc(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (32, 1), "float32", scope="global", offset_factor=1)
B = tir.match_buffer(b, (32, 1), "float32", scope="global", offset_factor=1)
C = tir.match_buffer(c, (32, 32), "float32", scope="global", offset_factor=1)
with tir.block([32, 32, tir.reduce_axis(0, 1)], "root") as [vi, vj, vk]:
tir.bind(vi, 0)
tir.bind(vj, 0)
tir.bind(vk, 0)
tir.reads([C[vi:vi+32, vj:vj+32], A[vi:vi+32,vk:vk+1], B[vj:vj+32,vk:vk+1]])
tir.writes(C[vi:vi+32, vj:vj+32])
for i, j, k in tir.grid(32, 32, 1):
with tir.block([32, 32, tir.reduce_axis(0, 1)], "B") as [vii, vjj, vkk]:
tir.bind(vii, vi + i)
tir.bind(vjj, vj + j)
tir.bind(vkk, vk)
C[vii, vjj] = C[vii, vjj] + A[vii,vkk] * B[vjj,vkk]
@tvm.script.tir
def intrin_mma_impl(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (32, 1), "float32", scope="global", offset_factor=1)
B = tir.match_buffer(b, (32, 1), "float32", scope="global", offset_factor=1)
C = tir.match_buffer(c, (32, 32), "float32", scope="global", offset_factor=1)
with tir.block([32, 32, tir.reduce_axis(0, 1)], "root") as [vi, vj, vk]:
tir.bind(vi, 0)
tir.bind(vj, 0)
tir.bind(vk, 0)
tir.reads([C[vi:vi+32, vj:vj+32], A[vi:vi+32, vk:vk+1], B[vj:vj+32,vk:vk+1]])
tir.writes(C[vi:vi+32, vj:vj+32])
tir.evaluate(tir.tvm_mma_sync(C.data, C.elem_offset // 1024, A.data, A.elem_offset // 32, B.data, B.elem_offset // 32, dtype='handle'))
@tvm.script.tir
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])
with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
with tir.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
def main():
mod = tvm.script.create_module({'main': matmul})
s = tir.Schedule(mod)
C = s.get_block('C')
i, j, k = s.get_axes(C)
i0, i1 = s.split(i, factor=32)
j0, j1 = s.split(j, factor=32)
k0, k1 = s.split(k, factor=1)
s.reorder(i0, j0, k0, i1, j1, k1)
s.tensorize(i1, tir.TensorIntrin(intrin_mma_desc, intrin_mma_impl))
print(tvm.script.asscript(s.mod['main']))
main()
The above code doesn't work because mismatch between loop and tensor intrinsic description. The loop k1 is eliminated from the block iter var (this is because of this).
block C(iter_var(vi, range(min=0, ext=128)), iter_var(vj, range(min=0, ext=128)), iter_var(vk, range(min=0, ext=128)){
bind(vi, ((vio*32) + i0_inner))
bind(vj, ((vjo*32) + i1_inner))
bind(vk, vko) # the inner loop var of extent 1 is still eliminated.
reads([C[vi, vj], A[vi, vk], B[vj, vk]])
writes([C[vi, vj]])
C[vi, vj] = (C[vi, vj] + (A[vi, vk]*B[vj, vk]))
}
and as a result B.elem_offset is lowered to get_elem_offset(B[vjo * 32, 0] instead of get_elem_offset(B[vjo * 32, vko] because the detected binding of vk is incorrect.
The design question here is whether we should eliminated loop of extent 1 during blockize and tensorize.
The text was updated successfully, but these errors were encountered:
Tensorize currently doesn't work when axis of a buffer has extent = 1. See the example.
The above code doesn't work because mismatch between loop and tensor intrinsic description. The loop
k1
is eliminated from the block iter var (this is because of this).If I remove this part of code, we still need to fix the patten matcher here https://github.com/Hzfengsy/tvm-tensorir/blob/main/src/tir/schedule/primitives/blockize_tensorize.cc#L73 because the original loop after blockize will be
and as a result
B.elem_offset
is lowered toget_elem_offset(B[vjo * 32, 0]
instead ofget_elem_offset(B[vjo * 32, vko]
because the detected binding ofvk
is incorrect.The design question here is whether we should eliminated loop of extent 1 during blockize and tensorize.
The text was updated successfully, but these errors were encountered: