Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StaticWhisperPipeline change to work with optimum models #1103

Merged
merged 3 commits into from
Oct 31, 2024

Conversation

eshiryae
Copy link
Contributor

@eshiryae eshiryae commented Oct 29, 2024

Fixes #895

@github-actions github-actions bot added category: whisper Whisper pipeline category: sampling Sampling / Decoding algorithms labels Oct 29, 2024
src/cpp/src/whisper_pipeline.cpp Outdated Show resolved Hide resolved
src/cpp/src/whisper_pipeline_static.cpp Show resolved Hide resolved
src/cpp/src/whisper_pipeline_static.cpp Outdated Show resolved Hide resolved
src/cpp/src/whisper_pipeline.cpp Outdated Show resolved Hide resolved
@github-actions github-actions bot added category: samples GenAI samples and removed category: whisper Whisper pipeline labels Oct 30, 2024
@eshiryae eshiryae marked this pull request as ready for review October 30, 2024 16:26
@ilya-lavrenov ilya-lavrenov added this to the 2024.5 milestone Oct 30, 2024
preprocessor.input(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.input(tensor.get_any_name()).preprocess().convert_element_type();

// if (tensor.get_any_name().find(".value") != std::string::npos) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems redundant, do we need to keep it?

preprocessor.output(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.output(tensor.get_any_name()).postprocess().convert_element_type();

// if (tensor.get_any_name().find(".value") != std::string::npos) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

@@ -160,10 +171,9 @@ int64_t decode_with_past(ov::InferRequest& decoder_with_past,
const std::vector<int64_t>& generated_tokens) {
// FIXME: Avoid this cast to i32. Why it's not i64 precision in model?
decoder_with_past.get_tensor("input_ids").data<int32_t>()[0] = static_cast<int32_t>(input_id);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since optimum creates i64 dtype for input_ids, do we still need this cast? This cast was initially required to support NPU-friendly models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it's attempt to align model generated on-the-fly to NPU-friendly one. Perhaps just leave it as i64 is fine

Copy link
Collaborator

@TolyaTalamanov TolyaTalamanov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slight revision of preprocessing stuff perhaps still needed here, but it's not critical for now, let's merge it!

void preprocess_encoder(std::shared_ptr<ov::Model> model) {
ov::preprocess::PrePostProcessor preprocessor(model);

preprocessor.input("input_features").tensor().set_element_type(ov::element::Type_t::f32);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's already f32, isn't it?

pm.run_passes(model);
}

void reshape_to_static(std::shared_ptr<ov::Model> model, const uint32_t input_size, const uint32_t kvcache_size) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps should be separate function for every model to avoid confusion

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Having separate reshape_to_static for decoder and decoder with past will be helpful.

// preprocessor.output(tensor.get_any_name()).tensor().set_layout(ov::Layout("NCWH"));
// preprocessor.output(tensor.get_any_name()).model().set_layout(ov::Layout("NCHW"));
//} else if (tensor.get_any_name().find(".key") != std::string::npos) {
// preprocessor.output(tensor.get_any_name()).tensor().set_layout(ov::Layout("NCHW"));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove it with refactoring PR

const auto& partial_shape = input.get_partial_shape();
new_shape = partial_shape;
new_shape[0] = 1; // batch_dim
new_shape[1] = 1500; // FIXME: is it got from encoder output{'last_hidden_state'}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Encoder hidden states is not needed as inputs to decoder with past.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is however a required input for decoder model and static shapes are 1 for batch and 1500 for encoder sequence length as you have added. However, even the last dimension is also dynamic and varies with model (checked optimum exported model using transformers v4.45.2 and optimum-intel v1.20.0.

Ideally, the encoder output 'last_hidden_state' dimension can be used to reshape the encoder_hidden_states input to decoder. This will be straightforward.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should be split on 3 (encoder, decoder, decoder_with_past) to avoid further confusion - let's do this clean up after main part is merged

pm.run_passes(model);
}

void add_attention_mask_input(std::shared_ptr<ov::Model> model) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are there 2 add_attention_mask functions?

new_shape = ov::PartialShape({1, input_size});
} else if (input_name.find("attention_mask") != std::string::npos) {
new_shape = ov::PartialShape({1, kvcache_size + 1});
} else if (input_name.find("position_ids") != std::string::npos) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

position_ids is now deprecated as inputs, replaced with cache_position. May be removed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it's left from non optimum-cli models, should be removed ofc

const auto& partial_shape = input.get_partial_shape();
new_shape = partial_shape;
new_shape[0] = 1; // batch_dim
new_shape[1] = 1500; // FIXME: is it got from encoder output{'last_hidden_state'}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is however a required input for decoder model and static shapes are 1 for batch and 1500 for encoder sequence length as you have added. However, even the last dimension is also dynamic and varies with model (checked optimum exported model using transformers v4.45.2 and optimum-intel v1.20.0.

Ideally, the encoder output 'last_hidden_state' dimension can be used to reshape the encoder_hidden_states input to decoder. This will be straightforward.

if (input_name.find("input_ids") != std::string::npos) {
new_shape = ov::PartialShape({1, input_size});
} else if (input_name.find("attention_mask") != std::string::npos) {
new_shape = ov::PartialShape({1, kvcache_size + 1});

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since 448 is passed as kvcache_size, using kvcache_size + 1 creates attention_mask of size 1, 449 which is wrong as 448 is max supported by model.

Use 1, kvcache_size for mask, as the past_key_values is later reshaped to kvcache_size - 1.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it's the matter of naming. The real kv cache size is 449 in this case though

@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Oct 30, 2024
@ilya-lavrenov ilya-lavrenov added this pull request to the merge queue Oct 31, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Oct 31, 2024
@ilya-lavrenov ilya-lavrenov added this pull request to the merge queue Oct 31, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Oct 31, 2024
@Wovchena Wovchena added this pull request to the merge queue Oct 31, 2024
@andrei-kochin andrei-kochin removed this pull request from the merge queue due to a manual request Oct 31, 2024
@andrei-kochin andrei-kochin merged commit cb2d527 into openvinotoolkit:master Oct 31, 2024
49 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: samples GenAI samples category: sampling Sampling / Decoding algorithms Code Freeze
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Dynamic Shape Issue When Run Whisper On NPU
6 participants