diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 26af341bfb..60b0825290 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -354,6 +354,12 @@ TVM_DLL Pass AllreduceTransform(); */ TVM_DLL Pass BufferFlatten(); +/*! + * \brief Do part of the lowering process for feature extraction. + * \return The pass. + */ +TVM_DLL Pass PreprocessForFeatureExtraction(); + /*! * \brief Locate the buffer allocation to the exact position (usually is * the lca of buffer access). This pass will inject opaque block diff --git a/python/tvm/meta_schedule/search_rule.py b/python/tvm/meta_schedule/search_rule.py index 2e3c5a8e3e..d27146aefc 100644 --- a/python/tvm/meta_schedule/search_rule.py +++ b/python/tvm/meta_schedule/search_rule.py @@ -317,3 +317,15 @@ def cross_thread_reduction() -> SearchRule: The rule created """ return _ffi_api_search_rule.CrossThreadReduction() # pylint: disable=no-member + + +def special_compute_location_gpu(): + """A rule that handles special cases in Winograd transformation for GPU. We need to change the compute + location of the producers of compute ops that perform "fake reduction" with const tensors. + + Returns + ---------- + rule: SearchRule + The search rule created + """ + return _ffi_api_search_rule.SpecialComputeLocationGPU() # pylint: disable=no-member diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 1a67425266..6d0d20b84d 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -18,6 +18,7 @@ # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import from tvm import topi from tvm.auto_scheduler import is_auto_scheduler_enabled +from tvm.meta_schedule import is_meta_schedule_enabled from tvm.contrib import nvcc from tvm.contrib.thrust import can_use_thrust from tvm.te import SpecializedCondition @@ -77,6 +78,11 @@ def softmax_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_softmax), name="softmax.cuda", ) + strategy.add_tir_implementation( + wrap_compute_softmax(topi.nn.softmax), + wrap_topi_schedule(topi.generic.default_tir_schedule), + name="tir_softmax.cuda", + ) if target.kind.name == "cuda" and "cudnn" in target.libs: strategy.add_implementation( wrap_compute_softmax(topi.cuda.softmax_cudnn), @@ -131,6 +137,11 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw), name="conv2d_nchw.cuda", ) + strategy.add_tir_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_nchw), + wrap_topi_schedule(topi.generic.default_tir_schedule), + name="tir_conv2d_nchw.cuda", + ) _, _, kh, kw = get_const_tuple(kernel.shape) if ( (2 < kh < 8 and 2 < kw < 8 and kh == kw) @@ -157,6 +168,11 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc), name="conv2d_nhwc.cuda", ) + strategy.add_tir_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_nhwc), + wrap_topi_schedule(topi.generic.default_tir_schedule), + name="tir_conv2d_nhwc.cuda" + ) N, H, W, _ = get_const_tuple(data.shape) KH, KW, CI, CO = get_const_tuple(kernel.shape) @@ -226,6 +242,14 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): plevel=15, ) + if is_meta_schedule_enabled() and judge_winograd_auto_scheduler: + strategy.add_tir_implementation( + wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc), + wrap_topi_schedule(topi.generic.default_tir_schedule), + name="tir_conv2d_nhwc.winograd", + plevel=15, + ) + elif layout == "HWNC": assert kernel_layout in ["HWOI", "HWOI16o16i", "HWOI8o32i", "HWOI32o16i"] _, _, N, in_channels = get_const_tuple(data.shape) @@ -282,6 +306,11 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw), name="depthwise_conv2d_nchw.cuda", ) + strategy.add_tir_implementation( + wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw), + wrap_topi_schedule(topi.generic.default_tir_schedule), + name="tir_depthwise_conv2d_nchw.cuda", + ) elif layout == "NHWC": assert kernel_layout == "HWOI" strategy.add_implementation( @@ -314,6 +343,11 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw), name="group_conv2d_nchw.cuda", ) + strategy.add_tir_implementation( + wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True), + wrap_topi_schedule(topi.generic.default_tir_schedule), + name="tir_group_conv2d_nchw.cuda", + ) elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: assert kernel_layout == "OIHW4o4i" strategy.add_implementation( @@ -403,6 +437,11 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd_without_weight_transform), name="conv2d_nchw_winograd_without_weight_transform.cuda", ) + strategy.add_tir_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd_without_weight_transform), + wrap_topi_schedule(topi.generic.default_tir_schedule), + name="tir_conv2d_nchw_winograd_without_weight_transform.cuda", + ) elif layout == "NHWC": N, H, W, _ = get_const_tuple(data.shape) alpha, _, CI, CO = get_const_tuple(kernel.shape) @@ -446,6 +485,13 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty ), name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda", ) + strategy.add_tir_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct_without_weight_transform), + wrap_topi_schedule( + topi.generic.default_tir_schedule + ), + name="tir_conv2d_nhwc_winograd_direct_without_weight_transform.cuda", + ) if is_auto_scheduler_enabled(): strategy.add_implementation( @@ -454,6 +500,13 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty name="conv2d_nhwc_winograd_without_weight_transform", plevel=15, ) + if is_meta_schedule_enabled(): + strategy.add_tir_implementation( + wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform), + wrap_topi_schedule(topi.generic.default_tir_schedule), + name="tir_conv2d_nhwc_winograd_without_weight_transform", + plevel=15, + ) else: raise RuntimeError( "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout) @@ -671,6 +724,12 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): name="dense_small_batch.cuda", ) + strategy.add_tir_implementation( + wrap_compute_dense(topi.cuda.dense_small_batch), + wrap_topi_schedule(topi.generic.default_tir_schedule), + name="tir_dense_small_batch.cuda", + ) + with SpecializedCondition(b >= 32): strategy.add_implementation( wrap_compute_dense(topi.cuda.dense_large_batch), @@ -678,6 +737,12 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): name="dense_large_batch.cuda", plevel=5, ) + strategy.add_tir_implementation( + wrap_compute_dense(topi.cuda.dense_large_batch), + wrap_topi_schedule(topi.generic.default_tir_schedule), + name="tir_dense_large_batch.cuda", + plevel=5, + ) if target.kind.name == "cuda": if nvcc.have_tensorcore(target=target): if ( @@ -728,6 +793,13 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): name="batch_matmul.cuda", plevel=10, ) + + strategy.add_tir_implementation( + wrap_compute_batch_matmul(topi.cuda.batch_matmul), + wrap_topi_schedule(topi.generic.default_tir_schedule), + name="tir_batch_matmul.cuda", + plevel=10, + ) if target.kind.name == "cuda" and "cublas" in target.libs: strategy.add_implementation( wrap_compute_batch_matmul(topi.cuda.batch_matmul_cublas), diff --git a/python/tvm/topi/generic/default.py b/python/tvm/topi/generic/default.py index 15d7af4191..3df53ad3fa 100644 --- a/python/tvm/topi/generic/default.py +++ b/python/tvm/topi/generic/default.py @@ -37,7 +37,7 @@ def default_schedule(outs, auto_inline): def default_tir_schedule(outs): """Default tir schedule for llvm.""" target = tvm.target.Target.current(allow_none=False) - if target.kind.name not in ("llvm", "c"): + if target.kind.name not in ("llvm", "c", "gpu", "cuda"): raise RuntimeError("schedule not registered for '%s'" % target) func = te.create_func(outs) diff --git a/src/meta_schedule/feature/per_block_feature.cc b/src/meta_schedule/feature/per_block_feature.cc index f583fbda27..8e92f48eea 100644 --- a/src/meta_schedule/feature/per_block_feature.cc +++ b/src/meta_schedule/feature/per_block_feature.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -294,7 +295,9 @@ class PerBlockFeatureExtractor : public tir::StmtExprVisitor { std::vector result; result.reserve(extractor.ordered_blocks_.size()); for (const tir::BlockRealizeNode* realize : extractor.ordered_blocks_) { - result.push_back(extractor.per_block_feature_.at(realize)); + if (!realize->block->name_hint.empty()) { + result.push_back(extractor.per_block_feature_.at(realize)); + } } return result; } @@ -383,7 +386,7 @@ class PerBlockFeatureExtractor : public tir::StmtExprVisitor { std::vector access_shape; int64_t num_continuous_bytes = 1; /*! \brief loop_accessed_numel[i][...] means the number of elements accessed by loops[i] */ - std::vector> loop_accessed_numel = {}; + std::vector> loop_accessed_numel = {}; // Stride info int64_t min_stride = 0; int64_t innermost_stride = 0; @@ -427,7 +430,7 @@ class PerBlockFeatureExtractor : public tir::StmtExprVisitor { // Note: `info.access_shape` for `i == n_loops - 1` is the only one preserved, // while others are discarded int64_t numel = CalcRegionUnionSize(info.regions, &info.access_shape); - info.loop_accessed_numel[i].push_back(numel); + info.loop_accessed_numel[i][buffer] = numel; touched_bytes += numel * buffer->dtype.bytes(); buffer_touched_under_loop_[loop][buffer].push_back(numel); } @@ -594,7 +597,7 @@ class PerBlockFeatureExtractor : public tir::StmtExprVisitor { feature.lines = 1; feature.unique_lines = 1; } else { - feature.unique_bytes = info.loop_accessed_numel.back().front() * dtype_bytes; + feature.unique_bytes = info.loop_accessed_numel.back().at(iter.first) * dtype_bytes; double m = static_cast(info.min_stride) * dtype_bytes / kCacheLineBytes; feature.lines = outer_loop_prod_ / info.prod_non_strided_loop_extent * std::min(1.0, m); feature.lines = std::max(1.0, feature.lines); @@ -855,10 +858,6 @@ class PerBlockFeatureExtractor : public tir::StmtExprVisitor { private: /******** Visitors ********/ void VisitStmt_(const tir::BlockRealizeNode* realize) override { - // TODO(@jinhongyii): think of better ways of judging init block in the future - if (std::string(realize->block->name_hint).find("_init") != std::string::npos) { - return; - } if (!scopes_.empty()) { ordered_blocks_.push_back(realize); } @@ -876,8 +875,6 @@ class PerBlockFeatureExtractor : public tir::StmtExprVisitor { const tir::StmtNode* stmt = *iter; if (stmt->IsInstance()) { loops.push_back(static_cast(stmt)); - } else { - break; } } FeatureSet& feature = per_block_feature_[realize]; @@ -1090,7 +1087,14 @@ runtime::NDArray PerBlockFeature(const Schedule& sch, int max_num_buffer_access_ size_t kNumFeature = kNumFeatureGroup1 + kNumFeatureGroup2Subgroup * max_num_buffer_access_features + kNumFeatureGroup3 + kNumFeatureGroup5; - tir::PrimFunc func = GetOnlyFunc(sch->mod()); + + IRModule mod = sch->mod(); + auto pass_list = Array(); + pass_list.push_back(tir::transform::PreprocessForFeatureExtraction()); + pass_list.push_back(tir::transform::Simplify()); + const auto& optimize = tir::transform::Sequential(pass_list); + mod = optimize(std::move(mod)); + tir::PrimFunc func = GetOnlyFunc(mod); std::vector feature_map = PerBlockFeatureExtractor::Extract(func); DoubleNDArrayPusher ret( diff --git a/src/meta_schedule/space/postproc.cc b/src/meta_schedule/space/postproc.cc index d5511f0b4c..b0f5bd0cac 100644 --- a/src/meta_schedule/space/postproc.cc +++ b/src/meta_schedule/space/postproc.cc @@ -605,6 +605,10 @@ class PostprocRewriteUnboundBlocks { (n_spatial_loops > 0 && loop_sref->seq_index != -1)) { break; } + if (!HasSingleChild(loop_sref)) { + n_spatial_loops++; + break; + } ++n_spatial_loops; } CHECK_GT(n_spatial_loops, 0) << "ValueError: not supported when spatial loop doesn't exist"; @@ -748,7 +752,8 @@ class PostprocRewriteReductionBlock { for (int i = 0; i < n_loops; ++i) { const LoopRV& loop_rv = loop_rvs[i]; tir::StmtSRef loop_sref = sch->GetSRef(loop_rv); - if (GetLoopIterType(sch->state(), loop_sref) != tir::kDataPar) { + tir::IterVarType type = GetLoopIterType(sch->state(), loop_sref); + if (type == tir::kCommReduce || type == tir::kOpaque) { // Insert the initializing block above the first loop which is not data parallel. sch->DecomposeReduction(block_rv, loop_rvs[i]); break; @@ -828,7 +833,7 @@ class PostprocVerifyGPUCode { {"max_local_memory_per_block", Extract(target, "registers_per_block")}, {"max_threads_per_block", Extract(target, "max_threads_per_block")}, {"max_vthread", Integer(8)}, - }; + {"max_vector_bytes", Integer(16)}}; return tir::VerifyGPUCode(func, constraints); } diff --git a/src/meta_schedule/space/search_rule.cc b/src/meta_schedule/space/search_rule.cc index 8839cb5383..946d3845b6 100644 --- a/src/meta_schedule/space/search_rule.cc +++ b/src/meta_schedule/space/search_rule.cc @@ -333,6 +333,7 @@ class RuleMultiLevelTiling { } else { state.write_cache = write_cache; } + state.write_cache_is_added = true; result.push_back(std::move(state)); return result; @@ -505,9 +506,6 @@ class RuleMultiLevelTiling { Array Apply(const SearchTask& task, const Schedule& sch, const BlockRV& block_rv) const { tir::StmtSRef block_sref = sch->GetSRef(block_rv); - if (HasAnyAnn(block_sref)) { - return {sch}; - } if (!NeedsMultiLevelTiling(sch->state(), block_sref)) { return {sch}; } @@ -1174,6 +1172,46 @@ SearchRule CrossThreadReduction() { return SearchRule("cross_thread_reduction", f_apply); }; +/********** SpecialComputeLocationGPU **********/ +class RuleSpecialComputeLocationGPU { + public: + Array Apply(const SearchTask& task, const Schedule& sch, + const BlockRV& block_rv) const { + if (sch->GetProducers(block_rv).empty()) { + return {sch}; + } + tir::StmtSRef block_sref = sch->GetSRef(block_rv); + if (!RuleInlinePureSpatial::NeedsInline(sch, block_sref, false)) { + return {sch}; + } + Array consumers = sch->GetConsumers(block_rv); + tir::Block block = sch->Get(block_rv); + if (consumers.size() != 1 || + !sch->Get(consumers[0]) + ->annotations.count( + tvm::auto_scheduler::SearchPolicyKey::simplify_const_tensor_indices)) { + return {sch}; + } + Array consumer_loops = sch->GetAxes(consumers[0]); + for (size_t i = 0; i < consumer_loops.size(); i++) { + tir::StmtSRef loop_sref = sch->GetSRef(consumer_loops[i]); + if (tir::GetLoopIterType(sch->state(), loop_sref) == tir::kUnrolled) { + sch->ComputeAt(block_rv, consumer_loops[i - 1], true); + sch->SetScope(block_rv, 0, "local"); + break; + } + } + return {sch}; + } +}; + +SearchRule SpecialComputeLocationGPU() { + auto f_apply = [](SearchTask task, Schedule sch, BlockRV block) -> Array { + return RuleSpecialComputeLocationGPU().Apply(task, sch, block); + }; + return SearchRule("special_compute_location_gpu", f_apply); +} + /********** FFI **********/ struct Internal { @@ -1227,6 +1265,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.search_rule.SimplifyComputeWithConstTensor") TVM_REGISTER_GLOBAL("meta_schedule.search_rule.AddRFactor").set_body_typed(AddRFactor); TVM_REGISTER_GLOBAL("meta_schedule.search_rule.CrossThreadReduction") .set_body_typed(CrossThreadReduction); +TVM_REGISTER_GLOBAL("meta_schedule.search_rule.SpecialComputeLocationGPU") + .set_body_typed(SpecialComputeLocationGPU); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space/search_rule.h b/src/meta_schedule/space/search_rule.h index 12ee25dc80..fc08d68c9d 100644 --- a/src/meta_schedule/space/search_rule.h +++ b/src/meta_schedule/space/search_rule.h @@ -153,6 +153,12 @@ TVM_DLL SearchRule MarkTensorize(Array tensor_intrins); * \return The rule created */ TVM_DLL SearchRule AddRFactor(int max_jobs_per_core, int max_innermost_factor); +/*! + * \brief Handle special cases in Winograd transformation for GPU. We need to change the compute + * location of the producers of compute ops that perform "fake reduction" with const tensors. + * \return The rule created + */ +TVM_DLL SearchRule SpecialComputeLocationGPU(); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/strategy/evolutionary.cc b/src/meta_schedule/strategy/evolutionary.cc index cd150c3568..e7b26c6422 100644 --- a/src/meta_schedule/strategy/evolutionary.cc +++ b/src/meta_schedule/strategy/evolutionary.cc @@ -500,9 +500,10 @@ Array EvolutionaryNode::SampleInitPopulation(const Array& suppo }; support::parallel_persist_for(0, results.size(), f_proc_measured); // Pick unmeasured states - std::atomic fail_ct(0); - auto f_proc_unmeasured = [this, &results, &thread_samplers, &fail_ct, &task, &space, &support, - thread_workloads](int thread_id, int i) -> void { + std::atomic tot_fail_ct(0); + std::atomic success_ct(0); + auto f_proc_unmeasured = [this, &results, &thread_samplers, &tot_fail_ct, &task, &space, &support, + &success_ct, thread_workloads](int thread_id, int i) -> void { Sampler* sampler = &thread_samplers[thread_id]; for (;;) { const Trace& support_trace = support[sampler->SampleInt(0, support.size())]->trace; @@ -514,20 +515,30 @@ Array EvolutionaryNode::SampleInitPopulation(const Array& suppo Trace trace(sch->trace->insts, sch->trace->decisions); this->AddCachedTrace(CachedTrace{trace.get(), sch, Repr(sch), -1.0}); results[i] = std::move(trace); + success_ct++; break; } else { - fail_ct++; + tot_fail_ct++; } } catch (const dmlc::Error& e) { - fail_ct++; + tot_fail_ct++; + } + if (success_ct > 64) { + break; } } }; num_measured = results.size(); results.resize(this->population, Trace(nullptr)); support::parallel_persist_for(num_measured, this->population, f_proc_unmeasured); - LOG(INFO) << "fail count: " << fail_ct; - return results; + std::vector pruned_results; + for (const auto& result : results) { + if (result.defined()) { + pruned_results.push_back(result); + } + } + LOG(INFO) << "fail count: " << tot_fail_ct; + return pruned_results; } Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, diff --git a/src/tir/schedule/analysis.cc b/src/tir/schedule/analysis.cc index 089c10db94..4eca48450d 100644 --- a/src/tir/schedule/analysis.cc +++ b/src/tir/schedule/analysis.cc @@ -641,6 +641,12 @@ IterVarType GetLoopIterType(const ScheduleState& self, const StmtSRef& loop_sref return IterVarType::kOpaque; } else if (n_reduce) { return IterVarType::kCommReduce; + } else if (loop->kind == ForKind::kUnrolled) { + return IterVarType::kUnrolled; + } else if (loop->kind == ForKind::kVectorized) { + return IterVarType::kVectorized; + } else if (loop->kind == ForKind::kParallel) { + return IterVarType::kParallelized; } return IterVarType::kDataPar; } diff --git a/src/tir/transforms/preprocess_for_feature_extraction.cc b/src/tir/transforms/preprocess_for_feature_extraction.cc new file mode 100644 index 0000000000..25cbb85d70 --- /dev/null +++ b/src/tir/transforms/preprocess_for_feature_extraction.cc @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file preprocess_for_feature_extraction.cc + */ +#include +#include + +#include "../schedule/utils.h" + +namespace tvm { +namespace tir { + +class SimpifyConstMatrix : public StmtExprMutator { + public: + static Stmt simplifyConstMatrix(const PrimFunc& func) { + SimpifyConstMatrix simp; + return simp.VisitStmt(func->body); + } + + private: + PrimExpr VisitExpr_(const SelectNode* node) { return make_const(node->dtype, 1.0f); } +}; +PrimFunc PreprocessForFeatureExtraction(PrimFunc f) { + + auto pass_list = Array(); + pass_list.push_back(tir::transform::LowerInitBlock()); + pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(tir::transform::CompactBufferAllocation()); + Map tmp_map; + tmp_map.Set(GlobalVar("main"),f); + IRModule mod(tmp_map); + mod=tir::transform::Sequential(pass_list)(mod); + f= Downcast(mod->Lookup("main")); + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = SimpifyConstMatrix::simplifyConstMatrix(f); + return f; +} + +namespace transform { +Pass PreprocessForFeatureExtraction() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return PreprocessForFeatureExtraction(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.PreProcessForFeatureExtraction", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.PreProcessForFeatureExtraction").set_body_typed(BufferFlatten); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py b/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py new file mode 100644 index 0000000000..71a0f649e4 --- /dev/null +++ b/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""End to end resnet-18 GPU test""" +# pylint: disable=missing-function-docstring +import os + +import numpy as np +import pytest +import tvm +import tvm.relay.testing +from tvm import meta_schedule as ms +from tvm import relay, te, auto_scheduler +from tvm.contrib import graph_runtime as runtime +from tvm.contrib.utils import tempdir + + +# import logging +# logging.basicConfig(level=logging.DEBUG) # to dump TVM IR after fusion + + +def get_network(name, batch_size, layout="NHWC", dtype="float32"): + """Get the symbol definition and random weight of a network""" + + # auto-scheduler prefers NHWC layout + if layout == "NHWC": + image_shape = (224, 224, 3) + elif layout == "NCHW": + image_shape = (3, 224, 224) + else: + raise ValueError("Invalid layout: " + layout) + + input_shape = (batch_size,) + image_shape + output_shape = (batch_size, 1000) + + if name.startswith("resnet-"): + n_layer = int(name.split("-")[1]) + mod, params = relay.testing.resnet.get_workload( + num_layers=n_layer, + batch_size=batch_size, + layout=layout, + dtype=dtype, + image_shape=image_shape, + ) + elif name.startswith("resnet3d-"): + n_layer = int(name.split("-")[1]) + mod, params = relay.testing.resnet.get_workload( + num_layers=n_layer, + batch_size=batch_size, + layout=layout, + dtype=dtype, + image_shape=image_shape, + ) + elif name == "mobilenet": + mod, params = relay.testing.mobilenet.get_workload( + batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape + ) + elif name == "squeezenet_v1.1": + assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout" + mod, params = relay.testing.squeezenet.get_workload( + version="1.1", + batch_size=batch_size, + dtype=dtype, + image_shape=image_shape, + ) + elif name == "inception_v3": + input_shape = (batch_size, 3, 299, 299) if layout == "NCHW" else (batch_size, 299, 299, 3) + mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) + elif name == "mxnet": + # an example for mxnet model + from mxnet.gluon.model_zoo.vision import get_model + + assert layout == "NCHW" + + block = get_model("resnet18_v1", pretrained=True) + mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype) + net = mod["main"] + net = relay.Function( + net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs + ) + mod = tvm.IRModule.from_expr(net) + + return mod, params, input_shape, output_shape + + +RPC_KEY = "rtx-3080" +network = "resnet-50" +batch_size = 1 +layout = "NHWC" +target = tvm.target.Target("nvidia/geforce-rtx-3080") +dtype = "float32" +TARGET_HOST = tvm.target.Target("llvm") +SPACE = ms.space.PostOrderApply( + stages=[ + ms.rule.simplify_compute_with_const_tensor(), + ms.rule.multi_level_tiling( + structure="SSSRRSRS", + must_cache_read=True, + cache_read_scope="shared", + can_cache_write=True, + must_cache_write=True, + cache_write_scope="local", + consumer_inline_strict=False, + fusion_levels=[3], + vector_load_max_len=4, + tile_binds=["blockIdx.x", "vthread", "threadIdx.x"], + ), + ms.rule.special_compute_location_gpu(), + ms.rule.inline_pure_spatial(strict_mode=False), + ms.rule.cross_thread_reduction(), + ms.rule.parallelize_vectorize_unroll( + max_jobs_per_core=-1, # disable parallelize + max_vectorize_extent=-1, # disable vectorize + unroll_max_steps=[0, 16, 64, 512, 1024], + unroll_explicit=True, + ), + ], + postprocs=[ + ms.postproc.rewrite_cooperative_fetch(), + ms.postproc.rewrite_unbound_blocks(), + ms.postproc.rewrite_parallel_vectorize_unroll(), + ms.postproc.rewrite_reduction_block(), + ms.postproc.disallow_dynamic_loops(), + ms.postproc.verify_gpu_code(), + ], +) + + +@pytest.mark.skip(reason="needs RPC") +def test_end_to_end_resnet(log): + os.environ["TVM_TRACKER_KEY"] = RPC_KEY + mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype) + + data = np.random.uniform(-1, 1, size=input_shape).astype("float32") + + lib_std = relay.build_module.build(mod, target, params=params) + tir_funcs = ms.extract_tasks(mod["main"], params, target) + + for func in tir_funcs.values(): + + sch = ms.autotune( + task=ms.SearchTask( + workload=func, + target=target, + target_host=TARGET_HOST, + log_file=log, + ), + space=SPACE, + strategy=ms.strategy.Evolutionary( + total_measures=16, + num_measures_per_iter=16, + population=2048, + init_measured_ratio=0.2, + genetic_algo_iters=4, + p_mutate=0.85, + mutator_probs={ + ms.mutator.mutate_tile_size(): 0.90, + ms.mutator.mutate_auto_unroll(): 0.10, + }, + cost_model=ms.XGBModel(xgb_eta=0.2), + eps_greedy=0.25, + ), + measurer=ms.ProgramMeasurer( + measure_callbacks=[ + ms.RecordToFile(), + ] + ) + ) + + with ms.ApplyHistoryBest(log, SPACE): + with tvm.transform.PassContext(opt_level=3, config={"relay.with_tir_schedule": True, + "relay.backend.use_meta_schedule": True}): + lib = relay.build_module.build(mod, target, params=params) + + def run_module(lib): + ctx = tvm.context(str(target), 0) + module = runtime.GraphModule(lib["default"](ctx)) + data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype)) + module.set_input("data", data_tvm) + + # Evaluate + print("Evaluate inference time cost...") + ftimer = module.module.time_evaluator("run", ctx, repeat=3, min_repeat_ms=500) + prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond + print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) + + module.run() + return module.get_output(0) + + std = run_module(lib_std).asnumpy() + out = run_module(lib).asnumpy() + np.testing.assert_allclose(out, std, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + test_end_to_end_resnet("resnet_cuda.json")