-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix when binary search find invalid value #95
Conversation
Good job @qelk123 ! |
Dear Zihao Ye,
Thank you for your reply.This is the first time I propose a PR,so although this modification works fine in my case,I believe there are inconsiderate parts in my code.Please give me your feedback so that I could refine it later.
Michael
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall a nice contribution, I have a few comments.
Please format the code per suggestions, that would not take too much of your time.
I suggest adding a binary_search_vaild_check
key in block annotation, indicating whether the user wants to perform a such check or not. If set to false, then we will examine whether the search is invalid or not, to reduce some overhead.
if(find!=find_backup) | ||
{ | ||
find=find_backup; | ||
if(op->buffer->dtype.is_float()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The data types are not only float32
and int32
.
0
is also not only the valid choice as initial values, suppose we want to perform max reduction, -inf
would be a better choice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can let user specify their desired initial values as block annotations, if such value is not set by user, we turn to use 0
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What kind of data types do we need to support currently(besides float and int)? Type info is needed when the user doesn't offer a default value,so we need to set the default value to 0 with a specific type according to the dtype of the sparse buffer.
Stmt then_stmt = BufferStore(mid, -1, mid_indices); | ||
PrimExpr if_stmt = (pivot != val); | ||
body_stmts.push_back( | ||
IfThenElse(if_stmt,then_stmt)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might not need to emit this IfThenElse
if we can confirm the lookup will not fail.
else | ||
return StmtExprMutator::VisitStmt_(op); | ||
} | ||
std::vector<BufferLoad> buffer_need_process; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need to store BufferLoad
? How about we only store Buffer
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BufferLoad here is to construct the if cond expr,if we only store the buffer how should we get the indices of the load buffer?
I had a proposal months before: #74, the really tricky part is Suppose |
From my perspective,even the binary search results is not a valid index in the indices buffer,the buffer load result like |
Yes we should return a default value, what do I mean by "invalid flag" is a indicator that whole expression containing it should be invalidated and return its default value. Back to my example We can leave it for another PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks much better than its initial form, thank you so much, I have some suggestions for further improvements.
Actually, current solution is let lookup_result[...] return -1 as an invalid index.Do you mean return "-1" is not a proper way to represent invalid index? |
@@ -101,7 +101,7 @@ def csrmm_dense_iter( | |||
low = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local") | |||
high = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local") | |||
low[0] = 0 | |||
high[0] = J_indptr_data[vi + 1] - J_indptr_data[vi] | |||
high[0] = J_indptr_data[vi + 1] - J_indptr_data[vi] - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In other original implementation we are using half open interval ([a, b)
) semantics, please check our loop invariant(
SparseTIR/src/tir/transforms/lower_sparse_iter.cc
Lines 830 to 844 in ca59cbe
/* Algorithm: | |
* - when left = true | |
* - pre-condition | |
* lb < ub, and the last dimension of buf is sorted. | |
* - loop-invariant | |
* low <= mid < high, buf[..., lb:low] < val, buf[..., high:ub] >= val | |
* - post-condition | |
* low = mid = high, buf[..., lb:low] < val, buf[..., high:ub] >= val | |
* - when left = false | |
* - pre-condition | |
* lb < ub, and the last dimension of buf is sorted. | |
* - loop-invariant | |
* low <= mid < high, buf[..., lb:low] <= val, buf[..., high:ub] > val | |
* - post-condition | |
* low = mid = high, buf[..., lb:low] <= val, buf[..., high:ub] > val |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SparseTIR/src/tir/transforms/lower_sparse_iter.cc
Lines 828 to 829 in ca59cbe
PrimExpr BinarySearch(SparseBuffer buf, Array<PrimExpr> prefix_indices, PrimExpr lb, PrimExpr ub, | |
PrimExpr val, bool left, bool minus_one = false) { |
In the function you set the upper bound as
indptr[i+1]-indptr[i]
from these line
SparseTIR/src/tir/transforms/lower_sparse_iter.cc
Lines 923 to 925 in ca59cbe
PrimExpr pivot_cmp_cond = left ? (pivot < val) : (pivot > val); | |
Stmt if_true = left ? BufferStore(low, mid_val + 1, {Integer(0)}) | |
: BufferStore(high, mid_val, {Integer(0)}); |
we can see if the val is greater than indices[i,indptr[i+1]-indptr[i]-1],it's possable that the mid buffer we return contains the value of indices[i,indptr[i+1]-indptr[i]],which is wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's more,if the val of index just equal to indices[i,indptr[i+1]-indptr[i]]
then it will return a valid index,and cause the wrong result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our semantic follows the lower_bound
and upper_bound
API in both C++ (https://en.cppreference.com/w/cpp/algorithm/lower_bound and https://en.cppreference.com/w/cpp/algorithm/upper_bound) and Numpy: (https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html).
Both of them would return the ub
when the key being searched is out of bound, which aligns with our implementation. Note that none of lower_bound
nor upper_bound
are designed for exact match.
I know the purpose of your J_indptr_data[vi + 1] - J_indptr_data[vi] - 1
change is to avoid the case where the returned value in mid
is J_indptr_data[vi + 1] - J_indptr_data[vi]
and we want to use mid
buffer to index original buffer for valid check. However, such change would break our loop invariant.
A better idea is not to check the validity in caller side, but creating another buffer valid
to store the state whether the check is valid or not when during the binary search. I'll create another PR for that and you can rebase on it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK,I can fix this later.For simplicity, I will probably add a judgment condition when deciding whether set a invalid index in mid
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have created the PR: #96 let me know if it looks good to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like you choose the same way (by adding a judgment condition) to fix this problem like me in my newest commit
Stmt then_stmt = BufferStore(mid, -1, mid_indices);
PrimExpr if_stmt = (pivot != val || mid_val == ub);
body_stmts.push_back(IfThenElse(if_stmt, then_stmt));
However,I don't know why you return a pair in binary_search
and it seems the second item only contain the valid flag instead of exact value.Do you mean we should leave the valid check outside the binary search function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I didn't notice your commit.
The valid flag is also returned to tell the caller of BinarySearch
function whether the search is successful, so that we don't need to access buf[..., mid[...]]
outside the binary search body.
We don't need exact value because the valid flag indicate such information, if the search is successful then the value equals buf[..., ,mid[...]]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @qelk123 do you need any assistance on rebase? It seems the codebase disappears and the PR is closed.
It's OK ,I will open a new pr later.
"global_symbol": "main", | ||
"tir.noalias": True, | ||
"sparse_tir_level": 1, | ||
"check_invalid_binary_search": False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose this annotation would not be emitted because we will only emit "check_invalid_binary_search": True
when invalid binary search check is enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually,check_invalid_binary_search
is a sign to tell the lower_sparse_buffer
pass whether we should consider invalid index created by the lower_iter_pass.This is to reduce the overhead.Otherwise, a sign should add to the lower_sparse_buffer
pass in other ways.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean we can choose not to print this annotation when check_invalid_binary_search = False
(not enabled).
We support non-affine indices, and the buffer access indices could be value loaded from another buffer access (e.g. A[B[C[...]]]), my example shows that we need to return default value for the entire expression A[B[C[...]]] if binary search is invalid, not only C[...] or B[C[...]]. |
I thought when facing the situation like |
I believe you're correct. Let's stick with this solution until rare cases. |
Hi @qelk123 do you need any assistance on rebase? It seems the codebase disappears and the PR is closed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @qelk123, this version looks good to me!
This is a simple fix up for the situation when binary searh find invalid indices,which contains:
1.An IfThenElse stmt after binary search,if the index is invalid,then return -1 as an invalid index value.
2.add a postprocess pass after lower sparse buffer,for those sparse data buffer,if the index value is a invalid value,then return a 0 from the sparse data buffer.