Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Granite Vision Support #35579

Merged
merged 12 commits into from
Jan 23, 2025
Merged

Conversation

alex-jw-brooks
Copy link
Contributor

@alex-jw-brooks alex-jw-brooks commented Jan 9, 2025

What does this PR do?

This PR adds compatibility for IBM's upcoming Granite Vision models (which are based on LLava Next). The main changes here are:

  • The vision feature layer, which is currently expected to be an integer can now also be a list of integers; if a list of integers are provided, the image features are the concatenated before applying the feature selection strategy
  • The validation which breaks visual encoders with no CLS (discussed a bit here) is removed. I did add a warning if the feature packing explodes with the default strategy as well.

This change was applied in a lot of places to make the checks for repository consistency happy + the config consistent, but the multimodal granite models are instances of LlavaNextForConditionalGeneration. I added a test for each changed model to ensure that things don't blow up if a list of vision feature layers is provided, but if there is another path forward that is preferred to changing several models at the same time to add compatibility with llava next, I'm happy to revise this PR as needed.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts, @qubvel, @zucchini-nlp

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for propagating the changes to all llava models. I think we can also modify vipllava for the sake of consistency

Also, left one question for cases when we have a list vision_feature_select_strategy with default strategy. WDYT, since that seems to be the most intuitive behavior? Or is the Multimodal Granite not cropping CLS tokens?

Comment on lines 313 to 322
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_image_feature = torch.cat(hs_pool, dim=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I think in this case when one has several layers indices and "default" feature selection strategy, one wants to crop CLS token of each layer. I realize this is not a feature used in any of official checkpoints, but if we want to standardize vision_feature_layer to be a Union[int, List[int]] I think that is the expected behavior

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great point! The IBM granite multimodal models use full as the vision feature selection strategy with siglip (no CLS) as the vision encoder, so I hadn't thought about what the behavior of default should be, but I completely agree with you.

I took a look at vipllava's feature selection (this) based on your other comment, and it seems like this is also the behavior of the CLS cropping there - I'll make the change!

Comment on lines 686 to 698
try:
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
except RuntimeError as e:
if vision_feature_select_strategy == "default":
logger.warning_once(
"Image feature shape does not line up with the provided patch size. "
"You may be using the `default` vision_feature_select_strategy with a"
" visual encoder that does not have CLS."
)
raise e
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I don't think try/expect is smth we want in modeling code. Can we frame it as if condition: raise Error, to catch unwanted behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I'll rewrite it to warn conditionally if the vision feature select strategy is default and the product of the shape dims isn't divisible by num_patch_height * num_patch_width * height * width

@alex-jw-brooks
Copy link
Contributor Author

Thanks a lot for the quick review @zucchini-nlp 😄 For vipllava, do you mean allowing the vision_feature_layers to be of type Union[int, List[int]], allowing a strategy, or something else?

@zucchini-nlp
Copy link
Member

@alex-jw-brooks

allowing the vision_feature_layers to be of type Union[int, List[int]]

exactly, this is what I meant so we can be consistent for all llava models :)

@alex-jw-brooks
Copy link
Contributor Author

Awesome, thank you for the clarification @zucchini-nlp 😄 I made the change in vipllava and rebased this PR, it should be ready for another look when you have a moment!

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating, LGTM! Left a tiny comment, seems to be typo :)

src/transformers/models/llava/modeling_llava.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zucchini-nlp
Copy link
Member

Cool, thanks a lot! Let's get one more review from the core maintainer and then we'll merge

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but we need to add a granite_multimodal.md with model details!

@ArthurZucker
Copy link
Collaborator

We would also want to add an integration test with expected values!

@alex-jw-brooks
Copy link
Contributor Author

Awesome, thank you for the feedback @ArthurZucker!

LGTM, but we need to add a granite_multimodal.md with model details!

Given that these models will be run in transformers as an instance of llava next and not their own class, would you prefer that we add a separate granite_multimodal.md, or add a note in the llava next docs? I'm happy to do either, and will update this PR tomorrow, once I've heard back from my colleagues on which paper would be best to cite for these models 😄

We would also want to add an integration test with expected values!

I've added an integration test that works for the version of the 2b model that I have locally. However, since the models are still being finalized (experimental release planned for the next week or so), I've marked it with @unittest.skip - would it be possible to merge this PR so that there is day zero support in transformers, and submit a follow-up PR enabling the integration test with the correct model ID + value once the 2b model is available?

@ArthurZucker
Copy link
Collaborator

A separate granite_multimodal.md would be nice IMO!

@ArthurZucker
Copy link
Collaborator

yep completely possible!

@alex-jw-brooks alex-jw-brooks changed the title Multimodal Granite Support Granite Vision Support Jan 22, 2025
Replace multimodal granite refs with granite vision

Add granite vision / llava next alias

Signed-off-by: Alex-Brooks <[email protected]>
@alex-jw-brooks
Copy link
Contributor Author

Great, thanks so much @ArthurZucker! It looks like the models are actually going to be released as granite vision - I went ahead and updated references from granite multimodal -> granite vision and added a short description of the model + example to docs/source/en/model_doc/granitevision.md. It should be ready for another look when you have a moment!

It seems like there will be a publication, but it's not available quite yet - I'll be sure to open follow-up PRs for the integration test and doc updates (e.g., paper citation and model ID fix if the name ends up incorrect in the example) in the coming weeks as things are published 😄

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 nit and let's merge! 🤗


# prepare image and text prompt, using the appropriate prompt template
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
image = Image.open(requests.get(url, stream=True).raw)

you no longer need this! image is processed directly!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, for that you also need to change the chat template as below:

conversation = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": url},
            {"type": "text", "text": "What is shown in this image?"},
        ],
    },
]
inputs = processor.apply_chat_template(
    conversation,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
    return_tensors="pt"
).to("cuda")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thanks for all your help @ArthurZucker @zucchini-nlp - I've updated the example to use the url directly instead! 😄

@ArthurZucker ArthurZucker merged commit 71cc816 into huggingface:main Jan 23, 2025
23 checks passed
@ArthurZucker
Copy link
Collaborator

Congrats on the merge! 🚀

bursteratom pushed a commit to bursteratom/transformers that referenced this pull request Jan 31, 2025
* Add multimodal granite support

Signed-off-by: Alex-Brooks <[email protected]>

Support multiple image feature layres

Signed-off-by: Alex-Brooks <[email protected]>

* Remove failing validation for visual encoders with no cls

Signed-off-by: Alex-Brooks <[email protected]>

* Update llava based models / configs to support list of feature layers

Signed-off-by: Alex-Brooks <[email protected]>

* Add tests for multiple feature layers

Signed-off-by: Alex-Brooks <[email protected]>

* Use conditional instead of except for misaligned feature shapes

Signed-off-by: Alex-Brooks <[email protected]>

* crop cls from each hidden state

Signed-off-by: Alex-Brooks <[email protected]>

* Fix formatting

Signed-off-by: Alex-Brooks <[email protected]>

* Support single vision feature int in vipllava

Signed-off-by: Alex-Brooks <[email protected]>

* Fix typo in vision feature selection strategy validation

Signed-off-by: Alex-Brooks <[email protected]>

* Add tentative integration test for granite vision models

Signed-off-by: Alex-Brooks <[email protected]>

* Add granite vision docs

Replace multimodal granite refs with granite vision

Add granite vision / llava next alias

Signed-off-by: Alex-Brooks <[email protected]>

* Use image url in granitevision example

Signed-off-by: Alex-Brooks <[email protected]>

---------

Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants