-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Mamba2
] Fix caching, slow path, and multi-gpu
#35154
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some comments for clarification
# Only left padding is valid | ||
attention_mask = torch.ones(size=(self.batch_size, self.seq_length), device=input_ids.device, dtype=torch.long) | ||
attention_mask[0, :1] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a mask, maybe for some other tests as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, is it intended that it is only tuned out for the first element of the batch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tbh, that was pretty willy nilly from me; could definitely be changed just wanted to debug and see if stuff works
…generate (gives total ids + mask at each step)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more comments for the cache
Integration tests will probably need an update but I don't have a GPU for the 7B atm. Edit: If you could update these integration tests, then gladly :D especially since I'm on vacay very soon |
Hey @vasqu thanks! Taking a look in a min |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @vasqu thanks a bunch! left a couple questions/comments but looks good
(batch_size, self.num_heads, self.head_dim, self.ssm_state_size), | ||
device=hidden_states.device, dtype=dtype | ||
# 2. Convolution sequence transformation | ||
if cache_params is not None and cache_position is not None and cache_position[0] > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently will break torch compile, FWIW
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the triton kernels themself are not easy to compile atm either way but this should definitely be handled properly in the future. FYI, you would need to register fake ops for torch to make it properly work which would entail some separate mamba2 utils for the kernel - see https://github.com/facebookresearch/lingua/tree/main/apps/mamba/component
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah not actionable immediately but would be nice to have in a near future! thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely! Would love to see it :)
out = self.out_proj(scan_output) | ||
return out | ||
|
||
# fmt: off | ||
def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): | ||
batch_size, seq_len, _ = input_states.shape | ||
dtype = input_states.dtype | ||
# Gated MLP's linear projection | ||
projected_states = self.in_proj(input_states.squeeze(1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this - seems improvable yes, but the squeeze is a no-op unless seq_len == 1
, so in caching situation indeed. So we're ending up with a [batch_size, H] tensor instead of a [batch_size, seq_len, H] tensor. Then, we're splitting this one on the last dimension, so it should be fine
# Only left padding is valid | ||
attention_mask = torch.ones(size=(self.batch_size, self.seq_length), device=input_ids.device, dtype=torch.long) | ||
attention_mask[0, :1] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, is it intended that it is only tuned out for the first element of the batch?
input_states = remove_padding_influence(input_states, attention_mask) | ||
projected_states = self.in_proj(input_states) | ||
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2 | ||
_, _, gate, hidden_states_B_C, dt = projected_states.split( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, now it's aligned with cuda kernel forward in naming. TBH the whole split is the same for cuda and torch so could be factored out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds like a refactor :D would leave this to a separate PR and focus on making things work first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah for sure!
also I added the slow label - feel free to launch a commit with message "[run-slow] mamba2" so we can trigger the slow CI! that way we make sure multi-gpu is indeed fixed |
@molbap Yup, added an empty commit - will get to the comments/review a bit later 🫡 (I could expect some failures on the integration tests, not sure let's see) |
Attempt 2 at multi gpu, at least a different error :p |
Things that remain:
Otherwise, ready to go @molbap Edit: Hub seems to have some unrelated issues |
Mamba2
] Fix Cache and several other small issuesMamba2
] Fix caching, slow path, and multi-gpu
Could you trigger the slow runs @molbap ? I added the slow path fix here and closed the old PR. The only thing that didn't work before was the multi GPU caching so let's see how it goes this time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to run on multi-GPU! Thanks a bunch, @ArthurZucker alright for me once all tests green (hub tests unrelated)
I'm on vacay for a while, so feel free to update / correct some stuff if you want to. Will be back after xmas, so I guess early merry xmas :) |
Got it, will update if needed! Enjoy the holidays @vasqu ! |
Thanks for the effort @vasqu! Just wanted to quick check in to understand the status of current HF implementation of Mamba-2.
|
Hey @HanGuo97, you understood most of the changes imo. I guess only things I'd add on would be the caching issues specifically:
The masks were there before but a bit of a refactor tbh. They're important to ensure that batched inference works correctly - that's something that is incorrect in the original mamba(2) implemenation (except you use varlen implementations, i.e. those that use cu_seq_lens). Lmk if something's unclear! |
Thank you for the clarifications! Is the current PR mostly complete, or can we expect more significant fixes? I’m asking because I’m refactoring the implementation for other reasons and prefer to base it on a more stable version. On a separate note, do you have any suggestions for testing the implementation’s correctness? I was considering comparing it with the original Mamba-2 implementation, but I’m unsure if I might overlook something. |
No problem! I don't think that there will be any more significant changes. So, this should be a good ground to work on. The original implementation is well made as is. I only have my gripes when using batches. Depending on your refactor it might easier to look at earlier versions without cu_seq as it adds a bit of complexity. Otherwise, kind of self-promo but I've also written a mamba2 version back before it came to transformers - maybe that can help ( https://github.com/vasqu/mamba2-torch but look into the open PR if you want correct batched inference). |
Got it, thanks again for the explanations (and enjoy the holidays)! |
Hi @vasqu, this should be merged soon! In terms of precedence I'll add an |
Hey 👋 I don't need direct credit, I just think that the list given in the docstring is misleading: transformers/src/transformers/models/bamba/modular_bamba.py Lines 214 to 218 in 667ed56
The changes are mainly because of the cache + dropping some attributes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
d_mlp = ( | ||
projected_states.shape[-1] | ||
- 2 * self.intermediate_size | ||
- 2 * self.n_groups * self.ssm_state_size | ||
- self.num_heads | ||
) // 2 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is less readable, but I mean mamba in general is hard to read 😄
What does this PR do?
Kind of a follow-up to #34901 as there are some issues in the current code:
Mamba2
] Fix slow path #34901)Fixes #33567
Fixes #34817
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@molbap @ArthurZucker