Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix when binary search find invalid value #95

Merged
merged 7 commits into from
Mar 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,24 +271,30 @@ class SparseBufferNode : public BufferNode {
PrimExpr GetNNZ() const;

Buffer flattened;

/*!
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
* \brief The default value in the sparse buffer.
*/
Optional<PrimExpr> default_value;
void VisitAttrs(AttrVisitor* v) {
BufferNode::VisitAttrs(v);
v->Visit("axes", &axes);
v->Visit("extra_storage", &extra_storage);
v->Visit("flattened", &flattened);
v->Visit("default_value", &default_value);
}

bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
return BufferNode::SEqualReduce(other, equal) && equal(axes, other->axes) &&
equal(extra_storage, other->extra_storage) && equal(flattened, other->flattened);
equal(extra_storage, other->extra_storage) && equal(flattened, other->flattened) &&
equal(default_value, other->default_value);
}

void SHashReduce(SHashReducer hash_reduce) const {
BufferNode::SHashReduce(hash_reduce);
hash_reduce(axes);
hash_reduce(extra_storage);
hash_reduce(flattened);
hash_reduce(default_value);
}

static constexpr const char* _type_key = "tir.sparse.SparseBuffer";
Expand All @@ -304,7 +310,8 @@ class SparseBufferNode : public BufferNode {
class SparseBuffer : public Buffer {
public:
TVM_DLL explicit SparseBuffer(Var data, Array<Axis> axes, DataType dtype, String name,
Optional<PrimExpr> extra_storage, Span span = Span());
Optional<PrimExpr> extra_storage,
Optional<PrimExpr> default_value = NullOpt, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, Buffer, SparseBufferNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBufferNode);
};
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ TVM_DLL Pass RenormalizeSplitPattern();
* \brief Lower sparse iterations in Sparse TIR.
* \return The pass.
*/
TVM_DLL Pass LowerSparseIter();
TVM_DLL Pass LowerSparseIter(bool check_invalid_binary_search = false);

/*!
* \brief Lower sparse buffers in Sparse TIR.
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,6 @@ def preflattened_buffer(
buffer_type="default",
span=None,
):

param = None
for key, value in self.context.func_buffer_map.items():
if value.same_as(postflattened):
Expand Down Expand Up @@ -1194,6 +1193,7 @@ def match_sparse_buffer(
axes: List[Axis],
dtype: str = "float32",
extra_storage: Optional[PrimExpr] = None,
default_value: Optional[PrimExpr] = None,
span: Optional[Span] = None,
):
if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
Expand All @@ -1214,7 +1214,7 @@ def match_sparse_buffer(
storage_type = PrimType("int8") if storage_type.dtype == "bool" else storage_type
data = Var(buffer_name, PointerType(storage_type, "global"), span)
buffer = tvm.tir.sparse.SparseBuffer(
data, axes, dtype, buffer_name, extra_storage, span
data, axes, dtype, buffer_name, extra_storage, default_value, span
)
self.context.func_buffer_map[param] = buffer
self.context.update_symbol(buffer_name, buffer, self.node)
Expand All @@ -1235,6 +1235,7 @@ def alloc_sparse_buffer(
axes: List[Axis],
dtype: str = "float32",
scope: str = "global",
default_value: Optional[PrimExpr] = None,
span: Optional[Span] = None,
):
if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
Expand All @@ -1245,7 +1246,9 @@ def alloc_sparse_buffer(
buffer_name: str = self.node.lhs[0].id.name

data = Var(buffer_name, PointerType(PrimType(dtype), scope), span)
buffer = tvm.tir.sparse.SparseBuffer(data, axes, dtype, buffer_name, 0, span)
buffer = tvm.tir.sparse.SparseBuffer(
data, axes, dtype, buffer_name, 0, default_value, span
)
if self.context.current_block_scope():
self.context.current_block_scope().alloc_buffers.append(buffer)
else:
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/sparse/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,19 @@
from tvm.tir.transform import LowerSparseBuffer, LowerSparseIter


def lower_sparse_iter(mod: IRModule):
def lower_sparse_iter(mod: IRModule, check_invalid_binary_search: bool = False):
"""Lower sparse iterators in Sparse TIR.

Parameters
----------
mod : IRModule
The IRModule to lower.
check_invalid_binary_search : bool
Whether check invalid indices made by binary search.
"""
if not isinstance(mod, IRModule):
raise TypeError("Expected IRModule, but got {}".format(type(mod)))
return LowerSparseIter()(mod)
return LowerSparseIter(check_invalid_binary_search)(mod)


def lower_sparse_buffer(mod: IRModule):
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ class SparseBuffer(Buffer):
The name of the sparse buffer
extra_storage : Optional[PrimExpr]
Required extra storage (e.g. for indptr)
default_value : Optional[PrimExpr]
The default value about missing value of the the sparse buffer
span : Span
"""

Expand All @@ -306,10 +308,11 @@ class SparseBuffer(Buffer):
dtype: str
name: str
extra_storage: Optional[PrimExpr]
default_value: Optional[PrimExpr]
span: Span

def __init__(self, data, axes, dtype, name, extra_storage, span):
self.__init_handle_by_constructor__(_ffi_api.SparseBuffer, data, axes, dtype, name, extra_storage, span) # type: ignore
def __init__(self, data, axes, dtype, name, extra_storage, default_value, span):
self.__init_handle_by_constructor__(_ffi_api.SparseBuffer, data, axes, dtype, name, extra_storage, default_value, span) # type: ignore


@tvm._ffi.register_object("tir.sparse.SpIterVar")
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,15 +806,20 @@ def RenormalizeSplitPattern():
return _ffi_api.RenormalizeSplitPattern() # type: ignore


def LowerSparseIter():
def LowerSparseIter(check_invalid_binary_search: bool = False):
"""Lower iterations in Sparse TIR

Parameters
----------
check_invalid_binary_search : bool
Whether check invalid indices made by binary search.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerSparseIter() # type: ignore
return _ffi_api.LowerSparseIter(check_invalid_binary_search) # type: ignore


def LowerSparseBuffer():
Expand Down
1 change: 1 addition & 0 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitStmt_(const WhileNode* op) override;
Doc VisitStmt_(const PrefetchNode* op) override;
Doc VisitStmt_(const BlockRealizeNode* op) override;
Doc VisitStmt_(const SparseIterationNode* op) override;
Doc VisitStmtDefault_(const Object* op) override;

Doc VisitType_(const PrimTypeNode* node) override;
Expand Down
21 changes: 21 additions & 0 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,27 @@ Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) {
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const SparseIterationNode* op) {
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
Doc doc;
doc << "sparse_iteration " << op->name << "(";
doc << Print(op->sp_iter_vars[0]->var);
for (int i = 1; i < static_cast<int>(op->sp_iter_vars.size()); ++i) {
doc << "," << Print(op->sp_iter_vars[i]->var);
}
doc << ")";
Doc body;
if (op->init.defined()) {
Doc init_block;
init_block << "with init()";
init_block << PrintBody(op->init.value());
body << init_block << Doc::NewLine();
}
// Print body
body << Print(op->body);
doc << " {" << Doc::Indent(2, Doc::NewLine() << body) << Doc::NewLine() << "}";
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) {
const auto* block_op = op->block.as<BlockNode>();
// print block name and block vars
Expand Down
5 changes: 5 additions & 0 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,11 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) {
if (sp_buf->extra_storage.defined()) {
doc << ", extra_storage=" << Print(sp_buf->extra_storage.value());
}

// default value
if (sp_buf->default_value.defined()) {
doc << ", default_value=" << Print(sp_buf->default_value.value());
}
// scope
const auto* ptr_type = sp_buf->data->type_annotation.as<PointerTypeNode>();
ICHECK(ptr_type) << "Buffer variable is not of pointer type";
Expand Down
18 changes: 15 additions & 3 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ PrimExpr SparseBufferNode::GetNNZ() const { return flattened->shape[0]; }

/*! \brief Default constructor of SparseBuffer */
SparseBuffer::SparseBuffer(Var data, Array<Axis> axes, DataType dtype, String name,
Optional<PrimExpr> extra_storage, Span span) {
Optional<PrimExpr> extra_storage, Optional<PrimExpr> default_value,
Span span) {
ObjectPtr<SparseBufferNode> node = make_object<SparseBufferNode>();
CHECK_GT(static_cast<int>(axes.size()), 0)
<< "ValueError: A SparseBuffer should have at least one dimension";
Expand Down Expand Up @@ -275,6 +276,13 @@ SparseBuffer::SparseBuffer(Var data, Array<Axis> axes, DataType dtype, String na
node->extra_storage = extra_storage;
node->name = name;
node->dtype = dtype;
if (!default_value) {
node->default_value = Cast(dtype, Integer(0));
} else {
ICHECK(default_value.value()->dtype == dtype)
<< "sparse buffer default value should match buffer data type";
node->default_value = default_value;
}
// collect shape
Array<PrimExpr> shape;
for (const Axis& axis : axes) {
Expand Down Expand Up @@ -307,9 +315,10 @@ TVM_REGISTER_NODE_TYPE(SparseBufferNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer")
.set_body_typed([](Var data, Array<Axis> axes, DataType dtype, String name,
Optional<PrimExpr> extra_storage, Span span) {
Optional<PrimExpr> extra_storage, Optional<PrimExpr> default_value,
Span span) {
return SparseBuffer(std::move(data), std::move(axes), std::move(dtype), std::move(name),
std::move(extra_storage), std::move(span));
std::move(extra_storage), std::move(default_value), std::move(span));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand All @@ -327,6 +336,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
if (op->extra_storage.defined()) {
p->stream << ", " << op->extra_storage.value();
}
if (op->default_value.defined()) {
p->stream << ", " << op->default_value.value();
}
p->stream << ")";
});

Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/specialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class PrimFuncSpecializer : public StmtExprMutator {
return buffer;
} else {
return SparseBuffer(sp_buf->data, std::move(axes), sp_buf->dtype, sp_buf->name,
sp_buf->extra_storage, sp_buf->span);
sp_buf->extra_storage, sp_buf->default_value, sp_buf->span);
}
} else {
Array<PrimExpr> shape =
Expand Down
Loading