Skip to content

Commit

Permalink
Remove unnecessary max_tokens params
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688557487
  • Loading branch information
alankelly authored and xnnpack-bot committed Oct 22, 2024
1 parent 6515679 commit 743f95f
Show file tree
Hide file tree
Showing 7 changed files with 5 additions and 44 deletions.
4 changes: 1 addition & 3 deletions include/xnnpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -1661,7 +1661,7 @@ XNN_DEPRECATED enum xnn_status xnn_define_prelu(
/// Define a RoPE (Rotary Positional Embeddings) Node and add it to a Subgraph.
///
/// @param subgraph - a Subgraph object that will own the created Node.
/// @param max_tokens - maximum possible number of tokens (maximum sequence length) of the input/output tensors.
/// @param max_tokens - deprecated.
/// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
/// with [batch, tokens, heads, channels] dimensions.
/// @param weights_id - Value ID for the weights tensor. The weights tensor must be a 2D tensor defined in the
Expand Down Expand Up @@ -5234,7 +5234,6 @@ enum xnn_status xnn_setup_resize_bilinear2d_nhwc_u8(
uint8_t* output);

enum xnn_status xnn_create_rope_nthc_f16(
size_t max_tokens,
uint32_t flags,
xnn_operator_t* rope_op_out);

Expand All @@ -5253,7 +5252,6 @@ enum xnn_status xnn_setup_rope_nthc_f16(
void* output);

enum xnn_status xnn_create_rope_nthc_f32(
size_t max_tokens,
uint32_t flags,
xnn_operator_t* rope_op_out);

Expand Down
21 changes: 0 additions & 21 deletions src/operators/rope-nthc.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "pthreadpool.h"

static enum xnn_status create_rope_nthc(
size_t max_tokens,
uint32_t flags,
enum xnn_operator_type operator_type,
const struct xnn_cmul_config* config,
Expand All @@ -39,13 +38,6 @@ static enum xnn_status create_rope_nthc(

status = xnn_status_invalid_parameter;

if (max_tokens == 0) {
xnn_log_error(
"failed to create %s operator with %zu max tokens: maximum number of tokens must be non-zero",
xnn_operator_type_to_string(operator_type), max_tokens);
goto error;
}

status = xnn_status_out_of_memory;

rope_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
Expand All @@ -56,8 +48,6 @@ static enum xnn_status create_rope_nthc(
goto error;
}

rope_op->max_tokens = max_tokens;

rope_op->type = operator_type;
rope_op->flags = flags;
rope_op->cmul_config = config;
Expand All @@ -73,7 +63,6 @@ static enum xnn_status create_rope_nthc(
}

enum xnn_status xnn_create_rope_nthc_f16(
size_t max_tokens,
uint32_t flags,
xnn_operator_t* rope_op_out)
{
Expand All @@ -85,15 +74,13 @@ enum xnn_status xnn_create_rope_nthc_f16(
}

return create_rope_nthc(
max_tokens,
flags,
xnn_operator_type_rope_nthc_f16,
config,
rope_op_out);
}

enum xnn_status xnn_create_rope_nthc_f32(
size_t max_tokens,
uint32_t flags,
xnn_operator_t* rope_op_out)
{
Expand All @@ -105,7 +92,6 @@ enum xnn_status xnn_create_rope_nthc_f32(
}

return create_rope_nthc(
max_tokens,
flags,
xnn_operator_type_rope_nthc_f32,
config,
Expand Down Expand Up @@ -138,13 +124,6 @@ static enum xnn_status reshape_rope_nthc(
return xnn_status_invalid_parameter;
}

if (tokens > rope_op->max_tokens) {
xnn_log_error(
"failed to reshape %s operator with %zu tokens: number of tokens can not exceed the maximum %zu",
xnn_operator_type_to_string(rope_op->type), tokens, rope_op->max_tokens);
return xnn_status_invalid_parameter;
}

if (heads == 0) {
xnn_log_error(
"failed to reshape %s operator with %zu heads: number of heads must be non-zero",
Expand Down
10 changes: 0 additions & 10 deletions src/subgraph/rope.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,11 @@ static enum xnn_status create_rope_operator(
switch (input_value->datatype) {
case xnn_datatype_fp16:
status = xnn_create_rope_nthc_f16(
node->params.rope.max_tokens,
/*flags=*/0,
&opdata->operator_objects[0]);
break;
case xnn_datatype_fp32:
status = xnn_create_rope_nthc_f32(
node->params.rope.max_tokens,
/*flags=*/0,
&opdata->operator_objects[0]);
break;
Expand Down Expand Up @@ -170,13 +168,6 @@ enum xnn_status xnn_define_rope(
return status;
}

if (max_tokens == 0) {
xnn_log_error(
"failed to define %s operator with %zu max tokens: maximum number of tokens must be non-zero",
xnn_node_type_to_string(xnn_node_type_rope), max_tokens);
return xnn_status_invalid_parameter;
}

status = xnn_subgraph_check_input_node_id(xnn_node_type_rope, input_id, subgraph->num_values);
if (status != xnn_status_success) {
return status;
Expand Down Expand Up @@ -262,7 +253,6 @@ enum xnn_status xnn_define_rope(

node->type = xnn_node_type_rope;
node->compute_type = compute_type;
node->params.rope.max_tokens = max_tokens;
node->num_inputs = 2;
node->inputs[0] = input_id;
node->inputs[1] = weights_id;
Expand Down
1 change: 0 additions & 1 deletion src/xnnpack/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ struct xnn_operator {
size_t group_input_channels;
size_t group_output_channels;
size_t channels;
size_t max_tokens;

uint32_t pad_value;

Expand Down
3 changes: 0 additions & 3 deletions src/xnnpack/subgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,6 @@ struct xnn_node {
size_t new_height;
size_t new_width;
} static_resize;
struct {
size_t max_tokens;
} rope;
struct {
size_t num_dims;
size_t offsets[XNN_MAX_TENSOR_DIMS];
Expand Down
4 changes: 2 additions & 2 deletions test/rope-operator-tester.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class RoPEOperatorTester {
xnn_operator_t rope_op = nullptr;

const xnn_status status = xnn_create_rope_nthc_f16(
/*max_tokens=*/tokens(), /*flags=*/0, &rope_op);
/*flags=*/0, &rope_op);
if (status == xnn_status_unsupported_hardware) {
GTEST_SKIP();
}
Expand Down Expand Up @@ -237,7 +237,7 @@ class RoPEOperatorTester {
xnn_operator_t rope_op = nullptr;

const xnn_status status = xnn_create_rope_nthc_f32(
/*max_tokens=*/tokens(), /*flags=*/0, &rope_op);
/*flags=*/0, &rope_op);
if (status == xnn_status_unsupported_hardware) {
GTEST_SKIP();
}
Expand Down
6 changes: 2 additions & 4 deletions test/rope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ TEST_F(RoPETestF16, define)
const struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_rope);
ASSERT_EQ(node->compute_type, xnn_compute_type_fp16);
ASSERT_EQ(node->params.rope.max_tokens, max_tokens);
ASSERT_EQ(node->num_inputs, 2);
ASSERT_EQ(node->inputs[0], input_id);
ASSERT_EQ(node->inputs[1], weights_id);
Expand Down Expand Up @@ -143,7 +142,6 @@ TEST_F(RoPETestF32, define)
const struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_rope);
ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
ASSERT_EQ(node->params.rope.max_tokens, max_tokens);
ASSERT_EQ(node->num_inputs, 2);
ASSERT_EQ(node->inputs[0], input_id);
ASSERT_EQ(node->inputs[1], weights_id);
Expand All @@ -161,7 +159,7 @@ TEST_F(RoPETestF16, matches_operator_api)
std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
std::generate(weights.begin(), weights.end(), [&]() { return f32dist(rng); });

const xnn_status status = xnn_create_rope_nthc_f16(max_tokens, /*flags=*/0, &op);
const xnn_status status = xnn_create_rope_nthc_f16(/*flags=*/0, &op);
if (status == xnn_status_unsupported_hardware) {
GTEST_SKIP();
}
Expand Down Expand Up @@ -239,7 +237,7 @@ TEST_F(RoPETestF32, matches_operator_api)
std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
std::generate(weights.begin(), weights.end(), [&]() { return f32dist(rng); });

const xnn_status status = xnn_create_rope_nthc_f32(max_tokens, /*flags=*/0, &op);
const xnn_status status = xnn_create_rope_nthc_f32(/*flags=*/0, &op);
if (status == xnn_status_unsupported_hardware) {
GTEST_SKIP();
}
Expand Down

0 comments on commit 743f95f

Please sign in to comment.