-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Zamba2 #34517
Add Zamba2 #34517
Conversation
Rebase zamba2
Hey @Arthur, Thank you again for your help in getting Zamba2 into A few remarks, mostly related to
I carefully compared
Looking forward to your feedback. Thanks so much! |
rebase on upstream
Hello @Cyrilvallez, I ran all model tests on two GPUs and after a couple of minor fixes everything appears to work now. I'm skipping this test as it gives an error related to mamba2 kernels. I indeed verified that mamba2 skips that test here. Separately, when running
which I was not getting before despite this part was identical. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Let's just wait for #35795 which will get rid of the CI failure for modular conversion! Sorry about that, and thanks for being so patient with us 🙏🙏🤗
Great work!
Awesome, sounds good! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! A few comments about the code paths, regex init and should be good!
"shared_transformer.pre_ff_layernorm.weight", | ||
] | ||
self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] | ||
if self.config.use_shared_mlp_adapter: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment about code path, which models have this set to true / false?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- tied key supports regex patter, we should never have to add all of themmanually like this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all checkpoints have config.use_shared_mlp_adapter
set to True
. We have internal checkpoints with this flag set to False
, which might be released in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather we add a new model when they are release than having 2 code pahts 😉 it's two different models for us!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replaced tied keys with regex patterns here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather we add a new model when they are release than having 2 code pahts 😉 it's two different models for us!
sounds good, got rid of config.use_shared_mlp_adapter
here.
, dtype=torch.float32) # fmt: skip | ||
|
||
torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3) | ||
torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's missing a test on cpu with the sow forward!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please say more about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done here for both test_simple_generate
and test_simple_batched_generate_with_padding
.
test_simple_generate
passes the cpu test straightforwardly. test_simple_batched_generate_with_padding
marginally doesn't pass on cpu for one of the two output logits in the batch (disagreement on 2 out of 40 logits):
torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3)
> torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1e-3)
E AssertionError: Tensor-likes are not close!
E
E Mismatched elements: 2 / 40 (5.0%)
E Greatest absolute difference: 0.009563922882080078 at index (12,) (up to 0.001 allowed)
E Greatest relative difference: 0.030748309567570686 at index (0,) (up to 0.001 allowed)
Given this is a 1.2B parameter model, it's not so surprising to find occasional small discrepancies in the forward pass when running a model of this size on CPU instead of GPU. I updated the value of the absolute tolerance to a new value when the test is run on CPU here: atol=1e-3
-> atol=6e-3 if torch_device == "cpu" else 1e-3
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep sounds good thanks a lot for checking this
Co-authored-by: Arthur <[email protected]>
This reverts commit 9007a52.
Thank you @ArthurZucker! I think all your comments have been addressed. All zamba-related tests appear to pass! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks only a small comment left regarding code paths and good to go !
query_states = self.q_proj(hidden_states) | ||
key_states = self.k_proj(hidden_states) | ||
value_states = self.v_proj(hidden_states) | ||
if self.config.use_shared_attention_adapter: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don'tknow if I asked already, but similarly is this true / false for the released checkpoint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is true for some of the released checkpoints and false for other released checkpoints
key_states = key_states.view(hidden_shape).transpose(1, 2) | ||
value_states = value_states.view(hidden_shape).transpose(1, 2) | ||
|
||
if self.config.use_mem_rope: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment, let's weed out the final bits that are not part of the released checkpoint!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is also true for part of the released checkpoints and false for others
if config.use_mem_rope: | ||
if config.use_long_context: | ||
logger.warning_once( | ||
"`use_long_context` set to `True`: using rescaled `rope_theta` and extended `max_position_embeddings`." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for this, this is a flag that changes theta and increases model's performance at long-context tasks and is specific to the 7B checkpoint, so it can be both true and false depending on the checkpoint
Co-authored-by: Arthur <[email protected]>
Thanks @ArthurZucker, I replied to your comments above. |
What does this PR do?
Please include support for Zamba2 architecture created by Zyphra Technologies.
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker