Skip to content

Commit

Permalink
sync
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy committed Jul 5, 2021
1 parent 8d67a62 commit 6fa423e
Show file tree
Hide file tree
Showing 23 changed files with 223 additions and 1,324 deletions.
1 change: 1 addition & 0 deletions include/tvm/node/structural_equal.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/runtime/data_type.h>

#include <string>
#include <unordered_map>

namespace tvm {

Expand Down
3 changes: 3 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,9 @@ struct MatmulAttrs : public tvm::AttrsNode<MatmulAttrs> {
bool transpose_a;
bool transpose_b;
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
// The original shape of the weight matrix,
// used to recover the compute after transforming the input's layout .
Array<PrimExpr> meta_schedule_original_shape;

TVM_DECLARE_ATTRS(MatmulAttrs, "relay.attrs.MatmulAttrs") {
TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/runtime/container/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <utility>

#include "./base.h"
#include "optional.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -1344,6 +1345,14 @@ class Map : public ObjectRef {
iterator end() const { return iterator(GetMapNode()->end()); }
/*! \return find the key and returns the associated iterator */
iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); }
/*! \return The value associated with the key, NullOpt if not found */
Optional<V> Get(const K& key) const {
MapNode::iterator iter = GetMapNode()->find(key);
if (iter == GetMapNode()->end()) {
return NullOptType{};
}
return DowncastNoCheck<V>(iter->second);
}

void erase(const K& key) { CopyOnWrite()->erase(key); }

Expand Down
29 changes: 0 additions & 29 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,35 +198,6 @@ class ScheduleNode : public runtime::Object {
* \return A list of loops above the given block in its scope, from outer to inner
*/
virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0;
/******** Schedule: loops manipulation ********/
/******** Schedule: compute location ********/
/*!
* \brief Inline a block into its consumer(s). It requires:
* 1) The block is a complete non-root block, which only produces one buffer
* 2) The block must not be the only leaf in the scope.
* 3) The body of the block must be a BufferStore statement in the form of,
* A[i, j, k, ...] = ...
* where the indices of the LHS are all distinct atomic variables,
* and no variables other than those indexing variables are allowed in the statement.
* \param block The block to be inlined to its consumer(s)
*/
virtual void ComputeInline(const BlockRV& block) = 0;
/*!
* \brief Inline a block into its only producer. It requires:
* 1) The block is a complete non-root block, which only produces and consumers one buffer
* 2) The block must not be the only leaf in the scope.
* 3) The only producer of the block is a read-after-write producer and a complete non-root block
* 4) The body of the block must be a BufferStore statement in the form of,
* B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...)
* where the indices of each `BufferLoad` on the RHS are all distinct atomic variables,
* and no variables other than those indexing variables are allowed in the statement.
* \param block The block to be inlined to its producer
*/
virtual void ReverseComputeInline(const BlockRV& block) = 0;
/******** Schedule: loop binding/annotation ********/
/******** Schedule: cache read/write ********/
/******** Schedule: reduction ********/
/******** Schedule: blockize & tensorize ********/
/*!
* \brief Get the leaf blocks of a specific scope
* \param block_rv The block where the scope is rooted
Expand Down
1 change: 0 additions & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from .op import q_multiply_shift

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
from .schedule import create_schedule, validate_hierarchy

from . import schedule
from . import ir_builder
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
from .state import ScheduleDebugMask, ScheduleState
from .inst import Inst, InstKind
from .trace import Trace
from .schedule import RAND_VAR_TYPE, BlockRV, ExprRV, LoopRV, Schedule
from .schedule import RAND_VAR_TYPE, BlockRV, ExprRV, LoopRV, Schedule, ScheduleError
231 changes: 132 additions & 99 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(
@property
def mod(self) -> IRModule:
"""Returns the AST of the module being scheduled"""
return _ffi_api_schedule.ScheduleModule(self) # type: ignore # pylint: disable=no-member
return _ffi_api_schedule.ScheduleGetMod(self) # type: ignore # pylint: disable=no-member

@property
def state(self) -> ScheduleState:
Expand All @@ -165,7 +165,7 @@ def copy(self, seed: int = -1) -> "Schedule":
copy : Schedule
A new copy of the schedule
"""
return _ffi_api_schedule.ScheduleCopy(self) # type: ignore # pylint: disable=no-member
return _ffi_api_schedule.ScheduleCopy(self, seed) # type: ignore # pylint: disable=no-member

def seed(self, seed: int) -> None:
"""Seed the randomness
Expand Down Expand Up @@ -254,7 +254,6 @@ def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None:
return _ffi_api_schedule.ScheduleRemoveRV(self, rand_var) # type: ignore # pylint: disable=no-member

########## Block/Loop relation ##########
########## Block/Loop relation ##########

def get_block(
self,
Expand Down Expand Up @@ -398,100 +397,6 @@ def reverse_compute_at(
self, block, loop, preserve_unit_loop
)

def compute_inline(self, block: BlockRV) -> None:
_ffi_api_schedule.ScheduleComputeInline(self, block) # pylint: disable=no-member

def reverse_compute_inline(self, block: BlockRV) -> None:
_ffi_api_schedule.ScheduleReverseComputeInline(self, block) # pylint: disable=no-member

########## Schedule: parallelize / annotate ##########

def vectorize(self, loop: LoopRV) -> None:
_ffi_api_schedule.ScheduleVectorize(self, loop) # pylint: disable=no-member

def parallel(self, loop: LoopRV) -> None:
_ffi_api_schedule.ScheduleParallel(self, loop) # pylint: disable=no-member

def unroll(self, loop: LoopRV) -> None:
_ffi_api_schedule.ScheduleUnroll(self, loop) # pylint: disable=no-member

def bind(self, loop: LoopRV, thread: Union[str, IterVar]) -> None:
if isinstance(thread, str):
thread = String(thread)
_ffi_api_schedule.ScheduleBind(self, loop, thread) # pylint: disable=no-member

def double_buffer(self, block: BlockRV) -> None:
_ffi_api_schedule.ScheduleDoubleBuffer(self, block) # pylint: disable=no-member

def set_scope(self, block: BlockRV, i: int, storage_scope: str) -> None:
_ffi_api_schedule.ScheduleSetScope( # pylint: disable=no-member
self, block, i, storage_scope
)

def pragma(self, loop: LoopRV, pragma_type: str, pragma_value: ExprRV) -> None:
if isinstance(pragma_value, bool):
pragma_value = IntImm("bool", pragma_value)
_ffi_api_schedule.SchedulePragma( # pylint: disable=no-member
self, loop, pragma_type, pragma_value
)

def storage_align(
self,
block: BlockRV,
buffer_index: int,
axis: int,
factor: int,
offset: int,
) -> None:
_ffi_api_schedule.ScheduleStorageAlign( # pylint: disable=no-member
self, block, buffer_index, axis, factor, offset
)

########## Schedule: cache read/write ##########

def cache_read(self, block: BlockRV, i: int, storage_scope: str) -> BlockRV:
return _ffi_api_schedule.ScheduleCacheRead( # pylint: disable=no-member
self, block, i, storage_scope
)

def cache_write(self, block: BlockRV, i: int, storage_scope: str) -> BlockRV:
return _ffi_api_schedule.ScheduleCacheWrite( # pylint: disable=no-member
self, block, i, storage_scope
)

########## Schedule: reduction ##########

def rfactor(self, loop: LoopRV, factor: int) -> LoopRV:
return _ffi_api_schedule.ScheduleRFactor(self, loop, factor) # pylint: disable=no-member

def decompose_reduction(self, block: BlockRV, loop: Optional[LoopRV]) -> BlockRV:
return _ffi_api_schedule.ScheduleDecomposeReduction( # pylint: disable=no-member
self, block, loop
)

def merge_reduction(self, init: BlockRV, update: BlockRV) -> None:
_ffi_api_schedule.ScheduleMergeReduction(self, init, update) # pylint: disable=no-member

########## Schedule: blockize / tensorize ##########

def blockize(self, loop: LoopRV) -> BlockRV:
return _ffi_api_schedule.ScheduleBlockize(self, loop) # pylint: disable=no-member

def get_loops(self, block: BlockRV) -> List[LoopRV]:
"""Get the parent loops of the block in its scope, from outer to inner
Parameters
----------
block : BlockRV
The query block
Returns
----------
loops : List[LoopRV]
A list of loops above the given block in its scope, from outer to inner
"""
return _ffi_api_schedule.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member

########## Schedule: loops manipulation ##########
########## Schedule: compute location ##########
def compute_inline(self, block: BlockRV) -> None:
"""Inline a block into its consumer(s). It requires:
Expand Down Expand Up @@ -609,10 +514,138 @@ def after_inline(a: ty.handle, c: ty.handle) -> None:
"""
_ffi_api_schedule.ScheduleReverseComputeInline(self, block) # type: ignore # pylint: disable=no-member

########## Schedule: loop binding/annotation ##########
########## Schedule: parallelize / annotate ##########

def vectorize(self, loop: LoopRV) -> None:
_ffi_api_schedule.ScheduleVectorize(self, loop) # pylint: disable=no-member

def parallel(self, loop: LoopRV) -> None:
_ffi_api_schedule.ScheduleParallel(self, loop) # pylint: disable=no-member

def unroll(self, loop: LoopRV) -> None:
_ffi_api_schedule.ScheduleUnroll(self, loop) # pylint: disable=no-member

def bind(self, loop: LoopRV, thread: Union[str, IterVar]) -> None:
if isinstance(thread, str):
thread = String(thread)
_ffi_api_schedule.ScheduleBind(self, loop, thread) # pylint: disable=no-member

def double_buffer(self, block: BlockRV) -> None:
_ffi_api_schedule.ScheduleDoubleBuffer(self, block) # pylint: disable=no-member

def set_scope(self, block: BlockRV, i: int, storage_scope: str) -> None:
_ffi_api_schedule.ScheduleSetScope( # pylint: disable=no-member
self, block, i, storage_scope
)

def pragma(self, loop: LoopRV, pragma_type: str, pragma_value: ExprRV) -> None:
if isinstance(pragma_value, bool):
pragma_value = IntImm("bool", pragma_value)
_ffi_api_schedule.SchedulePragma( # pylint: disable=no-member
self, loop, pragma_type, pragma_value
)

def storage_align(
self,
block: BlockRV,
buffer_index: int,
axis: int,
factor: int,
offset: int,
) -> None:
_ffi_api_schedule.ScheduleStorageAlign( # pylint: disable=no-member
self, block, buffer_index, axis, factor, offset
)

########## Schedule: cache read/write ##########

def cache_read(self, block: BlockRV, i: int, storage_scope: str) -> BlockRV:
return _ffi_api_schedule.ScheduleCacheRead( # pylint: disable=no-member
self, block, i, storage_scope
)

def cache_write(self, block: BlockRV, i: int, storage_scope: str) -> BlockRV:
return _ffi_api_schedule.ScheduleCacheWrite( # pylint: disable=no-member
self, block, i, storage_scope
)

########## Schedule: reduction ##########
########## Schedule: blockize & tensorize ##########

def rfactor(self, loop: LoopRV, factor: int) -> LoopRV:
return _ffi_api_schedule.ScheduleRFactor(self, loop, factor) # pylint: disable=no-member

def decompose_reduction(self, block: BlockRV, loop: Optional[LoopRV]) -> BlockRV:
return _ffi_api_schedule.ScheduleDecomposeReduction( # pylint: disable=no-member
self, block, loop
)

def merge_reduction(self, init: BlockRV, update: BlockRV) -> None:
_ffi_api_schedule.ScheduleMergeReduction(self, init, update) # pylint: disable=no-member

########## Schedule: blockize / tensorize ##########

def blockize(self, loop: LoopRV) -> BlockRV:
return _ffi_api_schedule.ScheduleBlockize(self, loop) # pylint: disable=no-member

def tensorize(self, loop: LoopRV, intrin: Union[str, TensorIntrin]) -> None:
if isinstance(intrin, str):
intrin = String(intrin)
_ffi_api_schedule.ScheduleTensorize(self, loop, intrin) # pylint: disable=no-member

########## Schedule: Marks and NO-OPs ##########

def mark_loop(
self,
loop: LoopRV,
ann_key: str,
ann_val: str,
) -> None:
"""Mark a range of loops with the specific mark
Parameters
----------
loop: LoopRV
The loops to be marked
ann_key : str
The annotation key
ann_val : str
The annotation value
"""
if isinstance(ann_val, str):
ann_val = String(ann_val)
elif isinstance(ann_val, int):
ann_val = IntImm("int64", ann_val)
_ffi_api_schedule.ScheduleMarkLoop( # pylint: disable=no-member
self, loop, ann_key, ann_val
)

def mark_block(
self,
block: BlockRV,
ann_key: str,
ann_val: ExprRV,
) -> None:
"""Mark a block
Parameters
----------
block : BlockRV
The block to be marked
ann_key : str
The annotation key
ann_val : ExprRV
The annotation value
"""
if isinstance(ann_val, str):
ann_val = String(ann_val)
elif isinstance(ann_val, int):
ann_val = IntImm("int64", ann_val)
_ffi_api_schedule.ScheduleMarkBlock( # pylint: disable=no-member
self, block, ann_key, ann_val
)

########## Schedule: Misc ##########

def inline_argument(self, i: int, func_name: str = "main"):
_ffi_api_schedule.ScheduleInlineArgument(self, i, func_name) # pylint: disable=no-member


@_register_object("tir.ConcreteSchedule")
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for
pass_list.push_back(tir::transform::InjectPrefetch());
pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
} else {
pass_list.push_back(tir::transform::AllreduceTransform());
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
Expand Down
5 changes: 1 addition & 4 deletions src/tir/ir/specialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include <tvm/tir/analysis.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <functional>

Expand Down Expand Up @@ -140,13 +139,11 @@ class PrimFuncSpecializer : public StmtExprMutator {
op = expr.as<BufferLoadNode>();
ICHECK(op != nullptr);
auto it = buffer_map_.find(op->buffer);
Array<PrimExpr> indices = MutateArray(op->indices, fmutate);
if (it == buffer_map_.end() && indices.same_as(op->indices)) {
if (it == buffer_map_.end()) {
return GetRef<BufferLoad>(op);
} else {
auto n = make_object<BufferLoadNode>(*op);
n->buffer = it->second;
n->indices = std::move(indices);
return PrimExpr(n);
}
}
Expand Down
Loading

0 comments on commit 6fa423e

Please sign in to comment.