Skip to content

Commit

Permalink
Explicitly set types for TVM_SREF_TO_ERROR/FOR/BLOCK (#406)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 authored Jun 30, 2021
1 parent e7421a1 commit 763b3c1
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 39 deletions.
36 changes: 20 additions & 16 deletions src/meta_schedule/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@
namespace tvm {
namespace meta_schedule {

/**************** TIR Nodes ****************/
using tir::BlockNode;
using tir::ForNode;

bool IsTrivialBinding(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
tir::BlockRealize realize = tir::GetBlockRealize(block_sref);
Array<tir::StmtSRef> loops = tir::GetLoops(block_sref);
const Array<PrimExpr>& bindings = realize->iter_values;
Expand All @@ -37,7 +41,7 @@ bool IsTrivialBinding(const tir::ScheduleState& self, const tir::StmtSRef& block
int n = loops.size();
for (int i = 0; i < n; ++i) {
const PrimExpr& bind = bindings[i];
const auto* loop = TVM_SREF_TO_FOR(loop, loops[i]);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loops[i]);
if (bind.as<tir::VarNode>() != loop->loop_var.get()) {
return false;
}
Expand All @@ -51,7 +55,7 @@ bool IsSubrootBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_s
}

bool IsLeafBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
bool no_child = true;
tir::PreOrderVisit(block->body, [&no_child](const ObjectRef& obj) -> bool {
if (!no_child) {
Expand All @@ -67,7 +71,7 @@ bool IsLeafBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref
}

Array<Integer> GetBlockVarTypes(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
Array<Integer> result;
for (const tir::IterVar& iter_var : block->iter_vars) {
int iter_type = iter_var->iter_type;
Expand All @@ -77,7 +81,7 @@ Array<Integer> GetBlockVarTypes(const tir::ScheduleState& self, const tir::StmtS
}

bool IsSpatial(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
for (const tir::IterVar& iter_var : block->iter_vars) {
if (iter_var->iter_type != tir::IterVarType::kDataPar) {
return false;
Expand All @@ -88,8 +92,8 @@ bool IsSpatial(const tir::ScheduleState& self, const tir::StmtSRef& block_sref)

bool IsOutputBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
tir::StmtSRef parent_sref = tir::GetScopeRoot(block_sref).value();
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const auto* parent = TVM_SREF_TO_BLOCK(parent, parent_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* parent = TVM_SREF_TO_BLOCK(parent, parent_sref);
if (parent_sref->parent == nullptr) {
const tir::PrimFuncNode* func = tir::GetRootPrimFunc(self, parent_sref);
for (const tir::BufferRegion& write : block->writes) {
Expand All @@ -112,7 +116,7 @@ bool IsOutputBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sr
}

int CountOp(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, const Op& op) {
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
int count = 0;
tir::PostOrderVisit(block->body, [&count, &op](const ObjectRef& obj) {
if (const auto* call = obj.as<tir::CallNode>()) {
Expand All @@ -125,7 +129,7 @@ int CountOp(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, con
}

bool HasBranch(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
bool has_branch = false;
arith::Analyzer analyzer;
auto f_visit = [&has_branch, &analyzer](const ObjectRef& obj) -> bool {
Expand Down Expand Up @@ -214,8 +218,8 @@ bool IsElementWiseMatch(const tir::ScheduleState& self, const tir::StmtSRef& pro
const tir::StmtSRef& consumer_sref) {
// Assume consumer is the only consumer of the producer
tir::StmtSRef parent_sref = tir::GetScopeRoot(producer_sref).value();
const auto* producer = TVM_SREF_TO_BLOCK(producer, producer_sref);
const auto* consumer = TVM_SREF_TO_BLOCK(consumer, consumer_sref);
const BlockNode* producer = TVM_SREF_TO_BLOCK(producer, producer_sref);
const BlockNode* consumer = TVM_SREF_TO_BLOCK(consumer, consumer_sref);
if (producer->writes.empty()) {
return false;
}
Expand Down Expand Up @@ -285,7 +289,7 @@ bool NeedsMultiLevelTiling(const tir::ScheduleState& self, const tir::StmtSRef&
if (!IsTrivialBinding(self, block_sref)) {
return false;
}
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
// Assume complete/reduction block
if (block->writes.size() != 1) {
return false;
Expand Down Expand Up @@ -333,7 +337,7 @@ bool NeedsMultiLevelTiling(const tir::ScheduleState& self, const tir::StmtSRef&

bool IsStrictlyInlineable(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
static const Op& op_tir_exp = Op::Get("tir.exp");
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
// Const tensors are strictly inlineable
if (block->reads.empty()) {
return true;
Expand Down Expand Up @@ -764,7 +768,7 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
int64_t max_parallel_extent,
int64_t basic_parallel_extent) {
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
Array<tir::StmtSRef> loops = tir::GetLoops(block_sref);

// Cond 1. The block is a reduction block and has trivial binding.
Expand All @@ -791,9 +795,9 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self,
}

// Cond 4.
const auto* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]);
const ForNode* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]);
if (i < static_cast<int>(loops.size()) - 1) {
const auto* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]);
const ForNode* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]);
if (loop_i->body.get() != loop_i1) {
return false;
}
Expand Down
8 changes: 6 additions & 2 deletions src/meta_schedule/space/search_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
namespace tvm {
namespace meta_schedule {

/**************** TIR Nodes ****************/
using tir::ForNode;
using tir::BlockNode;

/********** Constructors **********/

SearchRule::SearchRule(String name, SearchRuleNode::FApply apply) {
Expand Down Expand Up @@ -962,7 +966,7 @@ class RuleAddRFactor {
const BlockRV& block_rv) const {
// Check the conditions of the rule.
tir::StmtSRef block_sref = sch->GetSRef(block_rv);
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
if (HasAnyAnn(block_sref)) {
return {sch};
}
Expand Down Expand Up @@ -1041,7 +1045,7 @@ class RuleCrossThreadReduction {

// Check the conditions of the rule.
const tir::StmtSRef& block_sref = sch->GetSRef(block_rv);
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
if (HasAnyAnn(block_sref)) {
return {sch};
}
Expand Down
7 changes: 5 additions & 2 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ using tir::AsArray;
using tir::AsOptArray;
using tir::AsVector;

/**************** TIR Nodes ****************/
using tir::ForNode;

/*!
* \brief Compute mean of a FloatImm array.
* Taken from Ansor
Expand Down Expand Up @@ -179,7 +182,7 @@ inline Optional<tir::StmtSRef> FindBlockSRef(const tir::ScheduleState& sch, FPre
/**************** TIR Annotation ****************/

inline bool HasBinding(const tir::StmtSRef& loop_sref, const String& thread_tag) {
const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
if (!loop->thread_binding.defined()) {
return false;
}
Expand All @@ -191,7 +194,7 @@ inline bool HasBinding(const tir::StmtSRef& loop_sref, const String& thread_tag)
}

inline Optional<String> GetBinding(const tir::StmtSRef& loop_sref) {
const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
if (!loop->thread_binding.defined()) {
return NullOpt;
}
Expand Down
18 changes: 9 additions & 9 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root) {
BlockScope scope = self->GetBlockScope(scope_root);
// Cond 1. All block vars are data parallel
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
for (const IterVar& iter_var : block->iter_vars) {
if (iter_var->iter_type != kDataPar) {
return false;
Expand Down Expand Up @@ -307,7 +307,7 @@ bool RegionCoveredConsumer(const ScheduleState& self, const StmtSRef& consumer_b
if (consumer_block_sref->parent == nullptr) {
return true;
}
const auto* consumer_block = TVM_SREF_TO_BLOCK(consumer_block, consumer_block_sref);
const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block, consumer_block_sref);
BlockScope scope = self->GetBlockScope(scope_root);
// Step 1. Gather all the producers
struct Producer {
Expand All @@ -326,7 +326,7 @@ bool RegionCoveredConsumer(const ScheduleState& self, const StmtSRef& consumer_b
// i.e. the RAW predecessor is producer
if (edge->kind == DepKind::kRAW) {
const StmtSRef& producer_block_sref = edge->src;
const auto* producer_block = TVM_SREF_TO_BLOCK(producer_block, producer_block_sref);
const BlockNode* producer_block = TVM_SREF_TO_BLOCK(producer_block, producer_block_sref);
for (const BufferRegion& output_region : producer_block->writes) {
const VarNode* buffer_var = output_region->buffer->data.get();
buffer_producers[buffer_var].emplace_back(producer_block_sref, output_region);
Expand Down Expand Up @@ -600,7 +600,7 @@ bool CompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root) {
BlockScope scope = self->GetBlockScope(scope_root);
// Cond 2. Check if all the block vars are data parallel
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
for (const IterVar& iter_var : block->iter_vars) {
if (iter_var->iter_type != kDataPar) {
return false;
Expand Down Expand Up @@ -631,7 +631,7 @@ bool ReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
// return false;
// }
// Cond 4. Check whether the block body has the init statement.
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
if (!block->init.defined()) {
return false;
}
Expand Down Expand Up @@ -677,8 +677,8 @@ bool ReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
bool CanMergeReduction(const ScheduleState& self, const StmtSRef& init_block_sref,
const StmtSRef& update_block_sref, const StmtSRef& scope_root) {
BlockScope scope = self->GetBlockScope(scope_root);
const auto* init = TVM_SREF_TO_BLOCK(init, init_block_sref);
const auto* update = TVM_SREF_TO_BLOCK(update, update_block_sref);
const BlockNode* init = TVM_SREF_TO_BLOCK(init, init_block_sref);
const BlockNode* update = TVM_SREF_TO_BLOCK(update, update_block_sref);
// Cond 1. Check the binding of update block is valid
if (!self->IsAffineBlockBinding(update_block_sref)) {
return false;
Expand Down Expand Up @@ -731,7 +731,7 @@ IterVarType GetLoopIterType(const ScheduleState& self, const StmtSRef& loop_sref
int n_spatial = 0;
int n_reduce = 0;
int n_other = 0;
const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
const Var& loop_var = loop->loop_var;
auto f_visit = [&loop_var, &n_spatial, &n_reduce, &n_other](const ObjectRef& obj) -> bool {
if (const auto* realize = obj.as<BlockRealizeNode>()) {
Expand Down Expand Up @@ -785,7 +785,7 @@ Array<StmtSRef> CollectComputeLocation(const ScheduleState& self, const StmtSRef
result.reserve(loop_srefs.size());
bool visited_reduce = false;
for (const StmtSRef& loop_sref : loop_srefs) {
const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
IterVarType iter_type = GetLoopIterType(self, loop_sref);
if (iter_type == IterVarType::kDataPar) {
if (visited_reduce) {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
TVM_TIR_SCHEDULE_BEGIN();
// Prepare for the splitting
StmtSRef loop_sref = this->GetSRef(loop_rv);
const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
PrimExpr len = loop->extent;
// Find out the None
int n = factor_rvs.size();
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,13 @@ class ConcreteScheduleNode : public ScheduleNode {

inline Block ConcreteScheduleNode::Get(const BlockRV& block_rv) const {
StmtSRef sref = this->GetSRef(block_rv);
const auto* block = TVM_SREF_TO_BLOCK(block, sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, sref);
return GetRef<Block>(block);
}

inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const {
StmtSRef sref = this->GetSRef(loop_rv);
const auto* loop = TVM_SREF_TO_FOR(loop, sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, sref);
return GetRef<For>(loop);
}

Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/primitive/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis)
std::unordered_map<const VarNode*, For> loop_vars;
Array<StmtSRef> loops = GetLoops(block_sref);
for (const StmtSRef& l_sref : loops) {
const auto* l = TVM_SREF_TO_FOR(l, l_sref);
const ForNode* l = TVM_SREF_TO_FOR(l, l_sref);
if (l == loop) {
CHECK(!data_par_iters.count(l->loop_var.get()))
<< "ValueError: The rfactor loop cannot be touched by data parallel block vars";
Expand Down Expand Up @@ -598,7 +598,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis)
// IR replacement later.
Optional<StmtSRef> replace_top = NullOpt;
for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
const auto* l = TVM_SREF_TO_FOR(l, loops[i]);
const ForNode* l = TVM_SREF_TO_FOR(l, loops[i]);
if (l->body->IsInstance<SeqStmtNode>()) {
ICHECK_NE(i, static_cast<int>(loops.size()) - 1)
<< "ValueError: The body of the innermost loop must not be a SeqStmt";
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive/sampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ std::vector<int64_t> SamplePerfectTile(tir::ScheduleState self, Sampler* sampler
const tir::StmtSRef& loop_sref, int n,
int max_innermost_factor,
Optional<Array<Integer>>* decision) {
const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
int64_t extent = GetLoopIntExtent(loop);
std::vector<int64_t> result;
if (extent == -1) {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,7 @@ void ScheduleStateNode::DebugVerify() const {
/**************** BlockInfo-related ****************/

BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const {
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
auto it = this->block_info.find(block_sref);
CHECK(it != this->block_info.end())
<< "IndexError: Cannot find the corresponding BlockScope to the block sref:\n"
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ BufferRegion SubstituteBufferRegion(const BufferRegion& buffer_region,
}

BlockRealize GetBlockRealize(const StmtSRef& block_sref) {
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
// We cannot support getting the BlockRealize of the root block, since the parent sref of the root
// block sref is `nullptr`.
CHECK(block_sref->parent != nullptr)
Expand Down Expand Up @@ -443,7 +443,7 @@ void UpdateAffineFlag(ScheduleState self, const StmtSRef& block_sref) {
return;
}
BlockRealize realize = GetBlockRealize(block_sref);
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
Map<Var, Range> loop_var_ranges;
for (StmtSRefNode* loop_sref = block_sref->parent; loop_sref != nullptr;
loop_sref = loop_sref->parent) {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ inline int64_t GetLoopIntExtent(const ForNode* loop) {
}

inline int64_t GetLoopIntExtent(const StmtSRef& loop_sref) {
const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
return GetLoopIntExtent(loop);
}

Expand Down

0 comments on commit 763b3c1

Please sign in to comment.