Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix when binary search find invalid value #95

Merged
merged 7 commits into from
Mar 25, 2023
Merged

Conversation

qelk123
Copy link
Contributor

@qelk123 qelk123 commented Mar 15, 2023

This is a simple fix up for the situation when binary searh find invalid indices,which contains:
1.An IfThenElse stmt after binary search,if the index is invalid,then return -1 as an invalid index value.
2.add a postprocess pass after lower sparse buffer,for those sparse data buffer,if the index value is a invalid value,then return a 0 from the sparse data buffer.

@yzh119
Copy link
Member

yzh119 commented Mar 15, 2023

Good job @qelk123 !
Thanks for your contribution, it's a behavior I should have fixed but never did, I'll leave my feedback later.

@qelk123
Copy link
Contributor Author

qelk123 commented Mar 16, 2023 via email

Copy link
Member

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall a nice contribution, I have a few comments.

Please format the code per suggestions, that would not take too much of your time.

I suggest adding a binary_search_vaild_check key in block annotation, indicating whether the user wants to perform a such check or not. If set to false, then we will examine whether the search is invalid or not, to reduce some overhead.

src/tir/transforms/lower_sparse_buffer.cc Outdated Show resolved Hide resolved
src/tir/transforms/lower_sparse_buffer.cc Outdated Show resolved Hide resolved
if(find!=find_backup)
{
find=find_backup;
if(op->buffer->dtype.is_float())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The data types are not only float32 and int32.
0 is also not only the valid choice as initial values, suppose we want to perform max reduction, -inf would be a better choice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can let user specify their desired initial values as block annotations, if such value is not set by user, we turn to use 0 instead.

Copy link
Contributor Author

@qelk123 qelk123 Mar 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of data types do we need to support currently(besides float and int)? Type info is needed when the user doesn't offer a default value,so we need to set the default value to 0 with a specific type according to the dtype of the sparse buffer.

Stmt then_stmt = BufferStore(mid, -1, mid_indices);
PrimExpr if_stmt = (pivot != val);
body_stmts.push_back(
IfThenElse(if_stmt,then_stmt));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might not need to emit this IfThenElse if we can confirm the lookup will not fail.

src/tir/transforms/lower_sparse_buffer.cc Outdated Show resolved Hide resolved
else
return StmtExprMutator::VisitStmt_(op);
}
std::vector<BufferLoad> buffer_need_process;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to store BufferLoad? How about we only store Buffer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BufferLoad here is to construct the if cond expr,if we only store the buffer how should we get the indices of the load buffer?

@yzh119
Copy link
Member

yzh119 commented Mar 18, 2023

I had a proposal months before: #74, the really tricky part is INVALID flag propagation:

Suppose lookup_result stores the binary search results, and we then use this value as an index to access another buffer: A[lookup_result[...]], the whole expression should be marked as invalid. However, I accept your solution for now and we can improve it later on.

@qelk123
Copy link
Contributor Author

qelk123 commented Mar 18, 2023

From my perspective,even the binary search results is not a valid index in the indices buffer,the buffer load result like A[lookup_result[...]] probably should not directly return a invalid flag,but a default value about the sparse buffer.In TACO,this default value can be spcified when declare a tensor.
In an Union operation like add,when we do co-iteration on two sparse axis,as long as current index is valid in one axis,the result in related index of the output tensor should be set to the correct value,instead of invalid value.
Currently,it only support 0 as the default value,but maybe we can improve it by letting users to specify the default value later on?

@qelk123 qelk123 requested a review from yzh119 March 19, 2023 07:59
@yzh119
Copy link
Member

yzh119 commented Mar 21, 2023

From my perspective,even the binary search results is not a valid index in the indices buffer,the buffer load result like A[lookup_result[...]] probably should not directly return a invalid flag,but a default value about the sparse buffer.In TACO,this default value can be spcified when declare a tensor.

Yes we should return a default value, what do I mean by "invalid flag" is a indicator that whole expression containing it should be invalidated and return its default value.

Back to my example A[lookup_result[...]], if we return the default value 0 for an unsuccessful lookup_result, then we will get A[0], but what we really need is default value for buffer A, and such information need to be propagated in some way.

We can leave it for another PR.

Copy link
Member

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks much better than its initial form, thank you so much, I have some suggestions for further improvements.

include/tvm/tir/sparse.h Show resolved Hide resolved
src/printer/tir_text_printer.cc Show resolved Hide resolved
src/tir/ir/sparse.cc Outdated Show resolved Hide resolved
src/tir/transforms/lower_sparse_iter.cc Outdated Show resolved Hide resolved
@qelk123
Copy link
Contributor Author

qelk123 commented Mar 21, 2023

From my perspective,even the binary search results is not a valid index in the indices buffer,the buffer load result like A[lookup_result[...]] probably should not directly return a invalid flag,but a default value about the sparse buffer.In TACO,this default value can be spcified when declare a tensor.

Yes we should return a default value, what do I mean by "invalid flag" is a indicator that whole expression containing it should be invalidated and return its default value.

Back to my example A[lookup_result[...]], if we return the default value 0 for an unsuccessful lookup_result, then we will get A[0], but what we really need is default value for buffer A, and such information need to be propagated in some way.

We can leave it for another PR.

Actually, current solution is let lookup_result[...] return -1 as an invalid index.Do you mean return "-1" is not a proper way to represent invalid index?

@qelk123 qelk123 requested a review from yzh119 March 21, 2023 11:41
@@ -101,7 +101,7 @@ def csrmm_dense_iter(
low = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
high = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
low[0] = 0
high[0] = J_indptr_data[vi + 1] - J_indptr_data[vi]
high[0] = J_indptr_data[vi + 1] - J_indptr_data[vi] - 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other original implementation we are using half open interval ([a, b)) semantics, please check our loop invariant(

/* Algorithm:
* - when left = true
* - pre-condition
* lb < ub, and the last dimension of buf is sorted.
* - loop-invariant
* low <= mid < high, buf[..., lb:low] < val, buf[..., high:ub] >= val
* - post-condition
* low = mid = high, buf[..., lb:low] < val, buf[..., high:ub] >= val
* - when left = false
* - pre-condition
* lb < ub, and the last dimension of buf is sorted.
* - loop-invariant
* low <= mid < high, buf[..., lb:low] <= val, buf[..., high:ub] > val
* - post-condition
* low = mid = high, buf[..., lb:low] <= val, buf[..., high:ub] > val
). So I suppose this change is not necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PrimExpr BinarySearch(SparseBuffer buf, Array<PrimExpr> prefix_indices, PrimExpr lb, PrimExpr ub,
PrimExpr val, bool left, bool minus_one = false) {

In the function you set the upper bound as indptr[i+1]-indptr[i]
from these line
PrimExpr pivot_cmp_cond = left ? (pivot < val) : (pivot > val);
Stmt if_true = left ? BufferStore(low, mid_val + 1, {Integer(0)})
: BufferStore(high, mid_val, {Integer(0)});

we can see if the val is greater than indices[i,indptr[i+1]-indptr[i]-1],it's possable that the mid buffer we return contains the value of indices[i,indptr[i+1]-indptr[i]],which is wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's more,if the val of index just equal to indices[i,indptr[i+1]-indptr[i]] then it will return a valid index,and cause the wrong result.

Copy link
Member

@yzh119 yzh119 Mar 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our semantic follows the lower_bound and upper_bound API in both C++ (https://en.cppreference.com/w/cpp/algorithm/lower_bound and https://en.cppreference.com/w/cpp/algorithm/upper_bound) and Numpy: (https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html).
Both of them would return the ub when the key being searched is out of bound, which aligns with our implementation. Note that none of lower_bound nor upper_bound are designed for exact match.

I know the purpose of your J_indptr_data[vi + 1] - J_indptr_data[vi] - 1 change is to avoid the case where the returned value in mid is J_indptr_data[vi + 1] - J_indptr_data[vi] and we want to use mid buffer to index original buffer for valid check. However, such change would break our loop invariant.

A better idea is not to check the validity in caller side, but creating another buffer valid to store the state whether the check is valid or not when during the binary search. I'll create another PR for that and you can rebase on it.

Copy link
Contributor Author

@qelk123 qelk123 Mar 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK,I can fix this later.For simplicity, I will probably add a judgment condition when deciding whether set a invalid index in mid

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have created the PR: #96 let me know if it looks good to you.

Copy link
Contributor Author

@qelk123 qelk123 Mar 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you choose the same way (by adding a judgment condition) to fix this problem like me in my newest commit

Stmt then_stmt = BufferStore(mid, -1, mid_indices);
PrimExpr if_stmt = (pivot != val || mid_val == ub);
body_stmts.push_back(IfThenElse(if_stmt, then_stmt));

However,I don't know why you return a pair in binary_search and it seems the second item only contain the valid flag instead of exact value.Do you mean we should leave the valid check outside the binary search function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't notice your commit.
The valid flag is also returned to tell the caller of BinarySearch function whether the search is successful, so that we don't need to access buf[..., mid[...]] outside the binary search body.

We don't need exact value because the valid flag indicate such information, if the search is successful then the value equals buf[..., ,mid[...]].

Copy link
Contributor Author

@qelk123 qelk123 Mar 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @qelk123 do you need any assistance on rebase? It seems the codebase disappears and the PR is closed.

It's OK ,I will open a new pr later.

tests/python/sparsetir/sparse_tir_lowered_iter_scripts.py Outdated Show resolved Hide resolved
tests/python/sparsetir/sparse_tir_lowered_iter_scripts.py Outdated Show resolved Hide resolved
tests/python/sparsetir/sparse_tir_lowered_iter_scripts.py Outdated Show resolved Hide resolved
tests/python/sparsetir/sparse_tir_lowered_iter_scripts.py Outdated Show resolved Hide resolved
tests/python/sparsetir/sparse_tir_lowered_iter_scripts.py Outdated Show resolved Hide resolved
tests/python/sparsetir/sparse_tir_lowered_iter_scripts.py Outdated Show resolved Hide resolved
"global_symbol": "main",
"tir.noalias": True,
"sparse_tir_level": 1,
"check_invalid_binary_search": False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this annotation would not be emitted because we will only emit "check_invalid_binary_search": True when invalid binary search check is enabled.

Copy link
Contributor Author

@qelk123 qelk123 Mar 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually,check_invalid_binary_search is a sign to tell the lower_sparse_buffer pass whether we should consider invalid index created by the lower_iter_pass.This is to reduce the overhead.Otherwise, a sign should add to the lower_sparse_buffer pass in other ways.

Copy link
Member

@yzh119 yzh119 Mar 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we can choose not to print this annotation when check_invalid_binary_search = False (not enabled).

@yzh119 yzh119 closed this Mar 21, 2023
@yzh119 yzh119 reopened this Mar 21, 2023
@yzh119
Copy link
Member

yzh119 commented Mar 21, 2023

From my perspective,even the binary search results is not a valid index in the indices buffer,the buffer load result like A[lookup_result[...]] probably should not directly return a invalid flag,but a default value about the sparse buffer.In TACO,this default value can be spcified when declare a tensor.

Yes we should return a default value, what do I mean by "invalid flag" is a indicator that whole expression containing it should be invalidated and return its default value.
Back to my example A[lookup_result[...]], if we return the default value 0 for an unsuccessful lookup_result, then we will get A[0], but what we really need is default value for buffer A, and such information need to be propagated in some way.
We can leave it for another PR.

Actually, current solution is let lookup_result[...] return -1 as an invalid index.Do you mean return "-1" is not a proper way to represent invalid index?

We support non-affine indices, and the buffer access indices could be value loaded from another buffer access (e.g. A[B[C[...]]]), my example shows that we need to return default value for the entire expression A[B[C[...]]] if binary search is invalid, not only C[...] or B[C[...]].

@qelk123
Copy link
Contributor Author

qelk123 commented Mar 21, 2023

From my perspective,even the binary search results is not a valid index in the indices buffer,the buffer load result like A[lookup_result[...]] probably should not directly return a invalid flag,but a default value about the sparse buffer.In TACO,this default value can be spcified when declare a tensor.

Yes we should return a default value, what do I mean by "invalid flag" is a indicator that whole expression containing it should be invalidated and return its default value.
Back to my example A[lookup_result[...]], if we return the default value 0 for an unsuccessful lookup_result, then we will get A[0], but what we really need is default value for buffer A, and such information need to be propagated in some way.
We can leave it for another PR.

Actually, current solution is let lookup_result[...] return -1 as an invalid index.Do you mean return "-1" is not a proper way to represent invalid index?

We support non-affine indices, and the buffer access indices could be value loaded from another buffer access (e.g. A[B[C[...]]]), my example shows that we need to return default value for the entire expression A[B[C[...]]] if binary search is invalid, not only C[...] or B[C[...]].

I thought when facing the situation like A[B[C[...]]] if the binary search is failed in C,then C[...] will return a default value,and we can use this value to index B,the expression is still available?

@yzh119
Copy link
Member

yzh119 commented Mar 24, 2023

From my perspective,even the binary search results is not a valid index in the indices buffer,the buffer load result like A[lookup_result[...]] probably should not directly return a invalid flag,but a default value about the sparse buffer.In TACO,this default value can be spcified when declare a tensor.

Yes we should return a default value, what do I mean by "invalid flag" is a indicator that whole expression containing it should be invalidated and return its default value.
Back to my example A[lookup_result[...]], if we return the default value 0 for an unsuccessful lookup_result, then we will get A[0], but what we really need is default value for buffer A, and such information need to be propagated in some way.
We can leave it for another PR.

Actually, current solution is let lookup_result[...] return -1 as an invalid index.Do you mean return "-1" is not a proper way to represent invalid index?

We support non-affine indices, and the buffer access indices could be value loaded from another buffer access (e.g. A[B[C[...]]]), my example shows that we need to return default value for the entire expression A[B[C[...]]] if binary search is invalid, not only C[...] or B[C[...]].

I thought when facing the situation like A[B[C[...]]] if the binary search is failed in C,then C[...] will return a default value,and we can use this value to index B,the expression is still available?

I believe you're correct. Let's stick with this solution until rare cases.

@yzh119
Copy link
Member

yzh119 commented Mar 24, 2023

Hi @qelk123 do you need any assistance on rebase? It seems the codebase disappears and the PR is closed.

@qelk123 qelk123 reopened this Mar 24, 2023
@qelk123 qelk123 requested a review from yzh119 March 24, 2023 12:53
Copy link
Member

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @qelk123, this version looks good to me!

src/tir/transforms/lower_sparse_iter.cc Outdated Show resolved Hide resolved
src/tir/transforms/lower_sparse_iter.cc Show resolved Hide resolved
src/tir/transforms/lower_sparse_iter.cc Show resolved Hide resolved
@yzh119 yzh119 merged commit 9c26e84 into uwsampl:main Mar 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants