From 7eaafc360f61a68a93661811f7e2a807cd868b75 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Date: Tue, 10 Dec 2024 19:44:00 -0500 Subject: [PATCH] MatMulNBits collapse shape input when > 1d (#3698) --- src/onnx/parse_matmulnbits.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/onnx/parse_matmulnbits.cpp b/src/onnx/parse_matmulnbits.cpp index af9f09790aa..e5fb336dfad 100644 --- a/src/onnx/parse_matmulnbits.cpp +++ b/src/onnx/parse_matmulnbits.cpp @@ -66,10 +66,10 @@ struct parse_matmulnbits : op_parser to_string_range(expected_b_lens) + ". Actual dims: " + to_string_range(args[1]->get_shape().lens())); - std::vector expected_scales_lens{n * n_blocks_per_col}; - if(args[2]->get_shape().lens() != expected_scales_lens) + const size_t expected_scales_lens = n * n_blocks_per_col; + if(args[2]->get_shape().elements() != expected_scales_lens) MIGRAPHX_THROW("MatMulNBits: Input scales does not match expected dims: " + - to_string_range(expected_scales_lens) + + to_string(expected_scales_lens) + ". Actual dims: " + to_string_range(args[2]->get_shape().lens())); if(args.size() > 3)