Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MatMulNBits collapse shape input when > 1d #3698

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/onnx/parse_matmulnbits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,15 @@ struct parse_matmulnbits : op_parser<parse_matmulnbits>
". Actual dims: " + to_string_range(args[1]->get_shape().lens()));

std::vector<size_t> expected_scales_lens{n * n_blocks_per_col};
TedThemistokleous marked this conversation as resolved.
Show resolved Hide resolved
if(args[2]->get_shape().lens() != expected_scales_lens)

// Reshape anything larger than 1 dimension into a 1d tensor so we can check if we have the right amount of elements.
auto scale_input = args[2];
if(scale_input->get_shape().lens().size() > 1)
{
scale_input = info.add_instruction(make_op("reshape", {{"dims", {scale_input->get_shape().elements()}}}), scale_input);
}

if(scale_input->get_shape().lens() != expected_scales_lens)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is just to check then we should just use .elements instead of inserting a reshape: if(args[2]->get_shape().elements() != (n * n_blocks_per_col}))

Copy link
Collaborator Author

@TedThemistokleous TedThemistokleous Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Easy enough, but don't we need these to be in the correct 1d shape for the input? in the dequantize_b we do another reshape as well on the input scale.

auto scales = info.add_instruction(make_op("reshape", {{"dims", {n, -1}}}), args[2]);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because the first thing we do is reshape it to a 2d tensor of{n, -1}(where -1 is the remaining elements).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, you are not even using the reshaped instruction that is inserted, so there is no reason to add something to be always removed by DCE.

MIGRAPHX_THROW("MatMulNBits: Input scales does not match expected dims: " +
to_string_range(expected_scales_lens) +
". Actual dims: " + to_string_range(args[2]->get_shape().lens()));
Expand Down
Loading