Skip to content

Commit

Permalink
End-to-end match_buffer support and TensorCore script demo (#414)
Browse files Browse the repository at this point in the history
* add match_buffer support to access ditector

(cherry picked from commit b93a593c39a56ebb22c1dce70bfbd4484e581750)

* add match_buffer support to lca ditector

* tensorcore script

* finish e2e tensorcore script

* finish e2e tensorcore script
  • Loading branch information
Hzfengsy authored and jinhongyii committed Jul 29, 2021
1 parent a7b8bc7 commit 002348c
Show file tree
Hide file tree
Showing 11 changed files with 579 additions and 79 deletions.
14 changes: 14 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,20 @@ class MatchBufferRegion : public ObjectRef {
public:
TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);

/*!
* \brief Convert target buffer access indices to original one.
* \param indices The indices of the target buffer
* \return The indices of source buffer.
*/
TVM_DLL Array<PrimExpr> ConvertIndices(const Array<PrimExpr>& indices) const;

/*!
* \brief Convert target buffer region to original one.
* \param region The sub-region of the target buffer
* \return The region of source buffer.
*/
TVM_DLL Region ConvertRegion(const Region& region) const;

TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode);
};

Expand Down
5 changes: 5 additions & 0 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,11 @@ def transform_Subscript(self, node):
"Array access index expected int or IntImm, but got " + type(index),
node.span,
)
if int(index) >= len(symbol):
self.report_error(
f"Array access out of bound, size: {len(symbol)}, got index {index}.",
node.span,
)
return symbol[int(index)]
else:
self.report_error(
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::FlattenBuffer());
}
pass_list.push_back(tir::transform::BF16Legalize());
Expand Down
89 changes: 65 additions & 24 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,12 @@ class BlockReadWriteDetector : public StmtExprVisitor {
std::vector<std::vector<tvm::arith::IntSet>> read_regions_;
/*! \brief The write regions of the current block */
std::vector<std::vector<tvm::arith::IntSet>> write_regions_;
/*! \brief The opaque regions of the current block */
std::vector<std::vector<tvm::arith::IntSet>> opaque_regions_;
/*! \brief The outside buffer data mapping to its buffer */
Map<Var, Buffer> buffer_var_map_;
/*! \brief The target buffer var mapping to its matching */
std::unordered_map<const VarNode*, MatchBufferRegion> match_buffers_;
/*! \brief The analyzer for simplifying*/
arith::Analyzer analyzer_;

Expand All @@ -78,14 +82,18 @@ class BlockReadWriteDetector : public StmtExprVisitor {
* \param region The provided region
*/
void Update(std::vector<Buffer>* buffers, std::vector<std::vector<arith::IntSet>>* regions,
const Buffer& buffer, const std::vector<arith::IntSet>& region);
Buffer buffer, std::vector<arith::IntSet> region);

/*! \brief Helper function to collect access regions. */
Array<BufferRegion> CollectRegions(const std::vector<Buffer>& buffers,
const std::vector<std::vector<tvm::arith::IntSet>>& regions);

/*! \brief Helper function to add a opaque buffer. */
void AddOpaque(const Var& buffer_var);
/*! \brief Helper function to convert matched access region to source region. */
std::vector<arith::IntSet> ConvertMatchedRegion(const MatchBufferRegion& match_buffer,
const std::vector<arith::IntSet>& int_sets) const;

/*! \brief Helper function to update a opaque access. */
void UpdateOpaque(const Var& buffer_var);

void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const BlockRealizeNode* op) override;
Expand All @@ -97,8 +105,13 @@ class BlockReadWriteDetector : public StmtExprVisitor {
};

void BlockReadWriteDetector::operator()(const Stmt& stmt) {
ICHECK(stmt.as<BlockNode>() != nullptr)
<< "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey();
const auto* block = stmt.as<BlockNode>();
ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey();
for (const MatchBufferRegion& match_buffer : block->match_buffers) {
const Var& target_var = match_buffer->buffer->data;
match_buffers_[target_var.get()] = match_buffer;
buffer_var_map_.Set(target_var, match_buffer->buffer);
}
StmtExprVisitor::operator()(stmt);
}

Expand All @@ -111,18 +124,13 @@ Array<BufferRegion> BlockReadWriteDetector::CollectWrites() {
}

Array<BufferRegion> BlockReadWriteDetector::CollectOpaques() {
Array<BufferRegion> res;
res.reserve(opaque_buffers_.size());
for (const Buffer& buffer : opaque_buffers_) {
res.push_back(BufferRegion::FullRegion(buffer));
}
return res;
return CollectRegions(opaque_buffers_, opaque_regions_);
}

void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { AddOpaque(GetRef<Var>(op)); }
void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef<Var>(op)); }

void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) {
AddOpaque(op->buffer_var);
UpdateOpaque(op->buffer_var);
ExprVisitor::VisitExpr_(op);
}

Expand All @@ -143,7 +151,7 @@ void BlockReadWriteDetector::VisitStmt_(const ForNode* op) {
}

void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) {
AddOpaque(op->buffer_var);
UpdateOpaque(op->buffer_var);
StmtVisitor::VisitStmt_(op);
}

Expand Down Expand Up @@ -184,11 +192,39 @@ void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) {
}
}

std::vector<arith::IntSet> BlockReadWriteDetector::ConvertMatchedRegion(
const MatchBufferRegion& match_buffer, const std::vector<arith::IntSet>& int_sets) const {
const Buffer& buffer = match_buffer->buffer;

Region region;
region.reserve(int_sets.size());
ICHECK_EQ(buffer->shape.size(), int_sets.size());
for (size_t i = 0; i < int_sets.size(); ++i) {
const tvm::arith::IntSet& int_set = int_sets[i];
region.push_back(int_set.CoverRange(Range::FromMinExtent(0, buffer->shape[i])));
}

region = match_buffer.ConvertRegion(region);

std::vector<arith::IntSet> result;
result.reserve(region.size());
for (const Range& range : region) {
result.push_back(arith::EvalSet(range, dom_map_));
}
return result;
}

void BlockReadWriteDetector::Update(std::vector<Buffer>* buffers,
std::vector<std::vector<arith::IntSet>>* regions,
const Buffer& buffer,
const std::vector<arith::IntSet>& region) {
std::vector<std::vector<arith::IntSet>>* regions, Buffer buffer,
std::vector<arith::IntSet> region) {
if (buffer_var_map_.find(buffer->data) == buffer_var_map_.end()) return;
// Handle match_buffer remap
auto it = match_buffers_.find(buffer->data.get());
if (it != match_buffers_.end()) {
const MatchBufferRegion& match_buffer = it->second;
buffer = match_buffer->source->buffer;
region = ConvertMatchedRegion(match_buffer, std::move(region));
}
ICHECK_EQ(buffers->size(), regions->size())
<< " Expected the buffer and regions to have the same size ";
for (size_t i = 0; i < regions->size(); ++i) {
Expand All @@ -200,8 +236,8 @@ void BlockReadWriteDetector::Update(std::vector<Buffer>* buffers,
return;
}
}
buffers->push_back(buffer);
regions->push_back(region);
buffers->push_back(std::move(buffer));
regions->push_back(std::move(region));
}

Array<BufferRegion> BlockReadWriteDetector::CollectRegions(
Expand All @@ -213,23 +249,28 @@ Array<BufferRegion> BlockReadWriteDetector::CollectRegions(
for (size_t i = 0; i < regions.size(); ++i) {
Array<Range> region;
region.reserve(regions[i].size());
ICHECK_EQ(buffers[i]->shape.size(), regions[i].size());
for (size_t j = 0; j < regions[i].size(); j++) {
tvm::arith::IntSet range = regions[i][j];
const tvm::arith::IntSet& range = regions[i][j];
region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j])));
}
res.push_back(BufferRegion(buffers[i], region));
}
return res;
}

void BlockReadWriteDetector::AddOpaque(const Var& buffer_var) {
void BlockReadWriteDetector::UpdateOpaque(const Var& buffer_var) {
auto it = buffer_var_map_.find(buffer_var);
if (it != buffer_var_map_.end()) {
const Buffer& buffer = (*it).second;
for (const Buffer& opaque_buffer : opaque_buffers_) {
if (buffer.same_as(opaque_buffer)) return;
const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
const Region& region = buffer_region->region;
std::vector<arith::IntSet> int_set;
int_set.reserve(region.size());
for (const Range& range : region) {
int_set.push_back(arith::EvalSet(range, dom_map_));
}
opaque_buffers_.push_back(buffer);
Update(&opaque_buffers_, &opaque_regions_, buffer, int_set);
}
}

Expand Down
21 changes: 21 additions & 0 deletions src/tir/analysis/buffer_access_lca_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,21 @@ class LCADetector : public StmtExprVisitor {
for (const Buffer& buf : op->alloc_buffers) {
buffer_var_map_.emplace(buf->data.get(), buf.get());
}

// Update match_buffers
for (const MatchBufferRegion& match_buffer : op->match_buffers) {
const Buffer& target_buffer = match_buffer->buffer;
buffer_var_map_.emplace(target_buffer->data.get(), target_buffer.get());

const Buffer& source_buffer = match_buffer->source->buffer;
auto it = match_buffers_.find(source_buffer.get());
if (it != match_buffers_.end()) {
match_buffers_[target_buffer.get()] = it->second;
} else {
match_buffers_[target_buffer.get()] = source_buffer.get();
}
}

const ScopeInfo* parent_scope = ancestor_scopes_.back();
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n);
ancestor_scopes_.push_back(current_scope);
Expand Down Expand Up @@ -129,6 +144,10 @@ class LCADetector : public StmtExprVisitor {
}

void UpdateBufferLCA(const BufferNode* buffer) {
auto it = match_buffers_.find(buffer);
if (it != match_buffers_.end()) {
buffer = it->second;
}
const ScopeInfo*& lca = buffer_lca_[buffer];
lca = LowestCommonAncestor(lca, ancestor_scopes_.back());
}
Expand Down Expand Up @@ -164,6 +183,8 @@ class LCADetector : public StmtExprVisitor {
std::unordered_map<const BufferNode*, const ScopeInfo*> buffer_lca_ = {};
/*! \brief The map from Buffer data to the Buffer. */
std::unordered_map<const VarNode*, const BufferNode*> buffer_var_map_ = {};
/*! \brief The match buffers inside blocks. */
std::unordered_map<const BufferNode*, const BufferNode*> match_buffers_ = {};
/*! \brief Internal arena. */
support::Arena arena_;
};
Expand Down
45 changes: 45 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,51 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) {
data_ = std::move(node);
}

Array<PrimExpr> MatchBufferRegion::ConvertIndices(const Array<PrimExpr>& indices) const {
const Buffer& target = (*this)->buffer;
const BufferRegion& source = (*this)->source;
ICHECK_EQ(indices.size(), target->shape.size());

arith::Analyzer analyzer;
Array<PrimExpr> result;
result.reserve(source->region.size());
size_t offset = source->region.size() - indices.size();
for (size_t i = 0; i < offset; ++i) {
const Range& range = source->region[i];
ICHECK(analyzer.CanProve(range->extent == 1));
result.push_back(range->min);
}
for (size_t i = 0; i < indices.size(); ++i) {
const Range& range = source->region[i + offset];
const PrimExpr& index = indices[i];
result.push_back(range->min + index);
}
return result;
}

Region MatchBufferRegion::ConvertRegion(const Region& region) const {
const Buffer& target = (*this)->buffer;
const BufferRegion& source = (*this)->source;
ICHECK_EQ(region.size(), target->shape.size());

arith::Analyzer analyzer;
Region result;
result.reserve(source->region.size());
size_t offset = source->region.size() - region.size();
for (size_t i = 0; i < offset; ++i) {
const Range& source_range = source->region[i];
ICHECK(analyzer.CanProve(source_range->extent == 1));
result.push_back(Range::FromMinExtent(source_range->min, 1));
}
for (size_t i = 0; i < region.size(); ++i) {
const Range& source_range = source->region[i + offset];
const Range& target_range = region[i];
result.push_back(
Range::FromMinExtent(source_range->min + target_range->min, target_range->extent));
}
return result;
}

TVM_REGISTER_GLOBAL("tir.MatchBufferRegion").set_body_typed([](Buffer buffer, BufferRegion source) {
return MatchBufferRegion(buffer, source);
});
Expand Down
28 changes: 21 additions & 7 deletions src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
class StorageAlignCollector : public StmtVisitor {
public:
static std::unordered_map<Buffer, Array<Array<Integer>>, ObjectPtrHash, ObjectPtrEqual> Collect(
const PrimFunc& f) {
const PrimFunc& f) {
StorageAlignCollector collector;
collector(f->body);
return std::move(collector.storage_align_);
Expand All @@ -319,8 +319,8 @@ class StorageAlignCollector : public StmtVisitor {
const auto& storage_align = Downcast<Array<Array<Array<Integer>>>>((*it).second);
ICHECK(storage_align.size() == op->writes.size());
for (size_t i = 0; i < storage_align.size(); ++i) {
CHECK(!storage_align_.count(op->writes[i]->buffer)) <<
"ValueError: Conflicting storage_align for buffer " << op->writes[i]->buffer->name;
CHECK(!storage_align_.count(op->writes[i]->buffer))
<< "ValueError: Conflicting storage_align for buffer " << op->writes[i]->buffer->name;
storage_align_.emplace(op->writes[i]->buffer, storage_align[i]);
}
}
Expand All @@ -334,9 +334,11 @@ class StorageAlignCollector : public StmtVisitor {
/*! \brief Reallocate the buffers with minimal region. */
class BufferCompactor : public StmtExprMutator {
public:
static Stmt Compact(const PrimFunc& f,
const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>& regions,
const std::unordered_map<Buffer, Array<Array<Integer>>, ObjectPtrHash, ObjectPtrEqual>& storage_align) {
static Stmt Compact(
const PrimFunc& f,
const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>& regions,
const std::unordered_map<Buffer, Array<Array<Integer>>, ObjectPtrHash, ObjectPtrEqual>&
storage_align) {
std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info;

for (const auto& kv : regions) {
Expand All @@ -363,7 +365,6 @@ class BufferCompactor : public StmtExprMutator {
}

private:

/*! \brief The storage alignment for a dimension */
struct DimAlignInfo {
/*! \brief The factor of the alignment */
Expand Down Expand Up @@ -415,6 +416,7 @@ class BufferCompactor : public StmtExprMutator {
BlockNode* n = block.CopyOnWrite();
RewriteBufferRegions(&n->reads);
RewriteBufferRegions(&n->writes);
RewriteMatchBuffers(&n->match_buffers);
n->alloc_buffers = std::move(alloc_buffers);
return std::move(block);
}
Expand Down Expand Up @@ -507,6 +509,18 @@ class BufferCompactor : public StmtExprMutator {
*regions = std::move(new_regions);
}

void RewriteMatchBuffers(Array<MatchBufferRegion>* match_buffers) const {
Array<MatchBufferRegion> result;
result.reserve(match_buffers->size());
for (const auto& match_buffer : *match_buffers) {
const BufferRegion& buffer_region = match_buffer->source;
auto p = make_object<BufferRegionNode>(*buffer_region.get());
RewriteBufferRegion(&p->buffer, &p->region);
result.push_back(MatchBufferRegion(match_buffer->buffer, BufferRegion(p)));
}
*match_buffers = std::move(result);
}

/*! \brief The allocation information about each buffer. */
std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info_;
};
Expand Down
Loading

0 comments on commit 002348c

Please sign in to comment.