-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MatMulNBits collapse shape input when > 1d #3698
Conversation
src/onnx/parse_matmulnbits.cpp
Outdated
scale_input = info.add_instruction(make_op("reshape", {{"dims", {scale_input->get_shape().elements()}}}), scale_input); | ||
} | ||
|
||
if(scale_input->get_shape().lens() != expected_scales_lens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is just to check then we should just use .elements
instead of inserting a reshape: if(args[2]->get_shape().elements() != (n * n_blocks_per_col}))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Easy enough, but don't we need these to be in the correct 1d shape for the input? in the dequantize_b we do another reshape as well on the input scale.
auto scales = info.add_instruction(make_op("reshape", {{"dims", {n, -1}}}), args[2]);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, because the first thing we do is reshape it to a 2d tensor of{n, -1}
(where -1 is the remaining elements).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, you are not even using the reshaped instruction that is inserted, so there is no reason to add something to be always removed by DCE.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #3698 +/- ##
========================================
Coverage 92.23% 92.23%
========================================
Files 514 514
Lines 21746 21746
========================================
Hits 20057 20057
Misses 1689 1689 ☔ View full report in Codecov by Sentry. |
297e7b9
to
b595e4b
Compare
Check results before merge 🔆 |
🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output |
Fix input shape for matmulnbits to fold input via reshape to 1d input