Skip to content

Commit

Permalink
Merge branch 'main' into audio-chat-templates
Browse files Browse the repository at this point in the history
  • Loading branch information
zucchini-nlp authored Jan 30, 2025
2 parents 46b4915 + 365fecb commit 7ec8549
Show file tree
Hide file tree
Showing 87 changed files with 792 additions and 1,070 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ If you'd like to play with the examples or need the bleeding edge of the code an
```
git clone https://github.com/huggingface/transformers.git
cd transformers
pip install
pip install .
```

### With conda
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-23-11.html#rel-23-11
FROM nvcr.io/nvidia/pytorch:23.04-py3
FROM nvcr.io/nvidia/pytorch:23.11-py3
LABEL maintainer="Hugging Face"

ARG DEBIAN_FRONTEND=noninteractive
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/data/processors/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def squad_convert_example_to_features(
else:
p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0

pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id)
pad_token_indices = np.where(np.atleast_1d(span["input_ids"] == tokenizer.pad_token_id))
special_token_indices = np.asarray(
tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True)
).nonzero()
Expand Down
17 changes: 11 additions & 6 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,23 +406,28 @@ def prepare_inputs_for_generation(
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)

# 4. Create missing `position_ids` on the fly
attention_mask = (
kwargs.pop("decoder_attention_mask", None) if self.config.is_encoder_decoder else attention_mask
)
attention_mask_key = "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask"
position_ids_key = "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids"
if (
attention_mask is not None
and kwargs.get("position_ids") is None
and "position_ids" in set(inspect.signature(self.forward).parameters.keys())
and kwargs.get(position_ids_key) is None
and position_ids_key in set(inspect.signature(self.forward).parameters.keys())
):
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below)
kwargs[position_ids_key] = position_ids # placed in kwargs for further processing (see below)

# 5. Slice model inputs if it's an input that should have the same length as `input_ids`
for model_input_name in ["position_ids", "token_type_ids"]:
for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:
model_input = kwargs.get(model_input_name)
if model_input is not None:
if past_key_values is not None:
current_input_length = (
model_inputs["inputs_embeds"].shape[1]
if model_inputs["inputs_embeds"] is not None
if model_inputs.get("inputs_embeds") is not None
else model_inputs[input_ids_key].shape[1]
)
model_input = model_input[:, -current_input_length:]
Expand Down Expand Up @@ -469,7 +474,7 @@ def prepare_inputs_for_generation(
past_key_values=past_key_values,
)
if attention_mask is not None:
model_inputs["attention_mask"] = attention_mask
model_inputs[attention_mask_key] = attention_mask

# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/integrations/sdpa_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def sdpa_attention_forward(
if is_causal is None:
is_causal = causal_mask is None and query.shape[2] > 1

# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
# We convert it to a bool for the SDPA kernel that only accepts bools.
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
is_causal = is_causal.item()

attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def get_image_features(
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
**kwargs,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Expand All @@ -300,8 +301,9 @@ def get_image_features(
if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")

kwargs = {k: v for k, v in kwargs.items() if v is not None}
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)

# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
Expand Down Expand Up @@ -422,6 +424,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: torch.Tensor = None,
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -492,6 +495,7 @@ def forward(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)

n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
Expand Down
Loading

0 comments on commit 7ec8549

Please sign in to comment.