-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Roadmap] Sparse TensorIR #466
Comments
Thanks @yzh119 . can you paste a TVMScript example mockup of the sparse TensorIR, and transformation allowed? This would help us greatly understand the relation between the current design rationale and the new one. |
Yes I'll elaborate more here. |
Sparse FormatsIn Sparse TIR we have four kinds of axis.
They can represent both sparse matrices and ragged tensors. For example, if we want to represent an irregular batched matrix multiplication: For i in range(b):
For j in range(n[i]):
For k in range(m[i]):
For l in range(k[i]):
C[i, j, k] = C[i, j, k] + A[i, j, l] * B[i, l, k] Its dependent tree is:
The following syntax describe how do we define such structure in Sparse TIR. i = tir.sp.FixedDenseAxis(b)
j = tir.sp.VariableDenseAxis(N)
k = tir.sp.VariableDenseAxis(M) (where N, M, K are input 1-dim buffers in this case)
l = tir.sp.VariableDenseAxis(K)
Fmt = tir.sp.format([[i, j], [i, k], [i, l]], (i, j, k, l))
A = tir.sp.match_buffer(A_handle, fmt, (i, j, l))
B = tir.sp.match_buffer(B_handle, fmt, (i, j, k))
C = tir.sp.match_buffer(C_handle, fmt, (i, k, l)) Sparse BlocksA sparse block indicates
with tir.sp.block([i, j, k], [spatial, spatial, reduce], [[0, 1], [2]]) as [vi, vj, vk]:
pass |
After discussion with @MasterJH5574 and @tqchen , we decide to update the syntax as following: def sddmm(a: ty.handle, b: ty.handle, c: ty.handle, fmt: ty.handle):
N = tir.var('n')
M = tir.var('m')
B = tir.var('b') # block size
K = tir.var('k')
i = tir.match_axis(fmt, 'i', N)
j = tir.match_axis(fmt, 'j', M)
k = tir.match_axis(fmt, 'k', K)
bi = tir.match_axis(fmt, 'bi', B)
bj = tir.match_axis(fmt, 'bj', B)
A = tir.match_buffer(a, (i, bi, k), 'float32')
B = tir.match_buffer(b, (tir.to_dense(j), bj, k), 'float32')
C = tir.match_buffer(c, (i, j, bi, bj), 'float32')
for vi, vj in tir.fuse(tir.cord(i), tir.cord(j)):
for vk in tir.cord(k):
for vbi in tir.cord(bi):
for vbj in tir.cord(vbj):
with tir.block([], 'sddmm'):
tir.block_attr({'sparse': True})
with tir.init():
C[vi, vj, vbi, vbj] = 0.
C[vi, vj, vbi, vbj] = C[vi, vj, vbi, vbj] +\
A[vi, vbi, vk] * B[vj, vbj, vk] where Below we describe detailed syntax of our new design. Format DefinitionWe write format definition in Python outside out TIR scripts: fmt = tir.format(
{
"i": (tir.kDenseFixed, None),
"j": (tir.kSparseVariable, "i")
}
) We specify the format via a Python dictionary:
each axis has no more than one parent. Sparse Tensor DeclarationSparse Tensor is declared in Python as well. indptr = [...]
indices = [...]
a = tir.sparse.tensor(
data,
indptr,
indices,
) where Sparse Support in TIR scripts.
|
@yzh119 Thanks for the update! I'd like propose another design of formats as follows: ## Format proposal F2
fmt = tir.format({
"name1": tir.DenseFixedAxis(),
"name2": tir.DenseVariableAxis(name), # `name` is the name of the axis it depends on
"name3": tir.SparseFixedAxis(n_col),
"name4": tir.SparseVariableAxis(),
}) The main difference between this design and the above format design is that in this design, we treat different axis kinds in different ways. It's reasonable because:
|
Could you elaborate more on sparse-variable axis do not depend on other axes? I don't get the point.
|
@MasterJH5574 A second thought on our proposal of Because several blocks might share the same loop iterator, but one might view it as parallel axis and another one view it as reduction axis. |
@yzh119 Here are more thoughts, which were send to the slack channel before 👀.
|
Yes you're right. It's possible that different blocks view a loop var differently. But in the above design we use opaque blocks (blocks that have no block iter) for sparse block. Therefore it's not very convenient to represent the reduction information in block signatures. One possible design is to add some block iters, and one block iter in a sparse block is required to be bound to only one sparse iterator. An example might be like for vi, vj in T.fuse(tir.cord(i), tir.cord(j)):
for vk in T.cord(k):
for vbi in T.cord(bi):
for vbj in T.cord(vbj):
with T.block('sddmm'):
T.block_attr({'sparse': True})
vi_ = T.sparse_axis(vi)
vj_ = T.sparse_axis(vj)
vk_ = T.sparse_axis(vk, reduction=True)
vbi_ = T.sparse_axis(vbi)
vbj_ = T.sparse_axis(vbj)
with T.init():
C[vi_, vj_, vbi_, vbj_] = 0.
C[vi_, vj_, vbi_, vbj_] = C[vi_, vj_, vbi_, vbj_] + A[vi_, vbi_, vk_] * B[vj_, vbj_, vk_] (Just an example. The API of " Perhaps we should wait for the block iter/binding refactor in TVM script before converging to a detailed design. |
Oh yes I agree. So
To be more specific:
I don't know whether my explanation could convince you. I'll post the detailed proposal for sparse iterators and the iterator dependency soon. |
@MasterJH5574 , I'm thinking of the sparse softmax example: if we write it in the opaque block, the program would look like this (it uses the minus max element trick to avoid overflow): for vi in tir.cord(i):
for vj in tir.pos(j, reduction=True):
with tir.block('A_max'):
tir.block_attr({'sparse': True})
with tir.init():
A_max[vi] = tir.const(-float("inf"), "float32")
A_max[vi] = tir.max(A_max[vi], A[vi, vj])
for vj in tir.pos(j):
with tir.block('A_minus_exp'):
tir.block_attr({'sparse': True})
A_exp[vi, vj] = tir.exp(A[vi, vj] - A_max[vi])
for vj in tir.pos(j, reduction=True):
with tir.block():
tir.block_attr({'sparse': True})
with tir.init('Z'):
Z[vi] = 0.
Z[vi] = Z[vi] + A_exp[vi. vj]
for vj in tir.pos(j):
with tir.block('out'):
tir.block_attr({'sparse': True})
out[vi, vj] = A_exp[vi, vj] / Z[vi] However, we can fuse It would look more natural if we move the reduction attribute to blocks: for vi in tir.cord(i):
for vj in tir.pos(j):
with tir.block(name='A_max', sparse_axes=[vi, tir.reduce_axis(vj)]):
with tir.init():
A_max[vi] = tir.const(-float("inf"), "float32")
A_max[vi] = tir.max(A_max[vi], A[vi, vj])
for vj in tir.pos(j):
with tir.block(name='A_minus_exp', sparse_axes=[vi, vj]):
A_exp[vi, vj] = tir.exp(A[vi, vj] - A_max[vi])
for vj in tir.pos(j):
with tir.block(name='Z', sparse_iters=[vi, tir.reduce_axis(vj)]):
with tir.init():
Z[vi] = 0.
Z[vi] = Z[vi] + A_exp[vi. vj]
for vj in tir.pos(j):
with tir.block(name='out', sparser_iters=[vi, vj]):
out[vi, vj] = A_exp[vi, vj] / Z[vi] Then the block for vi in tir.cord(i):
for vj in tir.pos(j):
with tir.block(name='A_max', sparse_iters=[vi, tir.reduce_axis(vj)]):
with tir.init():
A_max[vi] = tir.const(-float("inf"), "float32")
A_max[vi] = tir.max(A_max[vi], A[vi, vj])
for vj in tir.pos(j):
with tir.block(name='Z_A_minus_exp', sparse_iters=[vi, tir.reduce_axis(vj)]):
with tir.init():
Z[vi] = 0.
Z[vi] = Z[vi] + tir.exp(A[vi, vj] - A_max[vi])
for vj in tir.pos(j):
with tir.block(name='out', sparse_iters=[vi, vj]):
out[vi, vj] = tir.exp(A[vi, vj] - A_max[vi]) / Z[vi] |
For sparse variable axis: I think you assume that we record axis dependency information in sparse tensors. But what about we don't do so? |
SparseIterVar Proposal 1This post propose the design of SparseIterVars, which behaves similar to loop vars in TIR. Note that the class name can change. Possible names are "SparseIterVar", "SparseIterator", "SparseLoopVar". There are two main types of SparseIterVar: BasicSparseIterVar and FusedSparseIterVar. I'll elaborate them respectively. BasicSparseIterVarBasicSparseIterVar represents the basic ways we iterate. Just like SparseAxis which has four types, there are also four kinds of BasicSparseIterVar. Definition
How to define a BasicSparseIterVar in TVM scriptAs described in @yzh119's post, users can define a BasicSparseIterVar via
FusedSparseIterVarA FusedSparseIterVar consists of an ordered array of BasicSparseIterVar, meaning that this FusedSparseIterVar is generated by fusing all BasicSparseIterVars in order. We can use Note that we never expose FusedSparseIterVars to users in TVM script. As the example below (same as @yzh119's example), in frontend we only and always expose BasicSparseIterVars, not letting users notice the existence of FusedSparseIterVars. In the example users only know that for vi, vj in tir.fuse(tir.cord(i), tir.cord(j)):
for vk in tir.cord(k):
for vbi in tir.cord(bi):
for vbj in tir.cord(vbj):
with tir.block([], 'sddmm'):
tir.block_attr({'sparse': True})
with tir.init():
C[vi, vj, vbi, vbj] = 0.
C[vi, vj, vbi, vbj] += A[vi, vbi, vk] * B[vj, vbj, vk] The lowering rules for SparseIterVars and SparseBuffer access will be posted in another proposal in the future. @yzh119 You can take a look. Although it's super long 🤦♂️. |
I'm okay about the |
Yes. A dense-fixed SparseIterVar contains a integer length, and all other kinds of SparseIterVar have an axis pointer which points to a SparseAxis with corresponding kind. In buffer access, the only constraint is that dense-variable SparseIterVars are only allowed to index a dense-variable axis. There's no constraint for SparseIterVars of the other three kinds. |
Proposal: gradual lowering of sparse iteration loopsHad some discussion w/ @MasterJH5574 on the possibility for gradually lowering sparse iteration loops. ExampleBelow is an example of my proposal, let's assume we are trying to lower a program that reduces the last dimension of a sparse tensor. Original code: I = tir.match_axis(fmt, "I")
J = tir.match_axis(fmt, "J")
K = tir.match_axis(fmt, "K")
A = tir.match_buffer(a, (I, J, K), "float32", "int32")
B = tir.match_buffer(b, (I, J), "float32", "int32")
for i in tir.cord(I):
for j in tir.pos(J):
for k in tir.pos(K):
with tir.block(name='reduction', sparse_iters=[i, j, tir.reduction(k)]):
with tir.init():
B[i, j] = 0.
B[i, j] = B[i, j] + A[i, j, k] where the fmt is created via fmt = tir.sparse.fomat(
"i": tir.DenseFixedAxis(),
"j": tir.SparseVariableAxis("i"),
"k": tir.SparseVariableAxis("j")
) Then we lower the program itervar by itervar: The first step is to for i in tir.grid(I.length):
with tir.block([I.length]) as vi:
tir.bind(i, vi)
J_i = tir.match_axis(J[vi])
K_i = tir.match_axis(K[vi])
B_i = tir.match_buffer(B[vi], (J_i))
A_i = tir.match_buffer(A[vi], (J_i, K_i))
for j in tir.pos(J):
for k in tir.pos(K):
with tir.block(name='reduction', sparse_iters=[j, tir.reduction(k)]):
with tir.init():
B_i[j] = 0.
B_i[j] = B_i[j] + A_i[j, k] Then for i in tir.grid(I.length):
with tir.block([I.length]) as vi:
tir.bind(vi, i)
J_i = tir.match_axis(J[vi])
K_i = tir.match_axis(K[vi])
B_i = tir.match_buffer(B[vi], (J_i))
A_i = tir.match_buffer(A[vi], (J_i, K_i))
for j in tir.grid(J_i.length):
with tir.block([J_i.length]) as vj:
tir.bind(vj, j)
K_i_j = tir.match_axis(K_i[vj])
A_i_j = tir.match_buffer(A_i[vj], (K_i_j,))
B_i_j = tir.match_buffer(B_i[vj], (1,))
for k in tir.pos(K):
with tir.block(name='reduction', sparse_iters=[vj, tir.reduction(vk)]):
with tir.init():
B_i[j] = 0.
B_i[j] = B_i[j] + A_i[j, k] Then for i in tir.grid(I.length):
with tir.block([I.length]) as vi:
tir.bind(vi, i)
J_i = tir.match_axis(J[vi])
K_i = tir.match_axis(K[vi])
B_i = tir.match_buffer(B[vi], (J_i))
A_i = tir.match_buffer(A[vi], (J_i, K_i))
for j in tir.grid(J_i.length):
with tir.block([J_i.length]) as vj:
tir.bind(vj, j)
K_i_j = tir.match_axis(K_i[vj])
A_i_j = tir.match_buffer(A_i[vj], (K_i_j,))
B_i_j = tir.match_buffer(B_i[vj], (1,))
for k in tir.grid(K_i_j.length):
with tir.block([tir.reduce_axis((0, K_i_j.length))]) as vk:
tir.bind(vk, k)
A_i_j_k = tir.match_buffer(A_i_j[vk], (1,))
with tir.init():
B_i_j[0] = 0.
B_i_j[0] = B_i_j[0] + A_i_j_k[0] What did It basically assign pointer |
@yzh119 Another question occurs to me about the gradual lowering. Is it possible to lower a SparseIterVar that was fused into a FusedSparseIterVar before? For example, in the code below for vi, vj in tir.fuse(tir.cord(i), tir.cord(j)):
blabla is it possible to only lower |
No, IMO we can only lower |
Discussion & Proposals: Output Buffer's
|
Proposal: Buffer Access LoweringRecall of Axis/Iterator DesignAccording to our design:
SparseBufferLoad & SparseBufferStoreLike BufferLoad and BufferStore, we now introduce SparseBufferLoad and SparseBufferStore to represent the read/write access to SparseBuffers.
Buffer Access LoweringThis section is about the method we convert SparseBufferLoad/SparseBufferStore to BufferLoad/BufferStore when lowering a whole SparseTIR to a normal TIR. In this section, we mean a sparse buffer access by "using an array of PrimExpr as indices to access a SparseBuffer". One of the tasks of SparseTIR lowering is to convert sparse buffer accesses to normal buffer accesses in TIR. Without loss of generality, we suppose a SparseBuffer Sparse IndicesTo lower sparse buffer accesses, we should be able to convert the original indices Function
|
Design: #368
lower_bound
(@yzh119 WIP)\infy: tensorization.
The text was updated successfully, but these errors were encountered: