diff --git a/src/meta_schedule/analysis.cc b/src/meta_schedule/analysis.cc index afb3b75213..31b5c6d69c 100644 --- a/src/meta_schedule/analysis.cc +++ b/src/meta_schedule/analysis.cc @@ -26,8 +26,12 @@ namespace tvm { namespace meta_schedule { +/**************** TIR Nodes ****************/ +using tir::BlockNode; +using tir::ForNode; + bool IsTrivialBinding(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) { - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); tir::BlockRealize realize = tir::GetBlockRealize(block_sref); Array loops = tir::GetLoops(block_sref); const Array& bindings = realize->iter_values; @@ -37,7 +41,7 @@ bool IsTrivialBinding(const tir::ScheduleState& self, const tir::StmtSRef& block int n = loops.size(); for (int i = 0; i < n; ++i) { const PrimExpr& bind = bindings[i]; - const auto* loop = TVM_SREF_TO_FOR(loop, loops[i]); + const ForNode* loop = TVM_SREF_TO_FOR(loop, loops[i]); if (bind.as() != loop->loop_var.get()) { return false; } @@ -51,7 +55,7 @@ bool IsSubrootBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_s } bool IsLeafBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) { - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); bool no_child = true; tir::PreOrderVisit(block->body, [&no_child](const ObjectRef& obj) -> bool { if (!no_child) { @@ -67,7 +71,7 @@ bool IsLeafBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref } Array GetBlockVarTypes(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) { - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); Array result; for (const tir::IterVar& iter_var : block->iter_vars) { int iter_type = iter_var->iter_type; @@ -77,7 +81,7 @@ Array GetBlockVarTypes(const tir::ScheduleState& self, const tir::StmtS } bool IsSpatial(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) { - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); for (const tir::IterVar& iter_var : block->iter_vars) { if (iter_var->iter_type != tir::IterVarType::kDataPar) { return false; @@ -88,8 +92,8 @@ bool IsSpatial(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) bool IsOutputBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) { tir::StmtSRef parent_sref = tir::GetScopeRoot(block_sref).value(); - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); - const auto* parent = TVM_SREF_TO_BLOCK(parent, parent_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* parent = TVM_SREF_TO_BLOCK(parent, parent_sref); if (parent_sref->parent == nullptr) { const tir::PrimFuncNode* func = tir::GetRootPrimFunc(self, parent_sref); for (const tir::BufferRegion& write : block->writes) { @@ -112,7 +116,7 @@ bool IsOutputBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sr } int CountOp(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, const Op& op) { - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); int count = 0; tir::PostOrderVisit(block->body, [&count, &op](const ObjectRef& obj) { if (const auto* call = obj.as()) { @@ -125,7 +129,7 @@ int CountOp(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, con } bool HasBranch(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) { - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); bool has_branch = false; arith::Analyzer analyzer; auto f_visit = [&has_branch, &analyzer](const ObjectRef& obj) -> bool { @@ -214,8 +218,8 @@ bool IsElementWiseMatch(const tir::ScheduleState& self, const tir::StmtSRef& pro const tir::StmtSRef& consumer_sref) { // Assume consumer is the only consumer of the producer tir::StmtSRef parent_sref = tir::GetScopeRoot(producer_sref).value(); - const auto* producer = TVM_SREF_TO_BLOCK(producer, producer_sref); - const auto* consumer = TVM_SREF_TO_BLOCK(consumer, consumer_sref); + const BlockNode* producer = TVM_SREF_TO_BLOCK(producer, producer_sref); + const BlockNode* consumer = TVM_SREF_TO_BLOCK(consumer, consumer_sref); if (producer->writes.empty()) { return false; } @@ -285,7 +289,7 @@ bool NeedsMultiLevelTiling(const tir::ScheduleState& self, const tir::StmtSRef& if (!IsTrivialBinding(self, block_sref)) { return false; } - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); // Assume complete/reduction block if (block->writes.size() != 1) { return false; @@ -333,7 +337,7 @@ bool NeedsMultiLevelTiling(const tir::ScheduleState& self, const tir::StmtSRef& bool IsStrictlyInlineable(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) { static const Op& op_tir_exp = Op::Get("tir.exp"); - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); // Const tensors are strictly inlineable if (block->reads.empty()) { return true; @@ -764,7 +768,7 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, int64_t max_parallel_extent, int64_t basic_parallel_extent) { - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); Array loops = tir::GetLoops(block_sref); // Cond 1. The block is a reduction block and has trivial binding. @@ -791,9 +795,9 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, } // Cond 4. - const auto* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]); + const ForNode* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]); if (i < static_cast(loops.size()) - 1) { - const auto* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]); + const ForNode* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]); if (loop_i->body.get() != loop_i1) { return false; } diff --git a/src/meta_schedule/space/search_rule.cc b/src/meta_schedule/space/search_rule.cc index 7956f7f002..406740cbb3 100644 --- a/src/meta_schedule/space/search_rule.cc +++ b/src/meta_schedule/space/search_rule.cc @@ -26,6 +26,10 @@ namespace tvm { namespace meta_schedule { +/**************** TIR Nodes ****************/ +using tir::ForNode; +using tir::BlockNode; + /********** Constructors **********/ SearchRule::SearchRule(String name, SearchRuleNode::FApply apply) { @@ -962,7 +966,7 @@ class RuleAddRFactor { const BlockRV& block_rv) const { // Check the conditions of the rule. tir::StmtSRef block_sref = sch->GetSRef(block_rv); - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); if (HasAnyAnn(block_sref)) { return {sch}; } @@ -1041,7 +1045,7 @@ class RuleCrossThreadReduction { // Check the conditions of the rule. const tir::StmtSRef& block_sref = sch->GetSRef(block_rv); - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); if (HasAnyAnn(block_sref)) { return {sch}; } diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 03b5490d0b..8a5a2420fe 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -42,6 +42,9 @@ using tir::AsArray; using tir::AsOptArray; using tir::AsVector; +/**************** TIR Nodes ****************/ +using tir::ForNode; + /*! * \brief Compute mean of a FloatImm array. * Taken from Ansor @@ -179,7 +182,7 @@ inline Optional FindBlockSRef(const tir::ScheduleState& sch, FPre /**************** TIR Annotation ****************/ inline bool HasBinding(const tir::StmtSRef& loop_sref, const String& thread_tag) { - const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); if (!loop->thread_binding.defined()) { return false; } @@ -191,7 +194,7 @@ inline bool HasBinding(const tir::StmtSRef& loop_sref, const String& thread_tag) } inline Optional GetBinding(const tir::StmtSRef& loop_sref) { - const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); if (!loop->thread_binding.defined()) { return NullOpt; } diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index a70ee0f7e3..4409a4d05c 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -111,7 +111,7 @@ bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root) { BlockScope scope = self->GetBlockScope(scope_root); // Cond 1. All block vars are data parallel - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); for (const IterVar& iter_var : block->iter_vars) { if (iter_var->iter_type != kDataPar) { return false; @@ -307,7 +307,7 @@ bool RegionCoveredConsumer(const ScheduleState& self, const StmtSRef& consumer_b if (consumer_block_sref->parent == nullptr) { return true; } - const auto* consumer_block = TVM_SREF_TO_BLOCK(consumer_block, consumer_block_sref); + const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block, consumer_block_sref); BlockScope scope = self->GetBlockScope(scope_root); // Step 1. Gather all the producers struct Producer { @@ -326,7 +326,7 @@ bool RegionCoveredConsumer(const ScheduleState& self, const StmtSRef& consumer_b // i.e. the RAW predecessor is producer if (edge->kind == DepKind::kRAW) { const StmtSRef& producer_block_sref = edge->src; - const auto* producer_block = TVM_SREF_TO_BLOCK(producer_block, producer_block_sref); + const BlockNode* producer_block = TVM_SREF_TO_BLOCK(producer_block, producer_block_sref); for (const BufferRegion& output_region : producer_block->writes) { const VarNode* buffer_var = output_region->buffer->data.get(); buffer_producers[buffer_var].emplace_back(producer_block_sref, output_region); @@ -600,7 +600,7 @@ bool CompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root) { BlockScope scope = self->GetBlockScope(scope_root); // Cond 2. Check if all the block vars are data parallel - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); for (const IterVar& iter_var : block->iter_vars) { if (iter_var->iter_type != kDataPar) { return false; @@ -631,7 +631,7 @@ bool ReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, // return false; // } // Cond 4. Check whether the block body has the init statement. - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); if (!block->init.defined()) { return false; } @@ -677,8 +677,8 @@ bool ReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, bool CanMergeReduction(const ScheduleState& self, const StmtSRef& init_block_sref, const StmtSRef& update_block_sref, const StmtSRef& scope_root) { BlockScope scope = self->GetBlockScope(scope_root); - const auto* init = TVM_SREF_TO_BLOCK(init, init_block_sref); - const auto* update = TVM_SREF_TO_BLOCK(update, update_block_sref); + const BlockNode* init = TVM_SREF_TO_BLOCK(init, init_block_sref); + const BlockNode* update = TVM_SREF_TO_BLOCK(update, update_block_sref); // Cond 1. Check the binding of update block is valid if (!self->IsAffineBlockBinding(update_block_sref)) { return false; @@ -731,7 +731,7 @@ IterVarType GetLoopIterType(const ScheduleState& self, const StmtSRef& loop_sref int n_spatial = 0; int n_reduce = 0; int n_other = 0; - const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); const Var& loop_var = loop->loop_var; auto f_visit = [&loop_var, &n_spatial, &n_reduce, &n_other](const ObjectRef& obj) -> bool { if (const auto* realize = obj.as()) { @@ -785,7 +785,7 @@ Array CollectComputeLocation(const ScheduleState& self, const StmtSRef result.reserve(loop_srefs.size()); bool visited_reduce = false; for (const StmtSRef& loop_sref : loop_srefs) { - const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); IterVarType iter_type = GetLoopIterType(self, loop_sref); if (iter_type == IterVarType::kDataPar) { if (visited_reduce) { diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 59067986ed..2b536832dc 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -339,7 +339,7 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, TVM_TIR_SCHEDULE_BEGIN(); // Prepare for the splitting StmtSRef loop_sref = this->GetSRef(loop_rv); - const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); PrimExpr len = loop->extent; // Find out the None int n = factor_rvs.size(); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 2db0273dd4..3384d6394e 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -214,13 +214,13 @@ class ConcreteScheduleNode : public ScheduleNode { inline Block ConcreteScheduleNode::Get(const BlockRV& block_rv) const { StmtSRef sref = this->GetSRef(block_rv); - const auto* block = TVM_SREF_TO_BLOCK(block, sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, sref); return GetRef(block); } inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const { StmtSRef sref = this->GetSRef(loop_rv); - const auto* loop = TVM_SREF_TO_FOR(loop, sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); return GetRef(loop); } diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 369e723963..f02e2876f7 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -390,7 +390,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis) std::unordered_map loop_vars; Array loops = GetLoops(block_sref); for (const StmtSRef& l_sref : loops) { - const auto* l = TVM_SREF_TO_FOR(l, l_sref); + const ForNode* l = TVM_SREF_TO_FOR(l, l_sref); if (l == loop) { CHECK(!data_par_iters.count(l->loop_var.get())) << "ValueError: The rfactor loop cannot be touched by data parallel block vars"; @@ -598,7 +598,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis) // IR replacement later. Optional replace_top = NullOpt; for (int i = static_cast(loops.size()) - 1; i >= 0; --i) { - const auto* l = TVM_SREF_TO_FOR(l, loops[i]); + const ForNode* l = TVM_SREF_TO_FOR(l, loops[i]); if (l->body->IsInstance()) { ICHECK_NE(i, static_cast(loops.size()) - 1) << "ValueError: The body of the innermost loop must not be a SeqStmt"; diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index c024c6626c..abebc69469 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -26,7 +26,7 @@ std::vector SamplePerfectTile(tir::ScheduleState self, Sampler* sampler const tir::StmtSRef& loop_sref, int n, int max_innermost_factor, Optional>* decision) { - const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); int64_t extent = GetLoopIntExtent(loop); std::vector result; if (extent == -1) { diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index fa2fe4404e..6116eb2294 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -1046,7 +1046,7 @@ void ScheduleStateNode::DebugVerify() const { /**************** BlockInfo-related ****************/ BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const { - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); auto it = this->block_info.find(block_sref); CHECK(it != this->block_info.end()) << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n" diff --git a/src/tir/schedule/utils.cc b/src/tir/schedule/utils.cc index 726ac65d34..357a363bdf 100644 --- a/src/tir/schedule/utils.cc +++ b/src/tir/schedule/utils.cc @@ -173,7 +173,7 @@ BufferRegion SubstituteBufferRegion(const BufferRegion& buffer_region, } BlockRealize GetBlockRealize(const StmtSRef& block_sref) { - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); // We cannot support getting the BlockRealize of the root block, since the parent sref of the root // block sref is `nullptr`. CHECK(block_sref->parent != nullptr) @@ -443,7 +443,7 @@ void UpdateAffineFlag(ScheduleState self, const StmtSRef& block_sref) { return; } BlockRealize realize = GetBlockRealize(block_sref); - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); Map loop_var_ranges; for (StmtSRefNode* loop_sref = block_sref->parent; loop_sref != nullptr; loop_sref = loop_sref->parent) { diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 4eeb19ca04..c6e0806a94 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -610,7 +610,7 @@ inline int64_t GetLoopIntExtent(const ForNode* loop) { } inline int64_t GetLoopIntExtent(const StmtSRef& loop_sref) { - const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); return GetLoopIntExtent(loop); }