Skip to content

Commit

Permalink
MatMulNBits collapse shape input when > 1d (#3698)
Browse files Browse the repository at this point in the history
  • Loading branch information
TedThemistokleous authored Dec 11, 2024
1 parent 64fe0c5 commit 7eaafc3
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/onnx/parse_matmulnbits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ struct parse_matmulnbits : op_parser<parse_matmulnbits>
to_string_range(expected_b_lens) +
". Actual dims: " + to_string_range(args[1]->get_shape().lens()));

std::vector<size_t> expected_scales_lens{n * n_blocks_per_col};
if(args[2]->get_shape().lens() != expected_scales_lens)
const size_t expected_scales_lens = n * n_blocks_per_col;
if(args[2]->get_shape().elements() != expected_scales_lens)
MIGRAPHX_THROW("MatMulNBits: Input scales does not match expected dims: " +
to_string_range(expected_scales_lens) +
to_string(expected_scales_lens) +
". Actual dims: " + to_string_range(args[2]->get_shape().lens()));

if(args.size() > 3)
Expand Down

0 comments on commit 7eaafc3

Please sign in to comment.