Skip to content

Commit

Permalink
sync
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhongyii committed Jul 29, 2021
1 parent 002348c commit a08a791
Show file tree
Hide file tree
Showing 26 changed files with 448 additions and 1,262 deletions.
86 changes: 43 additions & 43 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,62 +354,62 @@ def sample_compute_location(
########## Schedule: loops ##########

def fuse(self, *loops: List[LoopRV]) -> LoopRV:
"""Fuse a list of consecutive loops into one. It requires:
1) The loops can't have annotations or thread bindings.
2) The (i+1)-th loop must be the only child of the i-th loop.
3) All loops must start with 0.
"""Fuse a list of consecutive loops into one. It requires:
1) The loops can't have annotations or thread bindings.
2) The (i+1)-th loop must be the only child of the i-th loop.
3) All loops must start with 0.
Parameters
----------
*loops : List[LoopRV]
The loops to be fused
Parameters
----------
*loops : List[LoopRV]
The loops to be fused
Returns
----------
fused_loop : LoopRV
The new loop after fusion
Returns
----------
fused_loop : LoopRV
The new loop after fusion
Examples
--------
Examples
--------
Before applying fuse, in TensorIR, the IR is:
Before applying fuse, in TensorIR, the IR is:
.. code-block:: python
.. code-block:: python
@tvm.script.tir
def before_fuse(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
@tvm.script.tir
def before_fuse(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do fuse:
Create the schedule and do fuse:
.. code-block:: python
.. code-block:: python
sch = tir.Schedule(before_fuse)
i, j = sch.get_loops(sch.get_block("B"))
sch.fuse(i, j)
print(tvm.script.asscript(sch.mod["main"]))
sch = tir.Schedule(before_fuse)
i, j = sch.get_loops(sch.get_block("B"))
sch.fuse(i, j)
print(tvm.script.asscript(sch.mod["main"]))
After applying fuse, the IR becomes:
After applying fuse, the IR becomes:
.. code-block:: python
.. code-block:: python
@tvm.script.tir
def after_fuse(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
# the 2 loops are fused into 1
for i_j_fused in tir.serial(0, 16384):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, tir.floordiv(i_j_fused, 128))
tir.bind(vj, tir.floormod(i_j_fused, 128))
B[vi, vj] = A[vi, vj] * 2.0
@tvm.script.tir
def after_fuse(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
# the 2 loops are fused into 1
for i_j_fused in tir.serial(0, 16384):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, tir.floordiv(i_j_fused, 128))
tir.bind(vj, tir.floormod(i_j_fused, 128))
B[vi, vj] = A[vi, vj] * 2.0
"""
return _ffi_api_schedule.ScheduleFuse(self, loops) # type: ignore # pylint: disable=no-member
"""
return _ffi_api_schedule.ScheduleFuse(self, loops) # type: ignore # pylint: disable=no-member

def split(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/space/postproc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ class PostProcRewriteLayout {
}
// Step 1: create a new buffer
tir::Buffer new_buffer(buffer->data, buffer->dtype, new_shape, Array<PrimExpr>(),
buffer->elem_offset, buffer->name, buffer->scope,
buffer->elem_offset, buffer->name,
buffer->data_alignment, buffer->offset_factor, buffer->buffer_type);
// Step 2: do the rewrite to the buffer access
// the rule is as below:
Expand Down
Loading

0 comments on commit a08a791

Please sign in to comment.