-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
StaticWhisperPipeline change to work with optimum models #1103
StaticWhisperPipeline change to work with optimum models #1103
Conversation
9b53403
to
ba2370e
Compare
ba2370e
to
e2557b2
Compare
preprocessor.input(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16); | ||
preprocessor.input(tensor.get_any_name()).preprocess().convert_element_type(); | ||
|
||
// if (tensor.get_any_name().find(".value") != std::string::npos) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems redundant, do we need to keep it?
preprocessor.output(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16); | ||
preprocessor.output(tensor.get_any_name()).postprocess().convert_element_type(); | ||
|
||
// if (tensor.get_any_name().find(".value") != std::string::npos) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
@@ -160,10 +171,9 @@ int64_t decode_with_past(ov::InferRequest& decoder_with_past, | |||
const std::vector<int64_t>& generated_tokens) { | |||
// FIXME: Avoid this cast to i32. Why it's not i64 precision in model? | |||
decoder_with_past.get_tensor("input_ids").data<int32_t>()[0] = static_cast<int32_t>(input_id); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since optimum creates i64 dtype for input_ids, do we still need this cast? This cast was initially required to support NPU-friendly models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it's attempt to align model generated on-the-fly to NPU-friendly one. Perhaps just leave it as i64
is fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slight revision of preprocessing stuff perhaps still needed here, but it's not critical for now, let's merge it!
void preprocess_encoder(std::shared_ptr<ov::Model> model) { | ||
ov::preprocess::PrePostProcessor preprocessor(model); | ||
|
||
preprocessor.input("input_features").tensor().set_element_type(ov::element::Type_t::f32); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's already f32
, isn't it?
pm.run_passes(model); | ||
} | ||
|
||
void reshape_to_static(std::shared_ptr<ov::Model> model, const uint32_t input_size, const uint32_t kvcache_size) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps should be separate function for every model to avoid confusion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. Having separate reshape_to_static for decoder and decoder with past will be helpful.
// preprocessor.output(tensor.get_any_name()).tensor().set_layout(ov::Layout("NCWH")); | ||
// preprocessor.output(tensor.get_any_name()).model().set_layout(ov::Layout("NCHW")); | ||
//} else if (tensor.get_any_name().find(".key") != std::string::npos) { | ||
// preprocessor.output(tensor.get_any_name()).tensor().set_layout(ov::Layout("NCHW")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove it with refactoring PR
const auto& partial_shape = input.get_partial_shape(); | ||
new_shape = partial_shape; | ||
new_shape[0] = 1; // batch_dim | ||
new_shape[1] = 1500; // FIXME: is it got from encoder output{'last_hidden_state'} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Encoder hidden states is not needed as inputs to decoder with past.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is however a required input for decoder model and static shapes are 1 for batch and 1500 for encoder sequence length as you have added. However, even the last dimension is also dynamic and varies with model (checked optimum exported model using transformers v4.45.2 and optimum-intel v1.20.0.
Ideally, the encoder output 'last_hidden_state' dimension can be used to reshape the encoder_hidden_states input to decoder. This will be straightforward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function should be split on 3 (encoder
, decoder
, decoder_with_past
) to avoid further confusion - let's do this clean up after main part is merged
pm.run_passes(model); | ||
} | ||
|
||
void add_attention_mask_input(std::shared_ptr<ov::Model> model) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are there 2 add_attention_mask functions?
new_shape = ov::PartialShape({1, input_size}); | ||
} else if (input_name.find("attention_mask") != std::string::npos) { | ||
new_shape = ov::PartialShape({1, kvcache_size + 1}); | ||
} else if (input_name.find("position_ids") != std::string::npos) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
position_ids is now deprecated as inputs, replaced with cache_position. May be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it's left from non optimum-cli models, should be removed ofc
const auto& partial_shape = input.get_partial_shape(); | ||
new_shape = partial_shape; | ||
new_shape[0] = 1; // batch_dim | ||
new_shape[1] = 1500; // FIXME: is it got from encoder output{'last_hidden_state'} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is however a required input for decoder model and static shapes are 1 for batch and 1500 for encoder sequence length as you have added. However, even the last dimension is also dynamic and varies with model (checked optimum exported model using transformers v4.45.2 and optimum-intel v1.20.0.
Ideally, the encoder output 'last_hidden_state' dimension can be used to reshape the encoder_hidden_states input to decoder. This will be straightforward.
if (input_name.find("input_ids") != std::string::npos) { | ||
new_shape = ov::PartialShape({1, input_size}); | ||
} else if (input_name.find("attention_mask") != std::string::npos) { | ||
new_shape = ov::PartialShape({1, kvcache_size + 1}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since 448 is passed as kvcache_size, using kvcache_size + 1 creates attention_mask of size 1, 449 which is wrong as 448 is max supported by model.
Use 1, kvcache_size for mask, as the past_key_values is later reshaped to kvcache_size - 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it's the matter of naming. The real kv cache size is 449 in this case though
Fixes #895