Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨All attention refactor🚨 #35235

Merged
merged 99 commits into from
Dec 18, 2024
Merged

🚨All attention refactor🚨 #35235

merged 99 commits into from
Dec 18, 2024

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Dec 12, 2024

What does this PR do?

Todo in this PR:

  • Cohere
  • Chameleon
  • DBRX
  • Gemma
  • Gemma2
  • GLM (modular donc rien à faire je crois)
  • gpt_neoX et GPT2
  • Granite
  • Jamba
  • JetMoe
  • Mimi
  • Mistral
  • Mixtral
  • Mllama
  • Moshi
  • Nemotron
  • OPT
  • Phi
  • Ph3
  • PhiMoe
  • Qwen2
  • qwen2Moe
  • qwen2VL
  • SableML
  • StartCoder2 -> Modular normalement oK
  • Idefics1,2,3
  • Olmo
  • Olmo2
  • Siglip
  • Whisper

@ArthurZucker ArthurZucker force-pushed the all-attention-refactor branch from 0dc9253 to d1aa9ce Compare December 12, 2024 13:49
)


class GradientCheckpointLayer(torch.nn.Module):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should help with kwargs as well

@Cyrilvallez Cyrilvallez force-pushed the all-attention-refactor branch from 8b56823 to ecd814b Compare December 16, 2024 11:28
@Cyrilvallez
Copy link
Member

BTW, maybe it would be safer to check {p.device for p in model.parameters()} | {p.device for p in model.buffers()}?

@BenjaminBossan
Copy link
Member

Thanks a lot @Cyrilvallez I did not notice that. Indeed, your change fixes the failure.

BTW, maybe it would be safer to check {p.device for p in model.parameters()} | {p.device for p in model.buffers()}?

Thanks for the suggestion. In the original unit test, this is just a sanity check, the proper testing comes further down, I just removed that part for clarity.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Jan 7, 2025
See
huggingface/transformers#35235 (comment)
for context.

There has been a refactor in transformers that resulted in the rotary
embedding of Mistral (and probably others) moving to the model level.
This led to a device map used in one of the tests to being incorrect.
This PR fixes the device map.

Note that this fix doesn't really have anything to do with prefix
tuning, the error occurred even before prefix tuning is used.
@ArthurZucker
Copy link
Collaborator Author

Yeah sorry @BenjaminBossan, it's just that we don't use that attribute so not sure we are gonna add it back! THe PR is breaking as the name indicates! But I hope it was not too much trouble!

@BenjaminBossan
Copy link
Member

Thanks for letting me know, it should be an easy fix on the PEFT side.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Jan 8, 2025
The changes in huggingface/transformers#35235
resulted in a couple of adaption prompt tests to fail. This PR fixes
these failures while maintaining compatibility with older transformers
versions.

Required changes:

- hidden_size attribute removed from model, now config.hidden_size
- num_heads attribute removed from model, now config.num_attention_heads
- forward now returns 2 outputs instead of 3, rewritten to be agnostic
  towards the number of outputs
githubnemo pushed a commit to huggingface/peft that referenced this pull request Jan 10, 2025
See
huggingface/transformers#35235 (comment)
for context.

There has been a refactor in transformers that resulted in the rotary
embedding of Mistral (and probably others) moving to the model level.
This led to a device map used in one of the tests to being incorrect.
This PR fixes the device map.

Note that this fix doesn't really have anything to do with prefix
tuning, the error occurred even before prefix tuning is used.
BenjaminBossan added a commit to huggingface/peft that referenced this pull request Jan 10, 2025
The changes in huggingface/transformers#35235
resulted in a couple of adaption prompt tests to fail. This PR fixes
these failures while maintaining compatibility with older transformers
versions.

Required changes:

- hidden_size attribute removed from model, now config.hidden_size
- num_heads attribute removed from model, now config.num_attention_heads
- forward now returns 2 outputs instead of 3, rewritten to be agnostic
  towards the number of outputs
@foreverpiano
Copy link

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

AttributeError: 'MistralAttention' object has no attribute 'num_heads'

How can I fix this?

@ArthurZucker
Copy link
Collaborator Author

Hey! you should try to use the latest release of transformers! query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) is what's used now.

@ArthurZucker
Copy link
Collaborator Author

Is this by any chance related to AWQ or another package?

@foreverpiano
Copy link

foreverpiano commented Jan 13, 2025

Is there any doc about how to migrate from previous version to this version, like the variable definition, the alias change?

@foreverpiano
Copy link

foreverpiano commented Jan 13, 2025

Have you tested on several benchmarks about the performance? I knew that the Longbench score on transformer v4.47 vs v4.36 varies a lot on llama-3. Is it stable on this version?
I suggest adding some simple and small dataset tests.

@Cyrilvallez
Copy link
Member

Hey! Everything stays the same in terms of user experience/benchmark scores. If you used to hack into the different Layer classes however, it may have changed a bit. You can simply go and check-out the modeling code in this case (as was the case if you hacked into it in the first place I guess!)

loadams added a commit to microsoft/DeepSpeed that referenced this pull request Jan 13, 2025
Breaking change in transformers is
huggingface/transformers#35235. Need to make
changes to unpin nv-a6000 workflow.
@poedator
Copy link
Contributor

poedator commented Jan 14, 2025

My friends use a GPT2Model in production and want to compile it with StaticCache. With the maintainers blessing, I would try to create a PR with DynamicCache / StaticCache support in GPT2Model.
I am quite familiar with Cache class, I already coded some and made the DynamicCache work.

Please let me know if there are any hidden obstacles in Cache implementation for GPT2? Which tests to run or add?
@ArthurZucker

@Rocketknight1
Copy link
Member

cc @gante to that question!

@gante
Copy link
Member

gante commented Jan 15, 2025

I've chatted to @poedator offline -- I couldn't think of any obstacle in particular, and suggested a) to ensure we leave a deprecation warning regarding the old cache format b) use RUN_SLOW=1 py.test tests/models/gpt2/test_modeling_gpt2.py as a correctness check (gpt2 is fairly well tested, especially wrt text generation)

@poedator
Copy link
Contributor

It looks like test_flash_attn_2_from_config is broken - it expects attention layer to have flashattention in its name,
if "FlashAttention" in module.__class__.__name__:...
but after this refactoring, the attention classes are named differently.

ref

if "FlashAttention" in module.__class__.__name__:

please fix or suspend the test.
@ArthurZucker

@ArthurZucker
Copy link
Collaborator Author

indeed gimme a min!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.