Skip to content

Commit

Permalink
Pixtral: vectorize patch embeddings and enable tests (#35122)
Browse files Browse the repository at this point in the history
* initial POC

* - batch mix feature

* fix tests

* fix tests

* make style

* do not skip and instead fix tests

* update

* return back the test

* correct text with the correct ckpt
  • Loading branch information
zucchini-nlp authored Jan 30, 2025
1 parent 8bc4c89 commit 9725e5b
Show file tree
Hide file tree
Showing 10 changed files with 429 additions and 552 deletions.
6 changes: 5 additions & 1 deletion src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def get_image_features(
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
**kwargs,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Expand All @@ -300,8 +301,9 @@ def get_image_features(
if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")

kwargs = {k: v for k, v in kwargs.items() if v is not None}
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)

# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
Expand Down Expand Up @@ -422,6 +424,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: torch.Tensor = None,
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -492,6 +495,7 @@ def forward(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)

n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
Expand Down
224 changes: 85 additions & 139 deletions src/transformers/models/pixtral/image_processing_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
"""Image processor class for Pixtral."""

import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy as np

from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
pad,
resize,
to_channel_dimension_format,
)
Expand All @@ -31,13 +32,13 @@
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
is_valid_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_vision_available, logging
from ...utils import TensorType, is_vision_available, logging
from ...utils.import_utils import requires_backends


Expand All @@ -48,91 +49,6 @@
import PIL


class BatchMixFeature(BatchFeature):
def to(self, *args, **kwargs) -> "BatchMixFeature":
"""
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
different `dtypes` and sending the `BatchFeature` to a different `device`.
Args:
args (`Tuple`):
Will be passed to the `to(...)` function of the tensors.
kwargs (`Dict`, *optional*):
Will be passed to the `to(...)` function of the tensors.
Returns:
[`BatchFeature`]: The same instance after modification.
"""

def _recursive_to(obj, device, *args, **kwargs):
# Lists can be nested, so keep digging until we hit tensors
if isinstance(obj, list):
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
# cast and send to device
return obj.to(*args, **kwargs)
elif isinstance(obj, torch.Tensor) and device is not None:
# only send to device, don't cast
return obj.to(device=device)
else:
return obj

requires_backends(self, ["torch"])
import torch # noqa

device = kwargs.get("device")
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
# device should be always the first argument
arg = args[0]
if is_torch_dtype(arg):
# The first argument is a dtype
pass
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
device = arg
else:
# it's something else
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")

self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
return self


# Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images
def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
"""
Convert a single image or a list of images to a list of numpy arrays.
Args:
images (`ImageInput`):
A single image or a list of images.
Returns:
A list of numpy arrays.
"""
# If it's a single image, convert it to a list of lists
if is_valid_image(images):
images = [[images]]
# If it's a list of images, it's a single batch, so convert it to a list of lists
elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]):
images = [images]
# If it's a list of batches, it's already in the right format
elif (
isinstance(images, (list, tuple))
and len(images) > 0
and isinstance(images[0], (list, tuple))
and len(images[0]) > 0
and is_valid_image(images[0][0])
):
pass
else:
raise ValueError(
"Invalid input type. Must be a single image, a list of images, or a list of batches of images."
)
return images


# Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white.
def convert_to_rgb(image: ImageInput) -> ImageInput:
"""
Expand Down Expand Up @@ -219,18 +135,6 @@ def get_resize_output_image_size(
return num_height_tokens * patch_height, num_width_tokens * patch_width


# Hack to get tensor conversion used in BatchFeature without batching the images
def _get_is_as_tensor_fns(tensor_type: Union[str, TensorType]) -> Tuple[Callable, Callable]:
return BatchFeature()._get_is_as_tensor_fns(tensor_type)


def convert_to_tensor(array, tensor_type: Union[str, TensorType]) -> Any:
is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type)
if is_tensor(array):
return array
return as_tensor(array)


class PixtralImageProcessor(BaseImageProcessor):
r"""
Constructs a Pixtral image processor.
Expand Down Expand Up @@ -368,6 +272,49 @@ def resize(
**kwargs,
)

def _pad_for_batching(
self,
pixel_values: List[np.ndarray],
image_sizes: List[List[int]],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
Args:
pixel_values (`List[np.ndarray]`):
An array of pixel values of each images of shape (`batch_size`, `height`, `width`, `channels`)
image_sizes (`List[List[int]]`):
A list of sizes for each image in `pixel_values` in (height, width) format.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
Returns:
List[`np.ndarray`]: The padded images.
"""

max_shape = (
max([size[0] for size in image_sizes]),
max([size[1] for size in image_sizes]),
)
pixel_values = [
pad(
image,
padding=((0, max_shape[0] - size[0]), (0, max_shape[1] - size[1])),
data_format=data_format,
input_data_format=input_data_format,
)
for image, size in zip(pixel_values, image_sizes)
]
return pixel_values

def preprocess(
self,
images: ImageInput,
Expand Down Expand Up @@ -449,9 +396,9 @@ def preprocess(

validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)

images_list = make_list_of_images(images)
images = make_list_of_images(images)

if not valid_images(images_list[0]):
if not valid_images(images[0]):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
Expand All @@ -469,57 +416,56 @@ def preprocess(
)

if do_convert_rgb:
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
images = [convert_to_rgb(image) for image in images]

# All transformations expect numpy arrays.
images_list = [[to_numpy_array(image) for image in images] for images in images_list]
images = [to_numpy_array(image) for image in images]

if do_rescale and is_scaled_image(images_list[0][0]):
if do_rescale and is_scaled_image(images[0]):
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)

if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images_list[0][0])
input_data_format = infer_channel_dimension_format(images[0])

batch_images = []
batch_image_sizes = []
for sample_images in images_list:
images = []
image_sizes = []
for image in sample_images:
if do_resize:
image = self.resize(
image=image,
size=size,
patch_size=patch_size,
resample=resample,
input_data_format=input_data_format,
)

if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

if do_normalize:
image = self.normalize(
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
)

images.append(image)
image_sizes.append(get_image_size(image, input_data_format))
batch_images.append(images)
batch_image_sizes.append(image_sizes)

images_list = [
[to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images]
for images in batch_images
]
for image in images:
if do_resize:
image = self.resize(
image=image,
size=size,
patch_size=patch_size,
resample=resample,
input_data_format=input_data_format,
)

if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

if do_normalize:
image = self.normalize(
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
)

image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)

batch_images.append(image)
batch_image_sizes.append(get_image_size(image, data_format))

pixel_values = self._pad_for_batching(
pixel_values=batch_images,
image_sizes=batch_image_sizes,
input_data_format=data_format,
data_format=data_format,
)

# Convert to tensor type outside of BatchFeature to avoid batching the images of different sizes
images_list = [[convert_to_tensor(image, return_tensors) for image in images] for images in images_list]
return BatchMixFeature(data={"pixel_values": images_list, "image_sizes": batch_image_sizes}, tensor_type=None)
return BatchFeature(
data={"pixel_values": pixel_values, "image_sizes": batch_image_sizes}, tensor_type=return_tensors
)


__all__ = ["PixtralImageProcessor"]
Loading

0 comments on commit 9725e5b

Please sign in to comment.