Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Zamba2 #34517

Merged
merged 90 commits into from
Jan 27, 2025
Merged

Add Zamba2 #34517

merged 90 commits into from
Jan 27, 2025

Conversation

pglorio
Copy link
Contributor

@pglorio pglorio commented Oct 30, 2024

What does this PR do?

Please include support for Zamba2 architecture created by Zyphra Technologies.

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker

@pglorio pglorio marked this pull request as draft October 30, 2024 17:57
@pglorio
Copy link
Contributor Author

pglorio commented Nov 11, 2024

Hey @Arthur,

Thank you again for your help in getting Zamba2 into transformers! The PR is now finally ready to be reviewed. I added the documentation and all unit tests pass, including slow tests.

A few remarks, mostly related to modular transformers:

  1. To generate modeling and configuration I used utils/modular_model_converter.py from a previous commit because the most recent version of this script that followed from a large refactoring produces an error that I was not able to fix:
Converting src/transformers/models/zamba2/modular_zamba2.py to a single model single file format
Traceback (most recent call last):
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1510, in <module>
    converted_files = convert_modular_file(file_name, args.old_model_name, args.new_model_name)
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1447, in convert_modular_file
    for file, module in create_modules(cst_transformers).items():
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1387, in create_modules
    nodes_to_add, file_type, new_imports = get_class_node_and_dependencies(modular_mapper, class_name, node, files)
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1337, in get_class_node_and_dependencies
    new_node_dependencies, new_imports = check_dependencies_and_create_import_node(
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1283, in check_dependencies_and_create_import_node
    class_dependencies = {dep for dep in new_dependencies if m.matches(mapper.global_nodes[dep], m.ClassDef())}
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1283, in <setcomp>
    class_dependencies = {dep for dep in new_dependencies if m.matches(mapper.global_nodes[dep], m.ClassDef())}
KeyError: 'Zamba2Config'

I carefully compared Zamba2Config with classes of other models that also use modular (such as Gemma2Config) and they appear to have consistent format. Relatedly, the utils/modular_model_converter.py in the current PR (path) is the version from the previous commit mentioned above.

  1. After running utils/modular_model_converter.py, the modeling and configuration files generated contain unintended code that I had to update. All these modifications are in this commit. In particular, the produced modeling file contains Zamba2DynamicCache, which is the correct cache of Zamba2 as well as HybridMambaAttentionDynamicCache, which is the cache of Zamba and is not relevant to Zamba2, so I deleted HybridMambaAttentionDynamicCache and related references.

  2. I ran make fixup and all zamba-related tests pass, with the exception of python utils/check_modular_conversion.py. This test doesn't pass due to the modifications mentioned in the previous point.

  3. I slightly edited the Zamba2MambaMixer compared to the original Mamba2Mixer of mamba2, the main difference is that I added these lines, which was necessary to appropriately process the mamba2 cache (note this step already existed in the torch forward in these lines).

Looking forward to your feedback. Thanks so much!

@pglorio
Copy link
Contributor Author

pglorio commented Jan 17, 2025

Hello @Cyrilvallez, I ran all model tests on two GPUs and after a couple of minor fixes everything appears to work now. I'm skipping this test as it gives an error related to mamba2 kernels. I indeed verified that mamba2 skips that test here.

Separately, when running utils/check_modular_conversion.py I get the following error:

Differences found between the generated code and src/transformers/models/zamba2/modeling_zamba2.py:

   1 --- src/transformers/models/zamba2/modeling_zamba2.py_generated
   2 +++ src/transformers/models/zamba2/modeling_zamba2.py
   3 @@ -313,6 +313,13 @@
   4      return attn_output, attn_weights
   5  
   6  
   7 +def rotate_half(x):
   8 +    """Rotates half the hidden dims of the input."""
   9 +    x1 = x[..., : x.shape[-1] // 2]
  10 +    x2 = x[..., x.shape[-1] // 2 :]
  11 +    return torch.cat((-x2, x1), dim=-1)
  12 +
  13 +
  14  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  15      """Applies Rotary Position Embedding to the query and key tensors.
  16  
  17 @@ -338,13 +345,6 @@
  18      q_embed = (q * cos) + (rotate_half(q) * sin)
  19      k_embed = (k * cos) + (rotate_half(k) * sin)
  20      return q_embed, k_embed
  21 -
  22 -
  23 -def rotate_half(x):
  24 -    """Rotates half the hidden dims of the input."""
  25 -    x1 = x[..., : x.shape[-1] // 2]
  26 -    x2 = x[..., x.shape[-1] // 2 :]
  27 -    return torch.cat((-x2, x1), dim=-1)

which I was not getting before despite this part was identical.

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Let's just wait for #35795 which will get rid of the CI failure for modular conversion! Sorry about that, and thanks for being so patient with us 🙏🙏🤗
Great work!

@pglorio
Copy link
Contributor Author

pglorio commented Jan 21, 2025

Awesome, sounds good!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! A few comments about the code paths, regex init and should be good!

docs/source/en/model_doc/zamba2.md Outdated Show resolved Hide resolved
src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
"shared_transformer.pre_ff_layernorm.weight",
]
self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]]
if self.config.use_shared_mlp_adapter:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment about code path, which models have this set to true / false?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • tied key supports regex patter, we should never have to add all of themmanually like this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all checkpoints have config.use_shared_mlp_adapter set to True. We have internal checkpoints with this flag set to False, which might be released in the future.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather we add a new model when they are release than having 2 code pahts 😉 it's two different models for us!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replaced tied keys with regex patterns here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather we add a new model when they are release than having 2 code pahts 😉 it's two different models for us!

sounds good, got rid of config.use_shared_mlp_adapter here.

, dtype=torch.float32) # fmt: skip

torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1e-3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's missing a test on cpu with the sow forward!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please say more about this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done here for both test_simple_generate and test_simple_batched_generate_with_padding.

test_simple_generate passes the cpu test straightforwardly. test_simple_batched_generate_with_padding marginally doesn't pass on cpu for one of the two output logits in the batch (disagreement on 2 out of 40 logits):

        torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3)
>       torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1e-3)
E       AssertionError: Tensor-likes are not close!
E       
E       Mismatched elements: 2 / 40 (5.0%)
E       Greatest absolute difference: 0.009563922882080078 at index (12,) (up to 0.001 allowed)
E       Greatest relative difference: 0.030748309567570686 at index (0,) (up to 0.001 allowed)

Given this is a 1.2B parameter model, it's not so surprising to find occasional small discrepancies in the forward pass when running a model of this size on CPU instead of GPU. I updated the value of the absolute tolerance to a new value when the test is run on CPU here: atol=1e-3 -> atol=6e-3 if torch_device == "cpu" else 1e-3.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep sounds good thanks a lot for checking this

@pglorio
Copy link
Contributor Author

pglorio commented Jan 24, 2025

Thank you @ArthurZucker! I think all your comments have been addressed. All zamba-related tests appear to pass!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks only a small comment left regarding code paths and good to go !

docs/source/en/model_doc/zamba2.md Outdated Show resolved Hide resolved
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if self.config.use_shared_attention_adapter:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don'tknow if I asked already, but similarly is this true / false for the released checkpoint?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is true for some of the released checkpoints and false for other released checkpoints

key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)

if self.config.use_mem_rope:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment, let's weed out the final bits that are not part of the released checkpoint!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is also true for part of the released checkpoints and false for others

Comment on lines +963 to +967
if config.use_mem_rope:
if config.use_long_context:
logger.warning_once(
"`use_long_context` set to `True`: using rescaled `rope_theta` and extended `max_position_embeddings`."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here!

Copy link
Contributor Author

@pglorio pglorio Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this, this is a flag that changes theta and increases model's performance at long-context tasks and is specific to the 7B checkpoint, so it can be both true and false depending on the checkpoint

@pglorio
Copy link
Contributor Author

pglorio commented Jan 24, 2025

Thanks @ArthurZucker, I replied to your comments above.

@ArthurZucker ArthurZucker marked this pull request as ready for review January 27, 2025 09:26
@ArthurZucker ArthurZucker merged commit 33cb1f7 into huggingface:main Jan 27, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants