From 743f95f0c34b02d6d2cdb9e87da21caffe9c668f Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Tue, 22 Oct 2024 08:20:33 -0700 Subject: [PATCH] Remove unnecessary max_tokens params PiperOrigin-RevId: 688557487 --- include/xnnpack.h | 4 +--- src/operators/rope-nthc.c | 21 --------------------- src/subgraph/rope.c | 10 ---------- src/xnnpack/operator.h | 1 - src/xnnpack/subgraph.h | 3 --- test/rope-operator-tester.h | 4 ++-- test/rope.cc | 6 ++---- 7 files changed, 5 insertions(+), 44 deletions(-) diff --git a/include/xnnpack.h b/include/xnnpack.h index 0acfcb5ab1d..95a7057c0a2 100644 --- a/include/xnnpack.h +++ b/include/xnnpack.h @@ -1661,7 +1661,7 @@ XNN_DEPRECATED enum xnn_status xnn_define_prelu( /// Define a RoPE (Rotary Positional Embeddings) Node and add it to a Subgraph. /// /// @param subgraph - a Subgraph object that will own the created Node. -/// @param max_tokens - maximum possible number of tokens (maximum sequence length) of the input/output tensors. +/// @param max_tokens - deprecated. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph /// with [batch, tokens, heads, channels] dimensions. /// @param weights_id - Value ID for the weights tensor. The weights tensor must be a 2D tensor defined in the @@ -5234,7 +5234,6 @@ enum xnn_status xnn_setup_resize_bilinear2d_nhwc_u8( uint8_t* output); enum xnn_status xnn_create_rope_nthc_f16( - size_t max_tokens, uint32_t flags, xnn_operator_t* rope_op_out); @@ -5253,7 +5252,6 @@ enum xnn_status xnn_setup_rope_nthc_f16( void* output); enum xnn_status xnn_create_rope_nthc_f32( - size_t max_tokens, uint32_t flags, xnn_operator_t* rope_op_out); diff --git a/src/operators/rope-nthc.c b/src/operators/rope-nthc.c index 5de450de70f..43d4e33e8fc 100644 --- a/src/operators/rope-nthc.c +++ b/src/operators/rope-nthc.c @@ -22,7 +22,6 @@ #include "pthreadpool.h" static enum xnn_status create_rope_nthc( - size_t max_tokens, uint32_t flags, enum xnn_operator_type operator_type, const struct xnn_cmul_config* config, @@ -39,13 +38,6 @@ static enum xnn_status create_rope_nthc( status = xnn_status_invalid_parameter; - if (max_tokens == 0) { - xnn_log_error( - "failed to create %s operator with %zu max tokens: maximum number of tokens must be non-zero", - xnn_operator_type_to_string(operator_type), max_tokens); - goto error; - } - status = xnn_status_out_of_memory; rope_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator)); @@ -56,8 +48,6 @@ static enum xnn_status create_rope_nthc( goto error; } - rope_op->max_tokens = max_tokens; - rope_op->type = operator_type; rope_op->flags = flags; rope_op->cmul_config = config; @@ -73,7 +63,6 @@ static enum xnn_status create_rope_nthc( } enum xnn_status xnn_create_rope_nthc_f16( - size_t max_tokens, uint32_t flags, xnn_operator_t* rope_op_out) { @@ -85,7 +74,6 @@ enum xnn_status xnn_create_rope_nthc_f16( } return create_rope_nthc( - max_tokens, flags, xnn_operator_type_rope_nthc_f16, config, @@ -93,7 +81,6 @@ enum xnn_status xnn_create_rope_nthc_f16( } enum xnn_status xnn_create_rope_nthc_f32( - size_t max_tokens, uint32_t flags, xnn_operator_t* rope_op_out) { @@ -105,7 +92,6 @@ enum xnn_status xnn_create_rope_nthc_f32( } return create_rope_nthc( - max_tokens, flags, xnn_operator_type_rope_nthc_f32, config, @@ -138,13 +124,6 @@ static enum xnn_status reshape_rope_nthc( return xnn_status_invalid_parameter; } - if (tokens > rope_op->max_tokens) { - xnn_log_error( - "failed to reshape %s operator with %zu tokens: number of tokens can not exceed the maximum %zu", - xnn_operator_type_to_string(rope_op->type), tokens, rope_op->max_tokens); - return xnn_status_invalid_parameter; - } - if (heads == 0) { xnn_log_error( "failed to reshape %s operator with %zu heads: number of heads must be non-zero", diff --git a/src/subgraph/rope.c b/src/subgraph/rope.c index 8d424b79981..3f6e3b64121 100644 --- a/src/subgraph/rope.c +++ b/src/subgraph/rope.c @@ -36,13 +36,11 @@ static enum xnn_status create_rope_operator( switch (input_value->datatype) { case xnn_datatype_fp16: status = xnn_create_rope_nthc_f16( - node->params.rope.max_tokens, /*flags=*/0, &opdata->operator_objects[0]); break; case xnn_datatype_fp32: status = xnn_create_rope_nthc_f32( - node->params.rope.max_tokens, /*flags=*/0, &opdata->operator_objects[0]); break; @@ -170,13 +168,6 @@ enum xnn_status xnn_define_rope( return status; } - if (max_tokens == 0) { - xnn_log_error( - "failed to define %s operator with %zu max tokens: maximum number of tokens must be non-zero", - xnn_node_type_to_string(xnn_node_type_rope), max_tokens); - return xnn_status_invalid_parameter; - } - status = xnn_subgraph_check_input_node_id(xnn_node_type_rope, input_id, subgraph->num_values); if (status != xnn_status_success) { return status; @@ -262,7 +253,6 @@ enum xnn_status xnn_define_rope( node->type = xnn_node_type_rope; node->compute_type = compute_type; - node->params.rope.max_tokens = max_tokens; node->num_inputs = 2; node->inputs[0] = input_id; node->inputs[1] = weights_id; diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h index 8546d7985c1..0fbb6c36670 100644 --- a/src/xnnpack/operator.h +++ b/src/xnnpack/operator.h @@ -150,7 +150,6 @@ struct xnn_operator { size_t group_input_channels; size_t group_output_channels; size_t channels; - size_t max_tokens; uint32_t pad_value; diff --git a/src/xnnpack/subgraph.h b/src/xnnpack/subgraph.h index e8000c085b8..d60e34bad28 100644 --- a/src/xnnpack/subgraph.h +++ b/src/xnnpack/subgraph.h @@ -319,9 +319,6 @@ struct xnn_node { size_t new_height; size_t new_width; } static_resize; - struct { - size_t max_tokens; - } rope; struct { size_t num_dims; size_t offsets[XNN_MAX_TENSOR_DIMS]; diff --git a/test/rope-operator-tester.h b/test/rope-operator-tester.h index e887fec4892..33194b17d63 100644 --- a/test/rope-operator-tester.h +++ b/test/rope-operator-tester.h @@ -133,7 +133,7 @@ class RoPEOperatorTester { xnn_operator_t rope_op = nullptr; const xnn_status status = xnn_create_rope_nthc_f16( - /*max_tokens=*/tokens(), /*flags=*/0, &rope_op); + /*flags=*/0, &rope_op); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); } @@ -237,7 +237,7 @@ class RoPEOperatorTester { xnn_operator_t rope_op = nullptr; const xnn_status status = xnn_create_rope_nthc_f32( - /*max_tokens=*/tokens(), /*flags=*/0, &rope_op); + /*flags=*/0, &rope_op); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); } diff --git a/test/rope.cc b/test/rope.cc index e377d59562e..54ce6a1bcdc 100644 --- a/test/rope.cc +++ b/test/rope.cc @@ -98,7 +98,6 @@ TEST_F(RoPETestF16, define) const struct xnn_node* node = &subgraph->nodes[0]; ASSERT_EQ(node->type, xnn_node_type_rope); ASSERT_EQ(node->compute_type, xnn_compute_type_fp16); - ASSERT_EQ(node->params.rope.max_tokens, max_tokens); ASSERT_EQ(node->num_inputs, 2); ASSERT_EQ(node->inputs[0], input_id); ASSERT_EQ(node->inputs[1], weights_id); @@ -143,7 +142,6 @@ TEST_F(RoPETestF32, define) const struct xnn_node* node = &subgraph->nodes[0]; ASSERT_EQ(node->type, xnn_node_type_rope); ASSERT_EQ(node->compute_type, xnn_compute_type_fp32); - ASSERT_EQ(node->params.rope.max_tokens, max_tokens); ASSERT_EQ(node->num_inputs, 2); ASSERT_EQ(node->inputs[0], input_id); ASSERT_EQ(node->inputs[1], weights_id); @@ -161,7 +159,7 @@ TEST_F(RoPETestF16, matches_operator_api) std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); std::generate(weights.begin(), weights.end(), [&]() { return f32dist(rng); }); - const xnn_status status = xnn_create_rope_nthc_f16(max_tokens, /*flags=*/0, &op); + const xnn_status status = xnn_create_rope_nthc_f16(/*flags=*/0, &op); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); } @@ -239,7 +237,7 @@ TEST_F(RoPETestF32, matches_operator_api) std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); std::generate(weights.begin(), weights.end(), [&]() { return f32dist(rng); }); - const xnn_status status = xnn_create_rope_nthc_f32(max_tokens, /*flags=*/0, &op); + const xnn_status status = xnn_create_rope_nthc_f32(/*flags=*/0, &op); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); }