Skip to content

Commit

Permalink
Rework inputs preparation for OVModelForCausalLM (huggingface#620)
Browse files Browse the repository at this point in the history
* refactor OVModelForCausalLM class

* rework prepare_inputs_for_generation for OVModelForCausalLM

* refactoring

* Apply suggestions from code review

* fix position ids and add tests
  • Loading branch information
eaidova authored Apr 2, 2024
1 parent 447ef50 commit a48e0ca
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 56 deletions.
96 changes: 40 additions & 56 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
self._original_model = self.model.clone() # keep original model for serialization
self._pkv_precision = Type.f32
self.next_beam_idx = None
self._past_length = 0
self.update_pkv_precision()
if self.is_dynamic:
self.model = self._reshape(self.model, -1, -1)
Expand Down Expand Up @@ -356,19 +357,14 @@ def prepare_inputs(
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> Dict:
if self.use_cache and past_key_values is not None:
input_ids = input_ids[:, -1:]

batch_size = input_ids.shape[0]
if self.config.model_type == "bloom":
batch_size *= self.config.num_attention_heads

inputs = {}
past_len = 0
if not self.stateful:
if past_key_values is not None:
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
past_len = past_key_values[0][1].shape[-2]
if self._pkv_precision == Type.bf16:
# numpy does not support bf16, pretending f16, should change to bf16
past_key_values = tuple(
Expand All @@ -381,8 +377,6 @@ def prepare_inputs(
past_key_values = tuple(
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
)
else:
past_len = past_key_values[0].shape[-2]

# Add the past_key_values to the decoder inputs
inputs = dict(zip(self.key_value_input_names, past_key_values))
Expand Down Expand Up @@ -411,6 +405,8 @@ def prepare_inputs(
# Set initial value for the next beam_idx input that will be used at the current iteration
# and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
self.next_beam_idx = np.arange(batch_size, dtype=int)
self._past_length = 0
past_len = self._get_past_length(past_key_values)

inputs["input_ids"] = np.array(input_ids)
# Add the attention_mask inputs when needed
Expand All @@ -432,7 +428,7 @@ def prepare_inputs(
position_ids = np.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1
if past_key_values:
position_ids = np.expand_dims(position_ids[:, -1], axis=-1)
position_ids = position_ids[:, -input_ids.shape[1] :]

inputs["position_ids"] = position_ids

Expand Down Expand Up @@ -470,6 +466,7 @@ def forward(
# the first condition at the function beginning above.
# It should be something that is not None and it should be True when converted to Boolean.
past_key_values = ((),)
self._past_length += input_ids.shape[1]

if not self.stateful:
if self.use_cache:
Expand All @@ -485,19 +482,32 @@ def forward(

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)

# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
# Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

if past_key_values is not None:
past_len = self._get_past_length(past_key_values)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_len) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_len < input_ids.shape[1]:
input_ids = input_ids[:, past_len:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
if attention_mask is not None and position_ids is None and "position_ids" in self.input_names:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
position_ids = position_ids[:, -input_ids.shape[1] :]

return {
"input_ids": input_ids,
Expand All @@ -507,6 +517,24 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
"attention_mask": attention_mask,
}

def _get_past_length(self, past_key_values=None):
if past_key_values is None:
return 0
if self.stateful:
return self._past_length
if self.config.model_type in MULTI_QUERY_ATTN_MODELS:
return past_key_values[0].shape[-2]
seq_length_dim = -2
if self.config.model_type == "chatglm":
seq_length_dim = 0
elif self.config.model_type == "qwen":
seq_length_dim = 1
# input is tuple of pairs
if isinstance(past_key_values[0], (tuple, list)):
return past_key_values[0][1].shape[seq_length_dim]
# past key values comes after flattening
return past_key_values[1].shape[seq_length_dim]

# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
def _reorder_cache(
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
Expand Down Expand Up @@ -573,10 +601,6 @@ def _from_pretrained(
model_type = config.model_type.replace("_", "-")
if model_type == "bloom":
init_cls = OVBloomForCausalLM
elif model_type == "mpt":
init_cls = OVMPTForCausalLM
elif model_type == "opt":
init_cls = OVOPTForCausalLM
elif model_type == "gpt-bigcode":
init_cls = OVGPTBigCodeForCausalLM
else:
Expand Down Expand Up @@ -630,22 +654,12 @@ def _from_pretrained(
class OVBloomForCausalLM(OVModelForCausalLM):
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

# only last token for input_ids if past is not None
if past_key_values and not self.stateful:
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
}
return super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, **kwargs)

# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
def _reorder_cache(
Expand Down Expand Up @@ -712,36 +726,6 @@ def _convert_to_standard_cache(
)


class OVOPTForCausalLM(OVModelForCausalLM):
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
}


class OVMPTForCausalLM(OVModelForCausalLM):
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
}


class OVGPTBigCodeForCausalLM(OVModelForCausalLM):
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
def _reorder_cache(
Expand Down
5 changes: 5 additions & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,11 @@ def test_multiple_inputs(self, model_arch):
outputs = model.generate(**tokens, generation_config=generation_config)
self.assertIsInstance(outputs, torch.Tensor)
self.assertEqual(outputs.shape[0], 3)
# test that generation result is reproducible
outputs2 = model.generate(**tokens, generation_config=generation_config)
self.assertIsInstance(outputs2, torch.Tensor)
self.assertEqual(outputs2.shape[0], 3)
self.assertTrue(torch.allclose(outputs2, outputs))
del model
gc.collect()

Expand Down

0 comments on commit a48e0ca

Please sign in to comment.