Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mamba2] Fix caching, slow path, and multi-gpu #35154

Merged
merged 16 commits into from
Dec 20, 2024

Conversation

vasqu
Copy link
Contributor

@vasqu vasqu commented Dec 8, 2024

What does this PR do?

Kind of a follow-up to #34901 as there are some issues in the current code:

Fixes #33567
Fixes #34817

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@molbap @ArthurZucker

Copy link
Contributor Author

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some comments for clarification

src/transformers/models/mamba2/modeling_mamba2.py Outdated Show resolved Hide resolved
src/transformers/models/mamba2/modeling_mamba2.py Outdated Show resolved Hide resolved
src/transformers/models/mamba2/modeling_mamba2.py Outdated Show resolved Hide resolved
Comment on lines +106 to +108
# Only left padding is valid
attention_mask = torch.ones(size=(self.batch_size, self.seq_length), device=input_ids.device, dtype=torch.long)
attention_mask[0, :1] = 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a mask, maybe for some other tests as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, is it intended that it is only tuned out for the first element of the batch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh, that was pretty willy nilly from me; could definitely be changed just wanted to debug and see if stuff works

tests/models/mamba2/test_modeling_mamba2.py Show resolved Hide resolved
Copy link
Contributor Author

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more comments for the cache

@vasqu
Copy link
Contributor Author

vasqu commented Dec 8, 2024

Integration tests will probably need an update but I don't have a GPU for the 7B atm.

Edit: If you could update these integration tests, then gladly :D especially since I'm on vacay very soon

@molbap
Copy link
Contributor

molbap commented Dec 9, 2024

Hey @vasqu thanks! Taking a look in a min

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @vasqu thanks a bunch! left a couple questions/comments but looks good

src/transformers/models/mamba2/modeling_mamba2.py Outdated Show resolved Hide resolved
(batch_size, self.num_heads, self.head_dim, self.ssm_state_size),
device=hidden_states.device, dtype=dtype
# 2. Convolution sequence transformation
if cache_params is not None and cache_position is not None and cache_position[0] > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently will break torch compile, FWIW

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the triton kernels themself are not easy to compile atm either way but this should definitely be handled properly in the future. FYI, you would need to register fake ops for torch to make it properly work which would entail some separate mamba2 utils for the kernel - see https://github.com/facebookresearch/lingua/tree/main/apps/mamba/component

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah not actionable immediately but would be nice to have in a near future! thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely! Would love to see it :)

out = self.out_proj(scan_output)
return out

# fmt: off
def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# Gated MLP's linear projection
projected_states = self.in_proj(input_states.squeeze(1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this - seems improvable yes, but the squeeze is a no-op unless seq_len == 1, so in caching situation indeed. So we're ending up with a [batch_size, H] tensor instead of a [batch_size, seq_len, H] tensor. Then, we're splitting this one on the last dimension, so it should be fine

src/transformers/models/mamba2/modeling_mamba2.py Outdated Show resolved Hide resolved
tests/models/mamba2/test_modeling_mamba2.py Show resolved Hide resolved
Comment on lines +106 to +108
# Only left padding is valid
attention_mask = torch.ones(size=(self.batch_size, self.seq_length), device=input_ids.device, dtype=torch.long)
attention_mask[0, :1] = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, is it intended that it is only tuned out for the first element of the batch?

input_states = remove_padding_influence(input_states, attention_mask)
projected_states = self.in_proj(input_states)
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2
_, _, gate, hidden_states_B_C, dt = projected_states.split(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, now it's aligned with cuda kernel forward in naming. TBH the whole split is the same for cuda and torch so could be factored out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a refactor :D would leave this to a separate PR and focus on making things work first

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah for sure!

@molbap molbap added State space models Issues or PRs related to state space models such as mamba, mamba2 run-slow labels Dec 9, 2024
@molbap
Copy link
Contributor

molbap commented Dec 9, 2024

also I added the slow label - feel free to launch a commit with message "[run-slow] mamba2" so we can trigger the slow CI! that way we make sure multi-gpu is indeed fixed

@vasqu
Copy link
Contributor Author

vasqu commented Dec 9, 2024

@molbap Yup, added an empty commit - will get to the comments/review a bit later 🫡

(I could expect some failures on the integration tests, not sure let's see)

@vasqu
Copy link
Contributor Author

vasqu commented Dec 9, 2024

Attempt 2 at multi gpu, at least a different error :p

@vasqu
Copy link
Contributor Author

vasqu commented Dec 9, 2024

Things that remain:

  • Mulit-GPU sigh
  • Supersede slow path fix with same tests included?
  • Compile compatibility (future)
  • Refactor some stuff (future)

Otherwise, ready to go @molbap

Edit: Hub seems to have some unrelated issues

@vasqu vasqu changed the title [Mamba2] Fix Cache and several other small issues [Mamba2] Fix caching, slow path, and multi-gpu Dec 10, 2024
@vasqu vasqu mentioned this pull request Dec 10, 2024
5 tasks
@vasqu
Copy link
Contributor Author

vasqu commented Dec 10, 2024

Could you trigger the slow runs @molbap ? I added the slow path fix here and closed the old PR. The only thing that didn't work before was the multi GPU caching so let's see how it goes this time.

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to run on multi-GPU! Thanks a bunch, @ArthurZucker alright for me once all tests green (hub tests unrelated)

@vasqu
Copy link
Contributor Author

vasqu commented Dec 11, 2024

I'm on vacay for a while, so feel free to update / correct some stuff if you want to. Will be back after xmas, so I guess early merry xmas :)

@molbap
Copy link
Contributor

molbap commented Dec 11, 2024

Got it, will update if needed! Enjoy the holidays @vasqu !

@HanGuo97
Copy link

Thanks for the effort @vasqu! Just wanted to quick check in to understand the status of current HF implementation of Mamba-2.

cuda_kernels_forward

  • The current implementation correct, since most of the proposed changes are style changes.

torch_forward

  • The current implementation is incorrect, and the PR fixes two things.
  1. Reduction dimension.
  2. dt clamping.
  • There are some updates to masking, but not sure when/whether these are used.

Are these statements correct? Thanks in advance!

@vasqu
Copy link
Contributor Author

vasqu commented Dec 11, 2024

Hey @HanGuo97, you understood most of the changes imo. I guess only things I'd add on would be the caching issues specifically:

  • cuda path wrongly initiates the conv cache
  • torch path had some issues with dimensions when caching / inferring (I think you mean this as well)
  • overall cleanup of how inference is prepared as there was a mixup in regards to non-cached inference
  • multi-gpu device management caused issues (when interacting with the cache)

The masks were there before but a bit of a refactor tbh. They're important to ensure that batched inference works correctly - that's something that is incorrect in the original mamba(2) implemenation (except you use varlen implementations, i.e. those that use cu_seq_lens).

Lmk if something's unclear!

@HanGuo97
Copy link

Thank you for the clarifications!

Is the current PR mostly complete, or can we expect more significant fixes? I’m asking because I’m refactoring the implementation for other reasons and prefer to base it on a more stable version.

On a separate note, do you have any suggestions for testing the implementation’s correctness? I was considering comparing it with the original Mamba-2 implementation, but I’m unsure if I might overlook something.

@vasqu
Copy link
Contributor Author

vasqu commented Dec 11, 2024

No problem! I don't think that there will be any more significant changes. So, this should be a good ground to work on.

The original implementation is well made as is. I only have my gripes when using batches. Depending on your refactor it might easier to look at earlier versions without cu_seq as it adds a bit of complexity. Otherwise, kind of self-promo but I've also written a mamba2 version back before it came to transformers - maybe that can help ( https://github.com/vasqu/mamba2-torch but look into the open PR if you want correct batched inference).

@HanGuo97
Copy link

Got it, thanks again for the explanations (and enjoy the holidays)!

@vasqu
Copy link
Contributor Author

vasqu commented Dec 18, 2024

Any reason to withhold this PR? Seems like Bamba (#34982) has been merged and silently merged this PR's fixes (856cb3a) without a (commit) reference / dubious credit (i.e. adopted this "refactor" but fixed bugs themself?).

Failing tests are unrelated (internal hub failures).

@molbap @ArthurZucker

@molbap
Copy link
Contributor

molbap commented Dec 19, 2024

Hi @vasqu, this should be merged soon! In terms of precedence I'll add an From PR ... by @vasqu in comment so future users understand where the fixes come from, not a worry
ping @ArthurZucker for merge!

@molbap molbap requested a review from ArthurZucker December 19, 2024 08:30
@vasqu
Copy link
Contributor Author

vasqu commented Dec 19, 2024

Hey 👋 I don't need direct credit, I just think that the list given in the docstring is misleading:

The are a few differences between this and Mamba2Mixer:
- The variable use_precomputed_states is slightly different due to the HybridCache structure
- There's a few non-obvious bugs fixed with batching in the slow path that exist in main
- Some extra variables that our layer doesn't need have been removed
- We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged

The changes are mainly because of the cache + dropping some attributes.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @vasqu 😉
Your last comment is completely aligned with our philosophy: if indeed bamba is now the same, we shall add bamba with modular, isolating the differences if there are any left! cc @molbap on this! Merry Christmas as well!

Comment on lines +316 to +322
d_mlp = (
projected_states.shape[-1]
- 2 * self.intermediate_size
- 2 * self.n_groups * self.ssm_state_size
- self.num_heads
) // 2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is less readable, but I mean mamba in general is hard to read 😄

@ArthurZucker ArthurZucker merged commit 5a2aedc into huggingface:main Dec 20, 2024
16 checks passed
@vasqu vasqu deleted the fix-mamba2-caching branch December 20, 2024 15:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-slow State space models Issues or PRs related to state space models such as mamba, mamba2
Projects
None yet
4 participants