diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index bdc55eedad..6a15ec7ba4 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1060,6 +1060,20 @@ class MatchBufferRegion : public ObjectRef { public: TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source); + /*! + * \brief Convert target buffer access indices to original one. + * \param indices The indices of the target buffer + * \return The indices of source buffer. + */ + TVM_DLL Array ConvertIndices(const Array& indices) const; + + /*! + * \brief Convert target buffer region to original one. + * \param region The sub-region of the target buffer + * \return The region of source buffer. + */ + TVM_DLL Region ConvertRegion(const Region& region) const; + TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode); }; diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 76d98774a7..9acf21b6ba 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -826,6 +826,11 @@ def transform_Subscript(self, node): "Array access index expected int or IntImm, but got " + type(index), node.span, ) + if int(index) >= len(symbol): + self.report_error( + f"Array access out of bound, size: {len(symbol)}, got index {index}.", + node.span, + ) return symbol[int(index)] else: self.report_error( diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 97c64854a3..6c55db2657 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -223,6 +223,7 @@ Array CreatePassList(bool disable_loop_partition, bool for pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::FlattenBuffer()); } pass_list.push_back(tir::transform::BF16Legalize()); diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index b1da536f1d..a610f9d939 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -65,8 +65,12 @@ class BlockReadWriteDetector : public StmtExprVisitor { std::vector> read_regions_; /*! \brief The write regions of the current block */ std::vector> write_regions_; + /*! \brief The opaque regions of the current block */ + std::vector> opaque_regions_; /*! \brief The outside buffer data mapping to its buffer */ Map buffer_var_map_; + /*! \brief The target buffer var mapping to its matching */ + std::unordered_map match_buffers_; /*! \brief The analyzer for simplifying*/ arith::Analyzer analyzer_; @@ -78,14 +82,18 @@ class BlockReadWriteDetector : public StmtExprVisitor { * \param region The provided region */ void Update(std::vector* buffers, std::vector>* regions, - const Buffer& buffer, const std::vector& region); + Buffer buffer, std::vector region); /*! \brief Helper function to collect access regions. */ Array CollectRegions(const std::vector& buffers, const std::vector>& regions); - /*! \brief Helper function to add a opaque buffer. */ - void AddOpaque(const Var& buffer_var); + /*! \brief Helper function to convert matched access region to source region. */ + std::vector ConvertMatchedRegion(const MatchBufferRegion& match_buffer, + const std::vector& int_sets) const; + + /*! \brief Helper function to update a opaque access. */ + void UpdateOpaque(const Var& buffer_var); void VisitStmt_(const ForNode* op) override; void VisitStmt_(const BlockRealizeNode* op) override; @@ -97,8 +105,13 @@ class BlockReadWriteDetector : public StmtExprVisitor { }; void BlockReadWriteDetector::operator()(const Stmt& stmt) { - ICHECK(stmt.as() != nullptr) - << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); + const auto* block = stmt.as(); + ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); + for (const MatchBufferRegion& match_buffer : block->match_buffers) { + const Var& target_var = match_buffer->buffer->data; + match_buffers_[target_var.get()] = match_buffer; + buffer_var_map_.Set(target_var, match_buffer->buffer); + } StmtExprVisitor::operator()(stmt); } @@ -111,18 +124,13 @@ Array BlockReadWriteDetector::CollectWrites() { } Array BlockReadWriteDetector::CollectOpaques() { - Array res; - res.reserve(opaque_buffers_.size()); - for (const Buffer& buffer : opaque_buffers_) { - res.push_back(BufferRegion::FullRegion(buffer)); - } - return res; + return CollectRegions(opaque_buffers_, opaque_regions_); } -void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { AddOpaque(GetRef(op)); } +void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef(op)); } void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) { - AddOpaque(op->buffer_var); + UpdateOpaque(op->buffer_var); ExprVisitor::VisitExpr_(op); } @@ -143,7 +151,7 @@ void BlockReadWriteDetector::VisitStmt_(const ForNode* op) { } void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) { - AddOpaque(op->buffer_var); + UpdateOpaque(op->buffer_var); StmtVisitor::VisitStmt_(op); } @@ -184,11 +192,39 @@ void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) { } } +std::vector BlockReadWriteDetector::ConvertMatchedRegion( + const MatchBufferRegion& match_buffer, const std::vector& int_sets) const { + const Buffer& buffer = match_buffer->buffer; + + Region region; + region.reserve(int_sets.size()); + ICHECK_EQ(buffer->shape.size(), int_sets.size()); + for (size_t i = 0; i < int_sets.size(); ++i) { + const tvm::arith::IntSet& int_set = int_sets[i]; + region.push_back(int_set.CoverRange(Range::FromMinExtent(0, buffer->shape[i]))); + } + + region = match_buffer.ConvertRegion(region); + + std::vector result; + result.reserve(region.size()); + for (const Range& range : region) { + result.push_back(arith::EvalSet(range, dom_map_)); + } + return result; +} + void BlockReadWriteDetector::Update(std::vector* buffers, - std::vector>* regions, - const Buffer& buffer, - const std::vector& region) { + std::vector>* regions, Buffer buffer, + std::vector region) { if (buffer_var_map_.find(buffer->data) == buffer_var_map_.end()) return; + // Handle match_buffer remap + auto it = match_buffers_.find(buffer->data.get()); + if (it != match_buffers_.end()) { + const MatchBufferRegion& match_buffer = it->second; + buffer = match_buffer->source->buffer; + region = ConvertMatchedRegion(match_buffer, std::move(region)); + } ICHECK_EQ(buffers->size(), regions->size()) << " Expected the buffer and regions to have the same size "; for (size_t i = 0; i < regions->size(); ++i) { @@ -200,8 +236,8 @@ void BlockReadWriteDetector::Update(std::vector* buffers, return; } } - buffers->push_back(buffer); - regions->push_back(region); + buffers->push_back(std::move(buffer)); + regions->push_back(std::move(region)); } Array BlockReadWriteDetector::CollectRegions( @@ -213,8 +249,9 @@ Array BlockReadWriteDetector::CollectRegions( for (size_t i = 0; i < regions.size(); ++i) { Array region; region.reserve(regions[i].size()); + ICHECK_EQ(buffers[i]->shape.size(), regions[i].size()); for (size_t j = 0; j < regions[i].size(); j++) { - tvm::arith::IntSet range = regions[i][j]; + const tvm::arith::IntSet& range = regions[i][j]; region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); } res.push_back(BufferRegion(buffers[i], region)); @@ -222,14 +259,18 @@ Array BlockReadWriteDetector::CollectRegions( return res; } -void BlockReadWriteDetector::AddOpaque(const Var& buffer_var) { +void BlockReadWriteDetector::UpdateOpaque(const Var& buffer_var) { auto it = buffer_var_map_.find(buffer_var); if (it != buffer_var_map_.end()) { const Buffer& buffer = (*it).second; - for (const Buffer& opaque_buffer : opaque_buffers_) { - if (buffer.same_as(opaque_buffer)) return; + const BufferRegion buffer_region = BufferRegion::FullRegion(buffer); + const Region& region = buffer_region->region; + std::vector int_set; + int_set.reserve(region.size()); + for (const Range& range : region) { + int_set.push_back(arith::EvalSet(range, dom_map_)); } - opaque_buffers_.push_back(buffer); + Update(&opaque_buffers_, &opaque_regions_, buffer, int_set); } } diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index 6f2622f3a6..8f39de4c96 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -85,6 +85,21 @@ class LCADetector : public StmtExprVisitor { for (const Buffer& buf : op->alloc_buffers) { buffer_var_map_.emplace(buf->data.get(), buf.get()); } + + // Update match_buffers + for (const MatchBufferRegion& match_buffer : op->match_buffers) { + const Buffer& target_buffer = match_buffer->buffer; + buffer_var_map_.emplace(target_buffer->data.get(), target_buffer.get()); + + const Buffer& source_buffer = match_buffer->source->buffer; + auto it = match_buffers_.find(source_buffer.get()); + if (it != match_buffers_.end()) { + match_buffers_[target_buffer.get()] = it->second; + } else { + match_buffers_[target_buffer.get()] = source_buffer.get(); + } + } + const ScopeInfo* parent_scope = ancestor_scopes_.back(); auto* current_scope = arena_.make(parent_scope, op, n); ancestor_scopes_.push_back(current_scope); @@ -129,6 +144,10 @@ class LCADetector : public StmtExprVisitor { } void UpdateBufferLCA(const BufferNode* buffer) { + auto it = match_buffers_.find(buffer); + if (it != match_buffers_.end()) { + buffer = it->second; + } const ScopeInfo*& lca = buffer_lca_[buffer]; lca = LowestCommonAncestor(lca, ancestor_scopes_.back()); } @@ -164,6 +183,8 @@ class LCADetector : public StmtExprVisitor { std::unordered_map buffer_lca_ = {}; /*! \brief The map from Buffer data to the Buffer. */ std::unordered_map buffer_var_map_ = {}; + /*! \brief The match buffers inside blocks. */ + std::unordered_map match_buffers_ = {}; /*! \brief Internal arena. */ support::Arena arena_; }; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 428abd1bf1..d72fd8f72d 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -743,6 +743,51 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { data_ = std::move(node); } +Array MatchBufferRegion::ConvertIndices(const Array& indices) const { + const Buffer& target = (*this)->buffer; + const BufferRegion& source = (*this)->source; + ICHECK_EQ(indices.size(), target->shape.size()); + + arith::Analyzer analyzer; + Array result; + result.reserve(source->region.size()); + size_t offset = source->region.size() - indices.size(); + for (size_t i = 0; i < offset; ++i) { + const Range& range = source->region[i]; + ICHECK(analyzer.CanProve(range->extent == 1)); + result.push_back(range->min); + } + for (size_t i = 0; i < indices.size(); ++i) { + const Range& range = source->region[i + offset]; + const PrimExpr& index = indices[i]; + result.push_back(range->min + index); + } + return result; +} + +Region MatchBufferRegion::ConvertRegion(const Region& region) const { + const Buffer& target = (*this)->buffer; + const BufferRegion& source = (*this)->source; + ICHECK_EQ(region.size(), target->shape.size()); + + arith::Analyzer analyzer; + Region result; + result.reserve(source->region.size()); + size_t offset = source->region.size() - region.size(); + for (size_t i = 0; i < offset; ++i) { + const Range& source_range = source->region[i]; + ICHECK(analyzer.CanProve(source_range->extent == 1)); + result.push_back(Range::FromMinExtent(source_range->min, 1)); + } + for (size_t i = 0; i < region.size(); ++i) { + const Range& source_range = source->region[i + offset]; + const Range& target_range = region[i]; + result.push_back( + Range::FromMinExtent(source_range->min + target_range->min, target_range->extent)); + } + return result; +} + TVM_REGISTER_GLOBAL("tir.MatchBufferRegion").set_body_typed([](Buffer buffer, BufferRegion source) { return MatchBufferRegion(buffer, source); }); diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 1df5c70873..e06a6c953c 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -306,7 +306,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { class StorageAlignCollector : public StmtVisitor { public: static std::unordered_map>, ObjectPtrHash, ObjectPtrEqual> Collect( - const PrimFunc& f) { + const PrimFunc& f) { StorageAlignCollector collector; collector(f->body); return std::move(collector.storage_align_); @@ -319,8 +319,8 @@ class StorageAlignCollector : public StmtVisitor { const auto& storage_align = Downcast>>>((*it).second); ICHECK(storage_align.size() == op->writes.size()); for (size_t i = 0; i < storage_align.size(); ++i) { - CHECK(!storage_align_.count(op->writes[i]->buffer)) << - "ValueError: Conflicting storage_align for buffer " << op->writes[i]->buffer->name; + CHECK(!storage_align_.count(op->writes[i]->buffer)) + << "ValueError: Conflicting storage_align for buffer " << op->writes[i]->buffer->name; storage_align_.emplace(op->writes[i]->buffer, storage_align[i]); } } @@ -334,9 +334,11 @@ class StorageAlignCollector : public StmtVisitor { /*! \brief Reallocate the buffers with minimal region. */ class BufferCompactor : public StmtExprMutator { public: - static Stmt Compact(const PrimFunc& f, - const std::unordered_map& regions, - const std::unordered_map>, ObjectPtrHash, ObjectPtrEqual>& storage_align) { + static Stmt Compact( + const PrimFunc& f, + const std::unordered_map& regions, + const std::unordered_map>, ObjectPtrHash, ObjectPtrEqual>& + storage_align) { std::unordered_map buffer_info; for (const auto& kv : regions) { @@ -363,7 +365,6 @@ class BufferCompactor : public StmtExprMutator { } private: - /*! \brief The storage alignment for a dimension */ struct DimAlignInfo { /*! \brief The factor of the alignment */ @@ -415,6 +416,7 @@ class BufferCompactor : public StmtExprMutator { BlockNode* n = block.CopyOnWrite(); RewriteBufferRegions(&n->reads); RewriteBufferRegions(&n->writes); + RewriteMatchBuffers(&n->match_buffers); n->alloc_buffers = std::move(alloc_buffers); return std::move(block); } @@ -507,6 +509,18 @@ class BufferCompactor : public StmtExprMutator { *regions = std::move(new_regions); } + void RewriteMatchBuffers(Array* match_buffers) const { + Array result; + result.reserve(match_buffers->size()); + for (const auto& match_buffer : *match_buffers) { + const BufferRegion& buffer_region = match_buffer->source; + auto p = make_object(*buffer_region.get()); + RewriteBufferRegion(&p->buffer, &p->region); + result.push_back(MatchBufferRegion(match_buffer->buffer, BufferRegion(p))); + } + *match_buffers = std::move(result); + } + /*! \brief The allocation information about each buffer. */ std::unordered_map buffer_info_; }; diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index acb5691fda..9ad115c647 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -93,7 +93,7 @@ class MatchBufferLower : public StmtExprMutator { const BufferRegion& source = it->second; auto n = CopyOnWrite(op); - n->indices = ConvertIndices(op->indices, MatchBufferRegion(buffer, source)); + n->indices = MatchBufferRegion(buffer, source).ConvertIndices(op->indices); n->buffer = source->buffer; return Stmt(n); } @@ -110,7 +110,7 @@ class MatchBufferLower : public StmtExprMutator { } else { const Buffer& buffer = it->first; const BufferRegion& source = it->second; - Array indices = ConvertIndices(op->indices, MatchBufferRegion(buffer, source)); + Array indices = MatchBufferRegion(buffer, source).ConvertIndices(op->indices); return BufferLoad(source->buffer, indices); } } @@ -122,7 +122,7 @@ class MatchBufferLower : public StmtExprMutator { return buffer_region; } else { const BufferRegion& source = it->second; - Region region = ConvertRegion(buffer_region->region, MatchBufferRegion(buffer, source)); + Region region = MatchBufferRegion(buffer, source).ConvertRegion(buffer_region->region); return BufferRegion(source->buffer, std::move(region)); } } @@ -197,51 +197,6 @@ class MatchBufferLower : public StmtExprMutator { } } - Array ConvertIndices(const Array indices, - const MatchBufferRegion& match_buffer) { - const Buffer& target = match_buffer->buffer; - const BufferRegion& source = match_buffer->source; - ICHECK_EQ(indices.size(), target->shape.size()); - - Array result; - result.reserve(source->region.size()); - size_t offset = source->region.size() - indices.size(); - for (size_t i = 0; i < offset; ++i) { - const Range& range = source->region[i]; - ICHECK(analyzer_.CanProve(range->extent == 1)); - result.push_back(range->min); - } - for (size_t i = 0; i < indices.size(); ++i) { - const Range& range = source->region[i + offset]; - const PrimExpr& index = indices[i]; - result.push_back(range->min + index); - } - return result; - } - - Region ConvertRegion(const Region region, const MatchBufferRegion& match_buffer) { - const Buffer& target = match_buffer->buffer; - const BufferRegion& source = match_buffer->source; - ICHECK_EQ(region.size(), target->shape.size()); - - Region result; - result.reserve(source->region.size()); - size_t offset = source->region.size() - region.size(); - for (size_t i = 0; i < offset; ++i) { - const Range& source_range = source->region[i]; - ICHECK(analyzer_.CanProve(source_range->extent == 1)); - result.push_back(Range::FromMinExtent(source_range->min, 1)); - } - for (size_t i = 0; i < region.size(); ++i) { - const Range& source_range = source->region[i + offset]; - const Range& target_range = region[i]; - ICHECK(analyzer_.CanProve(source_range->extent >= target_range->min + target_range->extent)); - result.push_back( - Range::FromMinExtent(source_range->min + target_range->min, target_range->extent)); - } - return result; - } - void Bind(const PrimExpr& arg, PrimExpr value, const std::string& arg_name = "argument") { CHECK_EQ(arg.dtype(), value.dtype()) << "The data type mismatched: " << arg->dtype << " vs. " << value->dtype; diff --git a/tests/python/integration/test_lower.py b/tests/python/integration/test_lower.py new file mode 100644 index 0000000000..9053b35348 --- /dev/null +++ b/tests/python/integration/test_lower.py @@ -0,0 +1,327 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument +"""Test workload for lowering and build""" +import tvm +from tvm import tir +from tvm.script import ty +import tvm.testing +import numpy as np + + +@tvm.script.tir +def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + # match buffer + A = tir.match_buffer(a, [1024, 1024], "float16") + B = tir.match_buffer(b, [1024, 1024], "float16") + C = tir.match_buffer(c, [1024, 1024], "float32") + + # body + for blockIdx_x in tir.thread_binding(0, 16, "blockIdx.x"): + for blockIdx_y in tir.thread_binding(0, 8, "blockIdx.y"): + with tir.block([16, 8]) as [bx, by]: + tir.bind(bx, blockIdx_x) + tir.bind(by, blockIdx_y) + shared_A = tir.alloc_buffer([1024, 1024], "float16", scope="shared") + shared_B = tir.alloc_buffer([1024, 1024], "float16", scope="shared") + wmma_A = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") + wmma_B = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b") + wmma_C = tir.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator") + for ty in tir.thread_binding(0, 2, "threadIdx.y"): + for tz in tir.thread_binding(0, 2, "threadIdx.z"): + for i, j in tir.grid(2, 4): + with tir.block([64, 64]) as [vi, vj]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vj, by * 8 + tz * 4 + j) + tir.reads([]) + tir.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + C0 = tir.match_buffer( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[16 * 4, 1], + scope="wmma.accumulator", + offset_factor=1, + ) + tir.evaluate( + tir.tvm_fill_fragment( + C0.data, + 16, + 16, + 16, + i * 4 + j, + tir.float32(0), + dtype="handle", + ) + ) + + for ko in range(0, 32): + # copy data from global to shared + for tx in tir.thread_binding(0, 32, "threadIdx.x"): + for i0, j0 in tir.grid(1, 4): + for j1 in tir.vectorized(0, 4): + with tir.block([1024, 1024]) as [vi, vj]: + tir.bind(vi, bx * 64 + ty * 32 + tx + i0) + tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + shared_A[vi, vj + 8] = A[vi, vj] + + for i0, j0 in tir.grid(2, 4): + for j1 in tir.vectorized(0, 4): + with tir.block([1024, 1024]) as [vi, vj]: + tir.bind(vi, by * 128 + ty * 64 + tx * 2 + i0) + tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + shared_B[vi, vj + 8] = B[vi, vj] + + for ki in range(0, 2): + for i in range(0, 2): + with tir.block([64, 64]) as [vi, vk]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vk, ko * 2 + ki) + tir.reads( + shared_A[ + vi * 16 : vi * 16 + 16, + vk * 16 : vk * 16 + 16 + 8, + ] + ) + tir.writes( + wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16] + ) + s0 = tir.var("int32") + s1 = tir.var("int32") + A0 = tir.match_buffer( + shared_A[ + vi * 16 : vi * 16 + 16, + vk * 16 : vk * 16 + 16 + 8, + ], + (16, 16 + 8), + "float16", + strides=[s0, s1], + scope="shared", + offset_factor=1, + ) + wmma_A0 = tir.match_buffer( + wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="wmma.matrix_a", + offset_factor=1, + ) + tir.evaluate( + tir.tvm_load_matrix_sync( + wmma_A0.data, + 16, + 16, + 16, + i, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset + 8, + A0.strides[0], + 1, + dtype="handle", + ), + A0.strides[0], + "row_major", + dtype="handle", + ) + ) + for j in range(0, 4): + with tir.block([64, 64]) as [vj, vk]: + tir.bind(vj, by * 8 + tz * 4 + j) + tir.bind(vk, ko * 2 + ki) + tir.reads( + shared_B[ + vj * 16 : vj * 16 + 16, + vk * 16 : vk * 16 + 16 + 8, + ] + ) + tir.writes( + wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16] + ) + s0 = tir.var("int32") + s1 = tir.var("int32") + B0 = tir.match_buffer( + shared_B[ + vj * 16 : vj * 16 + 16, + vk * 16 : vk * 16 + 16 + 8, + ], + (16, 16 + 8), + "float16", + strides=[s0, s1], + scope="shared", + offset_factor=1, + ) + wmma_B0 = tir.match_buffer( + wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="wmma.matrix_b", + offset_factor=1, + ) + tir.evaluate( + tir.tvm_load_matrix_sync( + wmma_B0.data, + 16, + 16, + 16, + j, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + B0.data, + B0.elem_offset + 8, + B0.strides[0], + 1, + dtype="handle", + ), + B0.strides[0], + "col_major", + dtype="handle", + ) + ) + for i, j in tir.grid(2, 4): + with tir.block([64, 64, tir.reduce_axis(0, 64)]) as [ + vi, + vj, + vk, + ]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vj, by * 8 + tz * 4 + j) + tir.bind(vk, ko * 2 + ki) + tir.reads( + [ + wmma_A[ + vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + ], + wmma_B[ + vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + ], + wmma_C[ + vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16 + ], + ] + ) + tir.writes( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16] + ) + wmma_A1 = tir.match_buffer( + wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="wmma.matrix_a", + offset_factor=1, + ) + wmma_B1 = tir.match_buffer( + wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="wmma.matrix_b", + offset_factor=1, + ) + wmma_C1 = tir.match_buffer( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[16 * 4, 1], + scope="wmma.accumulator", + offset_factor=1, + ) + tir.evaluate( + tir.tvm_mma_sync( + wmma_C1.data, + i * 4 + j, + wmma_A1.data, + i, + wmma_B1.data, + j, + wmma_C1.data, + i * 4 + j, + dtype="handle", + ) + ) + for i, j in tir.grid(2, 4): + with tir.block([64, 64]) as [vi, vj]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vj, by * 8 + tz * 4 + j) + tir.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + s0 = tir.var("int32") + s1 = tir.var("int32") + wmma_C2 = tir.match_buffer( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[16 * 4, 1], + scope="wmma.accumulator", + offset_factor=1, + ) + C1 = tir.match_buffer( + C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[s0, s1], + offset_factor=1, + ) + tir.evaluate( + tir.tvm_store_matrix_sync( + wmma_C2.data, + 16, + 16, + 16, + i * 4 + j, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float32"), + C1.data, + C1.elem_offset, + C1.strides[0], + 1, + dtype="handle", + ), + C1.strides[0], + "row_major", + dtype="handle", + ) + ) + + +def test_gemm_tensorcore(): + dev = tvm.device("cuda", 0) + a_np = np.random.uniform(size=(1024, 1024)).astype("float16") + b_np = np.random.uniform(size=(1024, 1024)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.T.astype("float32")) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros((1024, 1024), dtype="float32"), dev) + print(tvm.script.asscript(tvm.lower(tensorcore_gemm, simple_mode=True))) + f = tvm.build(tensorcore_gemm, target="cuda", name="dense") + f(a, b, c) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) + + evaluator = f.time_evaluator(f.entry_name, dev, number=100) + t = evaluator(a, b, c).mean + num_flops = 2 * 1024 * 1024 * 1024 + gflops = num_flops / (t * 1e3) / 1e6 + print("gemm with tensor core: %f ms" % (t * 1e3)) + print("GFLOPS: %f" % gflops) + + +if __name__ == "__main__": + test_gemm_tensorcore() \ No newline at end of file diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index 36fd80fd07..c7cf2f6edf 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -70,6 +70,23 @@ def lca_is_func_root(a: ty.handle) -> None: A.data[0] = 1.0 +@tvm.script.tir +def match_buffer_func(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + B = tir.match_buffer(b, (128, 128), "float32") + with tir.block([8, 8], "block") as [vi, vj]: + tir.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) + tir.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + AA = tir.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16)) + B0 = tir.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) + B1 = tir.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) + with tir.block([16, 16], "AAA") as [i, j]: + AAA = tir.match_buffer(AA[i, j], ()) + AAA[()] = 1.0 + tir.evaluate(B0.data) + tir.evaluate(B1.data) + + def test_buffer_load_store(): func = buffer_load_store_func A, B = [func.buffer_map[x] for x in func.params] @@ -115,7 +132,24 @@ def test_lca_func_root(): assert lca[A] is None +def test_match_buffer(): + func = match_buffer_func + A, B = [func.buffer_map[x] for x in func.params] + lca = tir.analysis.detect_buffer_access_lca(func) + + root_block = func.body.block + block = root_block.body.body.body.block + block_inner = block.body[0].body.body.block + + # LCA of Buffer C is the inner block + assert lca[A] == block_inner + + # LCA of Buffer C is the main block + assert lca[B] == block + + if __name__ == "__main__": test_buffer_load_store() test_opaque_access() test_lca_func_root() + test_match_buffer() diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index 7e4d7d87c1..70de437280 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -39,6 +39,29 @@ def func() -> None: tir.evaluate(D.data) +@tvm.script.tir +def match_buffer_func() -> None: + with tir.block([], "root"): + A = tir.alloc_buffer((128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + tir.reads([]) + tir.writes([]) + # Need add read/write region manually to avoid triggering block access region detector + with tir.block([8, 8], "block") as [vi, vj]: + tir.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) + tir.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + AA = tir.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16)) + B0 = tir.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) + B1 = tir.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) + with tir.block([16, 16], "AAA") as [i, j]: + tir.reads([]) + tir.writes(AA[i, j]) + AAA = tir.match_buffer(AA[i, j], ()) + AAA[()] = 1.0 + tir.evaluate(B0.data) + tir.evaluate(B1.data) + + def test_block_access_region_detector(): block = func.body.block.body.block alloc_buffers = func.body.block.alloc_buffers @@ -53,5 +76,25 @@ def test_block_access_region_detector(): ) +def test_match_buffer(): + root_block = match_buffer_func.body.block + block = root_block.body.body.body.block + block_inner = block.body[0].body.body.block + alloc_buffers = func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + + # Check inner block AAA + ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) + tvm.ir.assert_structural_equal(block_inner.reads, ret[0]) + tvm.ir.assert_structural_equal(block_inner.writes, ret[1]) + + # Check block + ret = tir.analysis.get_block_access_region(block, buffer_var_map) + tvm.ir.assert_structural_equal(block.writes, ret[1]) + # B is opaque access + tvm.ir.assert_structural_equal(block.reads, ret[2]) + + if __name__ == "__main__": test_block_access_region_detector() + test_match_buffer()