Skip to content

Commit

Permalink
make position_ids optional
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Jan 29, 2025
1 parent 78ac6df commit 820dbbf
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 62 deletions.
25 changes: 16 additions & 9 deletions onnxruntime/contrib_ops/cpu/onnx_std_exp/rotary_embedding_onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,20 @@ Status RunRotaryEmbeddingONNX(concurrency::ThreadPool* tp, RotaryParameters para
const T* input_data = input + block_offset;
T* output_data = output + block_offset;

// Cache is (M, H/2) or (M, rotary_embedding_dim/2)
const int position_id = (position_ids_format == 0)
? static_cast<int>(position_ids[0]) + s
: static_cast<int>(position_ids[b * sequence_length + s]);
const int cache_offset = position_id * half_rotary_emb_dim;
const T* cos_data = cos_cache + cache_offset;
const T* sin_data = sin_cache + cache_offset;
const T* cos_data;
const T* sin_data;
int cache_offset;
if (position_ids_format == -1) {
cache_offset = (b * sequence_length + s) * half_rotary_emb_dim;
} else {
// Cache is (M, H/2) or (M, rotary_embedding_dim/2)
const int position_id = (position_ids_format == 0)
? static_cast<int>(position_ids[0]) + s
: static_cast<int>(position_ids[b * sequence_length + s]);
cache_offset = position_id * half_rotary_emb_dim;
}
cos_data = cos_cache + cache_offset;
sin_data = sin_cache + cache_offset;

MlasRotaryEmbedOneRow<T>(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data);

Expand Down Expand Up @@ -103,7 +110,7 @@ Status RotaryEmbeddingONNX<T>::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* cos_cache = context->Input<Tensor>(1);
const Tensor* sin_cache = context->Input<Tensor>(2);
const Tensor* position_ids = context->Input<Tensor>(3);
const Tensor* position_ids = context->Input<Tensor>(3); // position_ids are optional

RotaryParameters parameters = {};
ORT_RETURN_IF_ERROR(rotary_embedding_onnx_helper::CheckInputs<Tensor>(input,
Expand All @@ -124,7 +131,7 @@ Status RotaryEmbeddingONNX<T>::Compute(OpKernelContext* context) const {
const T* input_src = input->Data<T>();
const T* cos_cache_data = cos_cache->Data<T>();
const T* sin_cache_data = sin_cache->Data<T>();
const int64_t* pos_ids_data = position_ids->Data<int64_t>();
const int64_t* pos_ids_data = (nullptr == position_ids) ? nullptr : position_ids->Data<int64_t>();
T* output_dest = output->MutableData<T>();

AllocatorPtr allocator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,34 +44,12 @@ Status CheckInputs(const T* input,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 or 4 dimensions, got ",
input_dims.size());
}
// Check cos_cache and sin_cache
const auto& cos_cache_dims = cos_cache->Shape().GetDims();
if (cos_cache_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ",
cos_cache_dims.size());
}
const auto& sin_cache_dims = sin_cache->Shape().GetDims();
if (sin_cache_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ",
sin_cache_dims.size());
}
if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ",
"the same shape");
}
// Check position_ids
const auto& position_ids_dims = position_ids->Shape().GetDims();
if (!onnxruntime::IsScalarOr1ElementVector(position_ids) && position_ids_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 0, 1, or 2 ",
"dimensions, got ", position_ids_dims.size());
}

// Check num_heads and rotary_embedding_dim
if (rotary_embedding_dim > 0 && num_heads == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads must be provided if rotary_embedding_dim is ",
"specified");
"specified");
}

// Get attributes from inputs
int batch_size = static_cast<int>(input_dims[0]);
int sequence_length = static_cast<int>(input_dims[1]);
Expand All @@ -83,41 +61,102 @@ Status CheckInputs(const T* input,
hidden_size = static_cast<int>(input_dims[2]) * static_cast<int>(input_dims[3]);
transposed = true;
}
int max_sequence_length = static_cast<int>(cos_cache_dims[0]);
int head_size = rotary_embedding_dim == 0 ? static_cast<int>(cos_cache_dims[1]) * 2
: static_cast<int>(hidden_size / num_heads);
if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ",
"head_size");
}

int position_ids_format = -1;
int max_sequence_length;
int head_size;

// Check position_ids input shapes
if (!onnxruntime::IsScalarOr1ElementVector(position_ids)) {
if (batch_size != static_cast<int>(position_ids_dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size ",
"batch_size, got ", position_ids_dims[0]);
if (nullptr == position_ids) {
// Check cos_cache and sin_cache
const auto& cos_cache_dims = cos_cache->Shape().GetDims();
if (cos_cache_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 3 dimensions, got ",
cos_cache_dims.size());
}
if (sequence_length != static_cast<int>(position_ids_dims[1])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size ",
"sequence_length, got ", position_ids_dims[1]);
const auto& sin_cache_dims = sin_cache->Shape().GetDims();
if (sin_cache_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 3 dimensions, got ",
sin_cache_dims.size());
}
if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1] || cos_cache_dims[2] != sin_cache_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ",
"the same shape");
}

max_sequence_length = static_cast<int>(cos_cache_dims[1]);
head_size = rotary_embedding_dim == 0 ? static_cast<int>(cos_cache_dims[2]) * 2
: static_cast<int>(hidden_size / num_heads);
if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ",
"head_size");
}
// Check cos_cache input shapes
if (max_sequence_length != static_cast<int>(cos_cache_dims[1])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ",
"max_sequence_length, got ", cos_cache_dims[0]);
}
if ((head_size / 2) != static_cast<int>(cos_cache_dims[2]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast<int>(cos_cache_dims[2]))) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 2 should be same as ",
"head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]);
}
position_ids_format = 1;
} else {
position_ids_format = 0;
}
// Check cos_cache and sin_cache
const auto& cos_cache_dims = cos_cache->Shape().GetDims();
if (cos_cache_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ",
cos_cache_dims.size());
}
const auto& sin_cache_dims = sin_cache->Shape().GetDims();
if (sin_cache_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ",
sin_cache_dims.size());
}
if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ",
"the same shape");
}
// Check position_ids
const auto& position_ids_dims = position_ids->Shape().GetDims();
if (!onnxruntime::IsScalarOr1ElementVector(position_ids) && position_ids_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 0, 1, or 2 ",
"dimensions, got ", position_ids_dims.size());
}

// Check cos_cache input shapes
if (max_sequence_length != static_cast<int>(cos_cache_dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ",
"max_sequence_length, got ", cos_cache_dims[0]);
}
if ((head_size / 2) != static_cast<int>(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast<int>(cos_cache_dims[1]))) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ",
"head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]);
max_sequence_length = static_cast<int>(cos_cache_dims[0]);
head_size = rotary_embedding_dim == 0 ? static_cast<int>(cos_cache_dims[1]) * 2
: static_cast<int>(hidden_size / num_heads);
if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ",
"head_size");
}

// Check position_ids input shapes
if (!onnxruntime::IsScalarOr1ElementVector(position_ids)) {
if (batch_size != static_cast<int>(position_ids_dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size ",
"batch_size, got ", position_ids_dims[0]);
}
if (sequence_length != static_cast<int>(position_ids_dims[1])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size ",
"sequence_length, got ", position_ids_dims[1]);
}
position_ids_format = 1;
} else {
position_ids_format = 0;
}

// Check cos_cache input shapes
if (max_sequence_length != static_cast<int>(cos_cache_dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ",
"max_sequence_length, got ", cos_cache_dims[0]);
}
if ((head_size / 2) != static_cast<int>(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast<int>(cos_cache_dims[1]))) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ",
"head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]);
}
}


num_heads = num_heads > 0 ? num_heads : static_cast<int>(hidden_size / head_size);
// Calculate stride values
int head_stride;
Expand Down
107 changes: 103 additions & 4 deletions onnxruntime/test/contrib_ops/rotary_embedding_onnx_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,18 @@ static void RunTest(
int hidden_size = num_heads * head_size;
std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
std::vector<int64_t> pos_dims;
std::vector<int64_t> cache_dims = {max_sequence_length, head_size / 2};
std::vector<int64_t> cache_dims;
if (position_ids.size() != 0) {
cache_dims = {max_sequence_length, head_size / 2};
} else {
cache_dims = {batch_size, sequence_length, head_size / 2};
}

assert(hidden_size != 0 && head_size != 0 && num_heads != 0 && max_sequence_length != 0);
assert(max_sequence_length >= sequence_length);
if (position_ids.size() == 1) {
if (position_ids.size() == 0) {
pos_dims = {};
} else if (position_ids.size() == 1) {
pos_dims = {1};
} else {
pos_dims = {batch_size, sequence_length};
Expand Down Expand Up @@ -69,13 +76,21 @@ static void RunTest(
test.AddInput<float>("input", input_dims, input_data);
test.AddInput<float>("cos_cache", cache_dims, cos_cache);
test.AddInput<float>("sin_cache", cache_dims, sin_cache);
test.AddInput<int64_t>("position_ids", pos_dims, position_ids);
if (position_ids.size()) {
test.AddInput<int64_t>("position_ids", pos_dims, position_ids);
} else {
test.AddOptionalInputEdge<int64_t>();
}
test.AddOutput<float>("output", input_dims, output_data);
} else {
test.AddInput<MLFloat16>("input", input_dims, ToFloat16(input_data));
test.AddInput<MLFloat16>("cos_cache", cache_dims, ToFloat16(cos_cache));
test.AddInput<MLFloat16>("sin_cache", cache_dims, ToFloat16(sin_cache));
test.AddInput<int64_t>("position_ids", pos_dims, position_ids);
if (position_ids.size()) {
test.AddInput<int64_t>("position_ids", pos_dims, position_ids);
} else {
test.AddOptionalInputEdge<int64_t>();
}
test.AddOutput<MLFloat16>("output", input_dims, ToFloat16(output_data));
}
test.SetOutputAbsErr("output", 0.002f);
Expand Down Expand Up @@ -629,5 +644,89 @@ TEST(RotaryEmbeddingONNXTest, RotaryEmbeddingONNX_NotInterleaved_SmallData_Llama
interleaved);
}

// Interleaved = false, pos ids = nullptr
TEST(RotaryEmbeddingONNXTest, RotaryEmbeddingONNX_NotInterleaved_NoPosIds_SmallData_LlamaMSFT) {
int batch_size = 1;
int sequence_length = 2;
int num_heads = 3;
int head_size = 6;
int max_sequence_length = 4;
int64_t interleaved = 0; // false

std::vector<float> input_data = {
-1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f,
-0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f,
-0.9320f, -0.8579f, -1.0574f, -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f,
-0.8480f, 0.5266f, -1.2944f, -0.0243f, -0.2354f, -0.7087f, -0.9647f, -0.0991f,
-0.2994f, -0.0650f, -1.5720f, -1.3211f};

std::vector<float> cos_cache = {
1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f};

std::vector<float> sin_cache = {
0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f};

std::vector<int64_t> position_ids = {};

std::vector<float> output_data = {
-1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f,
-0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f,
-0.9320f, -0.8579f, -0.8618f, -0.0922f, -0.9073f, -0.7032f, -0.5762f, -0.2371f,
-0.4377f, 0.5370f, -1.2929f, -0.7267f, -0.2107f, -0.7115f, -0.4666f, -0.0261f,
-0.2965f, -0.8469f, -1.5749f, -1.3217f};

RunTests(input_data,
cos_cache,
sin_cache,
position_ids,
output_data,
batch_size,
sequence_length,
head_size,
num_heads,
max_sequence_length,
interleaved);
}

// Interleaved = true, pos ids = nullptr
TEST(RotaryEmbeddingONNXTest, RotaryEmbedding_ONNX_Interleaved_NoPosIds_SmallData_LlamaMSFT) {
int batch_size = 1;
int sequence_length = 3;
int num_heads = 2;
int head_size = 4;
int max_sequence_length = 8;
int64_t interleaved = 1; // true

std::vector<float> input_data = {
-1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f,
-1.2188f, 1.1676f, -1.0574f, -0.1188f, -0.7396f, -1.2425f, -0.1752f, 0.6990f,
-0.8110f, 0.6737f, -1.1233f, -0.0919f, -0.6861f, 0.7202f, 0.1963f, 0.6142f};

std::vector<float> cos_cache = {
1.0000f, 1.0000f, 0.5403f, 0.9999f, -0.4161f, 0.9998f};

std::vector<float> sin_cache = {
0.0000f, 0.0000f, 0.8415f, 0.0100f, 0.9093f, 0.0200f};

std::vector<float> output_data = {
-1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f,
-1.6411f, -0.3948f, -1.0561f, -0.1294f, 0.6460f, -1.2937f, -0.1822f, 0.6972f,
-0.2751f, -1.0178f, -1.1212f, -0.1143f, -0.3694f, -0.9235f, 0.1840f, 0.6180f};

std::vector<int64_t> position_ids = {};

RunTests(input_data,
cos_cache,
sin_cache,
position_ids,
output_data,
batch_size,
sequence_length,
head_size,
num_heads,
max_sequence_length,
interleaved);
}

} // namespace test
} // namespace onnxruntime

0 comments on commit 820dbbf

Please sign in to comment.