Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Granite Vision Support #35579

Merged
merged 12 commits into from
Jan 23, 2025
Merged
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@
title: Granite
- local: model_doc/granitemoe
title: GraniteMoe
- local: model_doc/granitevision
title: GraniteVision
- local: model_doc/helium
title: Helium
- local: model_doc/herbert
Expand Down
90 changes: 90 additions & 0 deletions docs/source/en/model_doc/granitevision.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Granite Vision

## Overview

The Granite Vision model is a variant of [LLaVA-NeXT](llava_next), leveraging a [Granite](granite) language model alongside a [SigLIP](SigLIP) visual encoder. It utilizes multiple concatenated vision hidden states as its image features, similar to [VipLlava](vipllava). It also uses a larger set of image grid pinpoints than the original LlaVa-NeXT models to support additional aspect ratios.

Tips:
- This model is loaded into Transformers as an instance of LlaVA-Next. The usage and tips from [LLaVA-NeXT](llava_next) apply to this model as well.

- You can apply the chat template on the tokenizer / processor in the same way as well. Example chat format:
```bash
"<|user|>\nWhat’s shown in this image?\n<|assistant|>\nThis image shows a red stop sign.<|end_of_text|><|user|>\nDescribe the image in more details.\n<|assistant|>\n"
```

Sample inference:
```python
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from PIL import Image
import requests

# Note: These docs were written prior to the public model release,
# and this path is subject to change.
# Please see https://huggingface.co/ibm-granite for the current model list.
model_path = "ibm-granite/granite-3.1-2b-instruct-vision"
processor = LlavaNextProcessor.from_pretrained(model_path)

model = LlavaNextForConditionalGeneration.from_pretrained(model_path).to("cuda")

# prepare image and text prompt, using the appropriate prompt template
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"

conversation = [
{
"role": "user",
"content": [
{"type": "image", "url": url},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
inputs = processor.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to("cuda")


# autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=100)

print(processor.decode(output[0], skip_special_tokens=True))
```

This model was contributed by [Alexander Brooks](https://huggingface.co/abrooks9944).

## LlavaNextConfig

[[autodoc]] LlavaNextConfig

## LlavaNextImageProcessor

[[autodoc]] LlavaNextImageProcessor
- preprocess

## LlavaNextProcessor

[[autodoc]] LlavaNextProcessor

## LlavaNextForConditionalGeneration

[[autodoc]] LlavaNextForConditionalGeneration
- forward
3 changes: 3 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
("gptsan-japanese", "GPTSanJapaneseConfig"),
("granite", "GraniteConfig"),
("granitemoe", "GraniteMoeConfig"),
("granitevision", "LlavaNextConfig"),
("graphormer", "GraphormerConfig"),
("grounding-dino", "GroundingDinoConfig"),
("groupvit", "GroupViTConfig"),
Expand Down Expand Up @@ -456,6 +457,7 @@
("gptsan-japanese", "GPTSAN-japanese"),
("granite", "Granite"),
("granitemoe", "GraniteMoeMoe"),
("granitevision", "LLaVA-NeXT"),
("graphormer", "Graphormer"),
("grounding-dino", "Grounding DINO"),
("groupvit", "GroupViT"),
Expand Down Expand Up @@ -725,6 +727,7 @@
("siglip_vision_model", "siglip"),
("chinese_clip_vision_model", "chinese_clip"),
("rt_detr_resnet", "rt_detr"),
("granitevision", "llava_next"),
]
)

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/llava/configuration_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ class LlavaConfig(PretrainedConfig):
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`.
vision_feature_layer (`int`, *optional*, defaults to -2):
The index of the layer to select the vision feature.
vision_feature_layer (`Union[int, List[int]]`, *optional*, defaults to -2):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
image_seq_length (`int`, *optional*, defaults to 576):
Sequence length of one image embedding.
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
Expand Down
49 changes: 35 additions & 14 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,12 @@ class LlavaCausalLMOutputWithPast(ModelOutput):
class LlavaMultiModalProjector(nn.Module):
def __init__(self, config: LlavaConfig):
super().__init__()
# We have hidden_size * the number of vision feature layers
num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
config.vision_config.hidden_size * num_feature_layers,
config.text_config.hidden_size,
bias=config.multimodal_projector_bias,
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(
Expand Down Expand Up @@ -207,8 +211,10 @@ def _init_weights(self, module):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
vision_feature_layer (`int`, *optional*, defaults to -2):
The index of the layer to select the vision feature.
vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`.
Expand Down Expand Up @@ -275,31 +281,46 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m
return model_embeds

def get_image_features(
self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.

Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
vision_feature_layer (`int`):
The index of the layer to select the vision feature.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")

# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature

# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
else:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
# For default; crop CLS from each hidden state in the hidden state pool
if vision_feature_select_strategy == "default":
hs_pool = [hs[:, 1:] for hs in hs_pool]
selected_image_feature = torch.cat(hs_pool, dim=-1)

image_features = self.multi_modal_projector(selected_image_feature)
return image_features

Expand Down Expand Up @@ -396,7 +417,7 @@ def forward(
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ class LlavaNextConfig(PretrainedConfig):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
If `"full"`, the full vision features are used.
vision_feature_layer (`int`, *optional*, defaults to -2):
The index of the layer to select the vision feature.
vision_feature_layer (`Union[int, List[int]]`, *optional*, defaults to -2):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
image_grid_pinpoints (`List`, *optional*, defaults to `[[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]`):
A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list
of the form `(height, width)`.
Expand Down
50 changes: 35 additions & 15 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,12 @@ class LlavaNextCausalLMOutputWithPast(ModelOutput):
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, config: LlavaNextConfig):
super().__init__()
# We have hidden_size * the number of vision feature layers
num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
config.vision_config.hidden_size * num_feature_layers,
config.text_config.hidden_size,
bias=config.multimodal_projector_bias,
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(
Expand Down Expand Up @@ -318,8 +322,10 @@ def _init_weights(self, module):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
vision_feature_layer (`int`, *optional*, defaults to -2):
The index of the layer to select the vision feature.
vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
Expand Down Expand Up @@ -672,18 +678,22 @@ def pack_image_features(self, image_features, image_sizes, vision_feature_select
image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size

if vision_feature_select_strategy == "default":
expected_num_patches = height * width
elif vision_feature_select_strategy == "full":
expected_num_patches = height * width + 1
if expected_num_patches != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")

num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)

if (
np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
and vision_feature_select_strategy == "default"
):
logger.warning_once(
"Image feature shape does not line up with the provided patch size. "
"You may be using the `default` vision_feature_select_strategy with a"
" visual encoder that does not have CLS."
)

image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
Expand Down Expand Up @@ -714,7 +724,7 @@ def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: int,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
):
"""
Expand All @@ -725,8 +735,10 @@ def get_image_features(
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`int`):
The index of the layer to select the vision feature.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Expand All @@ -752,11 +764,19 @@ def get_image_features(
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")

image_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_features.hidden_states[vision_feature_layer]
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_features.hidden_states[vision_feature_layer]
else:
hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_image_feature = torch.cat(hs_pool, dim=-1)

if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature

image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
Expand All @@ -772,7 +792,7 @@ def forward(
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ class LlavaNextVideoConfig(PretrainedConfig):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
If `"full"`, the full vision features are used.
vision_feature_layer (`int`, *optional*, defaults to -2):
The index of the layer to select the vision feature.
vision_feature_layer (`Union[int, List[int]]`, *optional*, defaults to -2):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
image_grid_pinpoints (`List`, *optional*, defaults to `[[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]`):
A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list
of the form `(height, width)`.
Expand Down
Loading