From 88472af5b471e3d78cbb60c0e76ebed0e90d12a1 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 9 Dec 2024 17:05:08 +0000 Subject: [PATCH 1/4] Fix input shape for matmulnbits to fold input via reshape to 1d input --- src/onnx/parse_matmulnbits.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/onnx/parse_matmulnbits.cpp b/src/onnx/parse_matmulnbits.cpp index af9f09790aa..8356b5d0a57 100644 --- a/src/onnx/parse_matmulnbits.cpp +++ b/src/onnx/parse_matmulnbits.cpp @@ -67,7 +67,15 @@ struct parse_matmulnbits : op_parser ". Actual dims: " + to_string_range(args[1]->get_shape().lens())); std::vector expected_scales_lens{n * n_blocks_per_col}; - if(args[2]->get_shape().lens() != expected_scales_lens) + + // Reshape anything larger than 1 dimension into a 1d tensor so we can check if we have the right amount of elements. + auto scale_input = args[2]; + if(scale_input->get_shape().lens().size() > 1) + { + scale_input = info.add_instruction(make_op("reshape", {{"dims", {scale_input->get_shape().elements()}}}), scale_input); + } + + if(scale_input->get_shape().lens() != expected_scales_lens) MIGRAPHX_THROW("MatMulNBits: Input scales does not match expected dims: " + to_string_range(expected_scales_lens) + ". Actual dims: " + to_string_range(args[2]->get_shape().lens())); From b595e4b1a7a19b252c524ecc57e030235cb96cd4 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 9 Dec 2024 22:18:03 +0000 Subject: [PATCH 2/4] Fix format --- src/onnx/parse_matmulnbits.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/onnx/parse_matmulnbits.cpp b/src/onnx/parse_matmulnbits.cpp index 8356b5d0a57..acdb6b95334 100644 --- a/src/onnx/parse_matmulnbits.cpp +++ b/src/onnx/parse_matmulnbits.cpp @@ -68,11 +68,13 @@ struct parse_matmulnbits : op_parser std::vector expected_scales_lens{n * n_blocks_per_col}; - // Reshape anything larger than 1 dimension into a 1d tensor so we can check if we have the right amount of elements. + // Reshape anything larger than 1 dimension into a 1d tensor so we can check if we have the + // right amount of elements. auto scale_input = args[2]; if(scale_input->get_shape().lens().size() > 1) { - scale_input = info.add_instruction(make_op("reshape", {{"dims", {scale_input->get_shape().elements()}}}), scale_input); + scale_input = info.add_instruction( + make_op("reshape", {{"dims", {scale_input->get_shape().elements()}}}), scale_input); } if(scale_input->get_shape().lens() != expected_scales_lens) From 005d8c7a4758c6857abbe9578af32691b9fa62bb Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 9 Dec 2024 23:11:55 +0000 Subject: [PATCH 3/4] Just compare elements and remove reshape --- src/onnx/parse_matmulnbits.cpp | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/onnx/parse_matmulnbits.cpp b/src/onnx/parse_matmulnbits.cpp index acdb6b95334..869342de543 100644 --- a/src/onnx/parse_matmulnbits.cpp +++ b/src/onnx/parse_matmulnbits.cpp @@ -68,16 +68,7 @@ struct parse_matmulnbits : op_parser std::vector expected_scales_lens{n * n_blocks_per_col}; - // Reshape anything larger than 1 dimension into a 1d tensor so we can check if we have the - // right amount of elements. - auto scale_input = args[2]; - if(scale_input->get_shape().lens().size() > 1) - { - scale_input = info.add_instruction( - make_op("reshape", {{"dims", {scale_input->get_shape().elements()}}}), scale_input); - } - - if(scale_input->get_shape().lens() != expected_scales_lens) + if(args[2]->get_shape().elements() != expected_scales_lens[0]) MIGRAPHX_THROW("MatMulNBits: Input scales does not match expected dims: " + to_string_range(expected_scales_lens) + ". Actual dims: " + to_string_range(args[2]->get_shape().lens())); From 86d9ee3ff79e30d5ac11717e74dfb168ebf147db Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Tue, 10 Dec 2024 15:53:19 +0000 Subject: [PATCH 4/4] Remove the need to use vector --- src/onnx/parse_matmulnbits.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/onnx/parse_matmulnbits.cpp b/src/onnx/parse_matmulnbits.cpp index 869342de543..e5fb336dfad 100644 --- a/src/onnx/parse_matmulnbits.cpp +++ b/src/onnx/parse_matmulnbits.cpp @@ -66,11 +66,10 @@ struct parse_matmulnbits : op_parser to_string_range(expected_b_lens) + ". Actual dims: " + to_string_range(args[1]->get_shape().lens())); - std::vector expected_scales_lens{n * n_blocks_per_col}; - - if(args[2]->get_shape().elements() != expected_scales_lens[0]) + const size_t expected_scales_lens = n * n_blocks_per_col; + if(args[2]->get_shape().elements() != expected_scales_lens) MIGRAPHX_THROW("MatMulNBits: Input scales does not match expected dims: " + - to_string_range(expected_scales_lens) + + to_string(expected_scales_lens) + ". Actual dims: " + to_string_range(args[2]->get_shape().lens())); if(args.size() > 3)