Skip to content

Commit

Permalink
[Meta Schedule] GPU End-to-end Alignment (#395)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhongyii authored Jun 4, 2021
1 parent 3e77665 commit 49304ac
Show file tree
Hide file tree
Showing 12 changed files with 466 additions and 24 deletions.
6 changes: 6 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,12 @@ TVM_DLL Pass AllreduceTransform();
*/
TVM_DLL Pass BufferFlatten();

/*!
* \brief Do part of the lowering process for feature extraction.
* \return The pass.
*/
TVM_DLL Pass PreprocessForFeatureExtraction();

/*!
* \brief Locate the buffer allocation to the exact position (usually is
* the lca of buffer access). This pass will inject opaque block
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/meta_schedule/search_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,15 @@ def cross_thread_reduction() -> SearchRule:
The rule created
"""
return _ffi_api_search_rule.CrossThreadReduction() # pylint: disable=no-member


def special_compute_location_gpu():
"""A rule that handles special cases in Winograd transformation for GPU. We need to change the compute
location of the producers of compute ops that perform "fake reduction" with const tensors.
Returns
----------
rule: SearchRule
The search rule created
"""
return _ffi_api_search_rule.SpecialComputeLocationGPU() # pylint: disable=no-member
72 changes: 72 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
from tvm import topi
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.meta_schedule import is_meta_schedule_enabled
from tvm.contrib import nvcc
from tvm.contrib.thrust import can_use_thrust
from tvm.te import SpecializedCondition
Expand Down Expand Up @@ -77,6 +78,11 @@ def softmax_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_softmax),
name="softmax.cuda",
)
strategy.add_tir_implementation(
wrap_compute_softmax(topi.nn.softmax),
wrap_topi_schedule(topi.generic.default_tir_schedule),
name="tir_softmax.cuda",
)
if target.kind.name == "cuda" and "cudnn" in target.libs:
strategy.add_implementation(
wrap_compute_softmax(topi.cuda.softmax_cudnn),
Expand Down Expand Up @@ -131,6 +137,11 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
name="conv2d_nchw.cuda",
)
strategy.add_tir_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
wrap_topi_schedule(topi.generic.default_tir_schedule),
name="tir_conv2d_nchw.cuda",
)
_, _, kh, kw = get_const_tuple(kernel.shape)
if (
(2 < kh < 8 and 2 < kw < 8 and kh == kw)
Expand All @@ -157,6 +168,11 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
name="conv2d_nhwc.cuda",
)
strategy.add_tir_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
wrap_topi_schedule(topi.generic.default_tir_schedule),
name="tir_conv2d_nhwc.cuda"
)

N, H, W, _ = get_const_tuple(data.shape)
KH, KW, CI, CO = get_const_tuple(kernel.shape)
Expand Down Expand Up @@ -226,6 +242,14 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
plevel=15,
)

if is_meta_schedule_enabled() and judge_winograd_auto_scheduler:
strategy.add_tir_implementation(
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
wrap_topi_schedule(topi.generic.default_tir_schedule),
name="tir_conv2d_nhwc.winograd",
plevel=15,
)

elif layout == "HWNC":
assert kernel_layout in ["HWOI", "HWOI16o16i", "HWOI8o32i", "HWOI32o16i"]
_, _, N, in_channels = get_const_tuple(data.shape)
Expand Down Expand Up @@ -282,6 +306,11 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.cuda",
)
strategy.add_tir_implementation(
wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.generic.default_tir_schedule),
name="tir_depthwise_conv2d_nchw.cuda",
)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
strategy.add_implementation(
Expand Down Expand Up @@ -314,6 +343,11 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
name="group_conv2d_nchw.cuda",
)
strategy.add_tir_implementation(
wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
wrap_topi_schedule(topi.generic.default_tir_schedule),
name="tir_group_conv2d_nchw.cuda",
)
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
Expand Down Expand Up @@ -403,6 +437,11 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd_without_weight_transform),
name="conv2d_nchw_winograd_without_weight_transform.cuda",
)
strategy.add_tir_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd_without_weight_transform),
wrap_topi_schedule(topi.generic.default_tir_schedule),
name="tir_conv2d_nchw_winograd_without_weight_transform.cuda",
)
elif layout == "NHWC":
N, H, W, _ = get_const_tuple(data.shape)
alpha, _, CI, CO = get_const_tuple(kernel.shape)
Expand Down Expand Up @@ -446,6 +485,13 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty
),
name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda",
)
strategy.add_tir_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct_without_weight_transform),
wrap_topi_schedule(
topi.generic.default_tir_schedule
),
name="tir_conv2d_nhwc_winograd_direct_without_weight_transform.cuda",
)

if is_auto_scheduler_enabled():
strategy.add_implementation(
Expand All @@ -454,6 +500,13 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty
name="conv2d_nhwc_winograd_without_weight_transform",
plevel=15,
)
if is_meta_schedule_enabled():
strategy.add_tir_implementation(
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform),
wrap_topi_schedule(topi.generic.default_tir_schedule),
name="tir_conv2d_nhwc_winograd_without_weight_transform",
plevel=15,
)
else:
raise RuntimeError(
"Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout)
Expand Down Expand Up @@ -671,13 +724,25 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
name="dense_small_batch.cuda",
)

strategy.add_tir_implementation(
wrap_compute_dense(topi.cuda.dense_small_batch),
wrap_topi_schedule(topi.generic.default_tir_schedule),
name="tir_dense_small_batch.cuda",
)

with SpecializedCondition(b >= 32):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_large_batch),
wrap_topi_schedule(topi.cuda.schedule_dense_large_batch),
name="dense_large_batch.cuda",
plevel=5,
)
strategy.add_tir_implementation(
wrap_compute_dense(topi.cuda.dense_large_batch),
wrap_topi_schedule(topi.generic.default_tir_schedule),
name="tir_dense_large_batch.cuda",
plevel=5,
)
if target.kind.name == "cuda":
if nvcc.have_tensorcore(target=target):
if (
Expand Down Expand Up @@ -728,6 +793,13 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
name="batch_matmul.cuda",
plevel=10,
)

strategy.add_tir_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul),
wrap_topi_schedule(topi.generic.default_tir_schedule),
name="tir_batch_matmul.cuda",
plevel=10,
)
if target.kind.name == "cuda" and "cublas" in target.libs:
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul_cublas),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/generic/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def default_schedule(outs, auto_inline):
def default_tir_schedule(outs):
"""Default tir schedule for llvm."""
target = tvm.target.Target.current(allow_none=False)
if target.kind.name not in ("llvm", "c"):
if target.kind.name not in ("llvm", "c", "gpu", "cuda"):
raise RuntimeError("schedule not registered for '%s'" % target)
func = te.create_func(outs)

Expand Down
26 changes: 15 additions & 11 deletions src/meta_schedule/feature/per_block_feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <tvm/arith/int_set.h>
#include <tvm/support/parallel_for.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/transform.h>

#include <algorithm>
#include <numeric>
Expand Down Expand Up @@ -294,7 +295,9 @@ class PerBlockFeatureExtractor : public tir::StmtExprVisitor {
std::vector<FeatureSet> result;
result.reserve(extractor.ordered_blocks_.size());
for (const tir::BlockRealizeNode* realize : extractor.ordered_blocks_) {
result.push_back(extractor.per_block_feature_.at(realize));
if (!realize->block->name_hint.empty()) {
result.push_back(extractor.per_block_feature_.at(realize));
}
}
return result;
}
Expand Down Expand Up @@ -383,7 +386,7 @@ class PerBlockFeatureExtractor : public tir::StmtExprVisitor {
std::vector<int64_t> access_shape;
int64_t num_continuous_bytes = 1;
/*! \brief loop_accessed_numel[i][...] means the number of elements accessed by loops[i] */
std::vector<std::vector<int64_t>> loop_accessed_numel = {};
std::vector<std::unordered_map<const tir::BufferNode*, int64_t>> loop_accessed_numel = {};
// Stride info
int64_t min_stride = 0;
int64_t innermost_stride = 0;
Expand Down Expand Up @@ -427,7 +430,7 @@ class PerBlockFeatureExtractor : public tir::StmtExprVisitor {
// Note: `info.access_shape` for `i == n_loops - 1` is the only one preserved,
// while others are discarded
int64_t numel = CalcRegionUnionSize(info.regions, &info.access_shape);
info.loop_accessed_numel[i].push_back(numel);
info.loop_accessed_numel[i][buffer] = numel;
touched_bytes += numel * buffer->dtype.bytes();
buffer_touched_under_loop_[loop][buffer].push_back(numel);
}
Expand Down Expand Up @@ -594,7 +597,7 @@ class PerBlockFeatureExtractor : public tir::StmtExprVisitor {
feature.lines = 1;
feature.unique_lines = 1;
} else {
feature.unique_bytes = info.loop_accessed_numel.back().front() * dtype_bytes;
feature.unique_bytes = info.loop_accessed_numel.back().at(iter.first) * dtype_bytes;
double m = static_cast<double>(info.min_stride) * dtype_bytes / kCacheLineBytes;
feature.lines = outer_loop_prod_ / info.prod_non_strided_loop_extent * std::min(1.0, m);
feature.lines = std::max(1.0, feature.lines);
Expand Down Expand Up @@ -855,10 +858,6 @@ class PerBlockFeatureExtractor : public tir::StmtExprVisitor {
private:
/******** Visitors ********/
void VisitStmt_(const tir::BlockRealizeNode* realize) override {
// TODO(@jinhongyii): think of better ways of judging init block in the future
if (std::string(realize->block->name_hint).find("_init") != std::string::npos) {
return;
}
if (!scopes_.empty()) {
ordered_blocks_.push_back(realize);
}
Expand All @@ -876,8 +875,6 @@ class PerBlockFeatureExtractor : public tir::StmtExprVisitor {
const tir::StmtNode* stmt = *iter;
if (stmt->IsInstance<tir::ForNode>()) {
loops.push_back(static_cast<const tir::ForNode*>(stmt));
} else {
break;
}
}
FeatureSet& feature = per_block_feature_[realize];
Expand Down Expand Up @@ -1090,7 +1087,14 @@ runtime::NDArray PerBlockFeature(const Schedule& sch, int max_num_buffer_access_
size_t kNumFeature = kNumFeatureGroup1 +
kNumFeatureGroup2Subgroup * max_num_buffer_access_features +
kNumFeatureGroup3 + kNumFeatureGroup5;
tir::PrimFunc func = GetOnlyFunc(sch->mod());

IRModule mod = sch->mod();
auto pass_list = Array<tvm::transform::Pass>();
pass_list.push_back(tir::transform::PreprocessForFeatureExtraction());
pass_list.push_back(tir::transform::Simplify());
const auto& optimize = tir::transform::Sequential(pass_list);
mod = optimize(std::move(mod));
tir::PrimFunc func = GetOnlyFunc(mod);
std::vector<FeatureSet> feature_map = PerBlockFeatureExtractor::Extract(func);

DoubleNDArrayPusher ret(
Expand Down
9 changes: 7 additions & 2 deletions src/meta_schedule/space/postproc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,10 @@ class PostprocRewriteUnboundBlocks {
(n_spatial_loops > 0 && loop_sref->seq_index != -1)) {
break;
}
if (!HasSingleChild(loop_sref)) {
n_spatial_loops++;
break;
}
++n_spatial_loops;
}
CHECK_GT(n_spatial_loops, 0) << "ValueError: not supported when spatial loop doesn't exist";
Expand Down Expand Up @@ -748,7 +752,8 @@ class PostprocRewriteReductionBlock {
for (int i = 0; i < n_loops; ++i) {
const LoopRV& loop_rv = loop_rvs[i];
tir::StmtSRef loop_sref = sch->GetSRef(loop_rv);
if (GetLoopIterType(sch->state(), loop_sref) != tir::kDataPar) {
tir::IterVarType type = GetLoopIterType(sch->state(), loop_sref);
if (type == tir::kCommReduce || type == tir::kOpaque) {
// Insert the initializing block above the first loop which is not data parallel.
sch->DecomposeReduction(block_rv, loop_rvs[i]);
break;
Expand Down Expand Up @@ -828,7 +833,7 @@ class PostprocVerifyGPUCode {
{"max_local_memory_per_block", Extract(target, "registers_per_block")},
{"max_threads_per_block", Extract(target, "max_threads_per_block")},
{"max_vthread", Integer(8)},
};
{"max_vector_bytes", Integer(16)}};
return tir::VerifyGPUCode(func, constraints);
}

Expand Down
46 changes: 43 additions & 3 deletions src/meta_schedule/space/search_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ class RuleMultiLevelTiling {
} else {
state.write_cache = write_cache;
}

state.write_cache_is_added = true;
result.push_back(std::move(state));
return result;
Expand Down Expand Up @@ -505,9 +506,6 @@ class RuleMultiLevelTiling {
Array<Schedule> Apply(const SearchTask& task, const Schedule& sch,
const BlockRV& block_rv) const {
tir::StmtSRef block_sref = sch->GetSRef(block_rv);
if (HasAnyAnn(block_sref)) {
return {sch};
}
if (!NeedsMultiLevelTiling(sch->state(), block_sref)) {
return {sch};
}
Expand Down Expand Up @@ -1174,6 +1172,46 @@ SearchRule CrossThreadReduction() {
return SearchRule("cross_thread_reduction", f_apply);
};

/********** SpecialComputeLocationGPU **********/
class RuleSpecialComputeLocationGPU {
public:
Array<Schedule> Apply(const SearchTask& task, const Schedule& sch,
const BlockRV& block_rv) const {
if (sch->GetProducers(block_rv).empty()) {
return {sch};
}
tir::StmtSRef block_sref = sch->GetSRef(block_rv);
if (!RuleInlinePureSpatial::NeedsInline(sch, block_sref, false)) {
return {sch};
}
Array<BlockRV> consumers = sch->GetConsumers(block_rv);
tir::Block block = sch->Get(block_rv);
if (consumers.size() != 1 ||
!sch->Get(consumers[0])
->annotations.count(
tvm::auto_scheduler::SearchPolicyKey::simplify_const_tensor_indices)) {
return {sch};
}
Array<tir::LoopRV> consumer_loops = sch->GetAxes(consumers[0]);
for (size_t i = 0; i < consumer_loops.size(); i++) {
tir::StmtSRef loop_sref = sch->GetSRef(consumer_loops[i]);
if (tir::GetLoopIterType(sch->state(), loop_sref) == tir::kUnrolled) {
sch->ComputeAt(block_rv, consumer_loops[i - 1], true);
sch->SetScope(block_rv, 0, "local");
break;
}
}
return {sch};
}
};

SearchRule SpecialComputeLocationGPU() {
auto f_apply = [](SearchTask task, Schedule sch, BlockRV block) -> Array<Schedule> {
return RuleSpecialComputeLocationGPU().Apply(task, sch, block);
};
return SearchRule("special_compute_location_gpu", f_apply);
}

/********** FFI **********/

struct Internal {
Expand Down Expand Up @@ -1227,6 +1265,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.search_rule.SimplifyComputeWithConstTensor")
TVM_REGISTER_GLOBAL("meta_schedule.search_rule.AddRFactor").set_body_typed(AddRFactor);
TVM_REGISTER_GLOBAL("meta_schedule.search_rule.CrossThreadReduction")
.set_body_typed(CrossThreadReduction);
TVM_REGISTER_GLOBAL("meta_schedule.search_rule.SpecialComputeLocationGPU")
.set_body_typed(SpecialComputeLocationGPU);

} // namespace meta_schedule
} // namespace tvm
6 changes: 6 additions & 0 deletions src/meta_schedule/space/search_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ TVM_DLL SearchRule MarkTensorize(Array<tir::TensorIntrin> tensor_intrins);
* \return The rule created
*/
TVM_DLL SearchRule AddRFactor(int max_jobs_per_core, int max_innermost_factor);
/*!
* \brief Handle special cases in Winograd transformation for GPU. We need to change the compute
* location of the producers of compute ops that perform "fake reduction" with const tensors.
* \return The rule created
*/
TVM_DLL SearchRule SpecialComputeLocationGPU();

} // namespace meta_schedule
} // namespace tvm
Expand Down
Loading

0 comments on commit 49304ac

Please sign in to comment.