Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto memory planning #434

Open
jinhongyii opened this issue Aug 10, 2021 · 0 comments
Open

Auto memory planning #434

jinhongyii opened this issue Aug 10, 2021 · 0 comments

Comments

@jinhongyii
Copy link
Collaborator

Background

We want to introduce auto data movement, that is, automatically generate schedule for copying a region of data to another. This will involve automatically deciding the layout for the intermediate buffers and automatically binding threads.

We first pay attention to the layout for intermediate buffers and completely ignore bindings.

Process

A super simple idea is to have a list of layouts and sample one of these from the list. The list can be

f(i, j)= (i, j)
f(i, j)= (j, i)
f(i, j)= pad m elements every n lines
f(i, j)= (i, i ^ j)
...

Suppose the warp memory has an identical layout. Then the process can be defined as below:

The original IR is:

with tir.block([1024,1024,tir.reduce_axis(1024)]) as [vi, vj, vk]
	C[vi, vj]+=A[vi, vk]*B[vk, vj]

after early tensorize:

for i, j, k in tir.grid(64, 64, 64):
    with tir.block([64,64,tir.reduce_axis(64)]) as [vi, vj, vk]
        for ii, jj, kk in tir.grid(16, 16, 16):
            with tir.block(16, 16 tir.reduce_axis(16)] as [vii, vjj, vkk]
            	tir.bind(...)
                C[vii, vjj]+=A[vii, vkk]*B[vkk, vjj]

after multi-level tiling (no cache_read/cache_write)

for i0, j0 , ... in tir.grid(...): #SSSRRSRS
    with tir.block([64,64,tir.reduce_axis(64)]) as [vi, vj, vk]
    	tir.bind(...)
        for ii, jj, kk in tir.grid(16, 16, 16):
            with tir.block(16, 16 tir.reduce_axis(16)] as [vii, vjj, vkk]
            	tir.bind(...)
                C[vii, vjj]+=A[vii, vkk]*B[vkk, vjj]

rewrite the computation part , get the warp load/store tensor intrin and generate intermediata buffer whose layout is constrained by the tensor intrin of warp load :

for i0, j0, i1, j1, i2, j2 in tir.grid(...): #SSS
	for k0:#R
		A->A_shared(sampled layout f)
		B->B_shared(sampled layout g)
		for k1, i3, j3, k2:#RSR
			A_shared[i,i^j]->A_warp
			B_shared->B_warp:
			for i4, j4: #S
               wmma
    wmma->C           

tensor rewrite for warp load/store:

for i0, j0, i1, j1, i2, j2 in tir.grid(...): #SSS
	for k0:#R
		A->A_shared(sampled layout f)
		B->B_shared(sampled layout g)
		for k1, i3, j3, k2:#RSR
			wmma_load_sync
			for i4, j4: #S
               wmma
    wmma_store_sync        

Note there are some potential problems of this algorithm:

Problem 1. how can we ensure the layout can fit the pre-registered tensor intrin (warp load).

For example, if we have a layout that pad 4 elements every two lines and the shared memory size of A is 128*64

f(i,j) = i * 64 + j + i / 2 * 4

If we have the pre-registered tensor intrin as wmma_load_sync, then the layout cannot be applied because the function requires a consistent stride as its arguments:

//here `ldm` represents the stride
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm);

However, when we have a pre-registered tensor intrin as the ptx mma instruction:

ldmatrix.sync.aligned.shape.num{.trans}{.ss}.type r, [p];

.shape  = {.m8n8};
.num    = {.x1, .x2, .x4};
.ss     = {.shared};
.type   = {.b16};

It passes in the pointer to the head of each line, which doesn't have such constraint of strides, so the layout can be applied in this situation.

So we should have a description of the constraints which the intrin impose on the layouts for every tensor intrin or we can specify exactly what layout a tensor intrin accepts.

Problem 2. how to describe the tensor intrin after introducing layouts.

Previously we are dealing with a simple data movement, that is copying data from A_shared[i,j] to A_warp[i,j], which is 16 * 16 to 16 * 16.

@tvm.script.tir
def wmma_load_a_desc(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (16, 16), "float16", align=128, offset_factor=256,
                         scope="shared")
    C = tir.match_buffer(c, (16, 16), "float16", align=128, offset_factor=256,
                         scope="wmma.matrix_a")

    with tir.block([16, 16], "root") as [vi, vj]:
        tir.bind(vi, 0)
        tir.bind(vj, 0)
        for i, j in tir.grid(16, 16):
            with tir.block([16, 16], "load") as [vii, vjj]:
                tir.bind(vii, vi + i)
                tir.bind(vjj, vj + j)
                C[vii, vjj] = A[vii, vjj]
                
@tvm.script.tir
def wmma_load_a_impl(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (16, 16), "float16", align=128, offset_factor=256, scope="shared")
    C = tir.match_buffer(c, (16, 16), "float16", align=128, offset_factor=256, scope="wmma.matrix_a")

    with tir.block([16, 16], "root") as [vi, vj]:
        tir.bind(vi, 0)
        tir.bind(vj, 0)
        tir.reads(A[0: 16, 0: 16])
        tir.writes(C[0: 16, 0: 16])
        tir.evaluate(tir.tvm_load_matrix_sync(
            C.data, 16, 16, 16, C.elem_offset // 256, A.access_ptr("r"), 16, "row_major",
            dtype="handle"))

However, if we use several different layouts, there will different descriptions: A_shared[f(i,j)]->A_warp[i,j]. for each different f, we need a different implementation so as to do tensorize rewrite. Take the layout discussed in the first problem as example.

the description and implementation would be expected to be below, but actually it can't be done.

@tvm.script.tir
def mma_load_a_desc(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (256), "float16", align=128, offset_factor=256,
                         scope="shared")
    C = tir.match_buffer(c, (16, 16), "float16", align=128, offset_factor=256,
                         scope="mma.matrix_a")

    with tir.block([16, 16], "root") as [vi, vj]:
        tir.bind(vi, 0)
        tir.bind(vj, 0)
        for i, j in tir.grid(16, 16):
            with tir.block([16, 16], "load") as [vii, vjj]:
                tir.bind(vii, vi + i)
                tir.bind(vjj, vj + j)
                #the layout can't be expressed in 2-d way
                #this example is wrong because the `stride` can't be inferred 
                C[vii, vjj] = A[vii * stride + vjj + vii / 2 * 4]
                
@tvm.script.tir
def mma_load_a_impl(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (256), "float16", align=128, offset_factor=256, scope="shared")
    C = tir.match_buffer(c, (16, 16), "float16", align=128, offset_factor=256, scope="mma.matrix_a")

    with tir.block([16, 16], "root") as [vi, vj]:
        tir.bind(vi, 0)
        tir.bind(vj, 0)
        tir.reads(A[0: 16, 0: 16])
        tir.writes(C[0: 16, 0: 16])
        tir.evaluate(tir.ldmatrix(C.data, 16, 16, 16, C.elem_offset // 256, A.access_ptr("r"),2, 4, 64))
        # this is merely an example. ldmatrix is currently not supported

Another real problem is that if there are m layouts and n tensor intrins. we'll have m*n description-implementation pair.

I will post some intermediate TIR later to clarify the transformations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant