Skip to content

Commit

Permalink
Add Zamba2 (#34517)
Browse files Browse the repository at this point in the history
* First commit

* Finish model implementation

* First commit

* Finish model implementation

* Register zamba2

* generated modeling and configuration

* generated modeling and configuration

* added hybrid cache

* fix attention_mask in mamba

* dropped unused loras

* fix flash2

* config docstrings

* fix config and fwd pass

* make fixup fixes

* text_modeling_zamba2

* small fixes

* make fixup fixes

* Fix modular model converter

* added inheritances in modular, renamed zamba cache

* modular rebase

* new modular conversion

* fix generated modeling file

* fixed import for Zamba2RMSNormGated

* modular file cleanup

* make fixup and model tests

* dropped inheritance for Zamba2PreTrainedModel

* make fixup and unit tests

* Add inheritance of rope from GemmaRotaryEmbedding

* moved rope to model init

* drop del self.self_attn and del self.feed_forward

* fix tests

* renamed lora -> adapter

* rewrote adapter implementation

* fixed tests

* Fix torch_forward in mamba2 layer

* Fix torch_forward in mamba2 layer

* Fix torch_forward in mamba2 layer

* Dropped adapter in-place sum

* removed rope from attention init

* updated rope

* created get_layers method

* make fixup fix

* make fixup fixes

* make fixup fixes

* update to new attention standard

* update to new attention standard

* make fixup fixes

* minor fixes

* cache_position

* removed cache_position postion_ids use_cache

* remove config from modular

* removed config from modular (2)

* import apply_rotary_pos_emb from llama

* fixed rope_kwargs

* Instantiate cache in Zamba2Model

* fix cache

* fix @slow decorator

* small fix in modular file

* Update docs/source/en/model_doc/zamba2.md

Co-authored-by: Arthur <[email protected]>

* several minor fixes

* inherit mamba2decoder fwd and drop position_ids in mamba

* removed docstrings from modular

* reinstate zamba2 attention decoder fwd

* use regex for tied keys

* Revert "use regex for tied keys"

This reverts commit 9007a52.

* use regex for tied keys

* add cpu to slow forward tests

* dropped config.use_shared_mlp_adapter

* Update docs/source/en/model_doc/zamba2.md

Co-authored-by: Arthur <[email protected]>

* re-convert from modular

---------

Co-authored-by: root <[email protected]>
Co-authored-by: Arthur <[email protected]>
  • Loading branch information
3 people authored Jan 27, 2025
1 parent 14a9bb5 commit 33cb1f7
Show file tree
Hide file tree
Showing 18 changed files with 4,148 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ Flax), PyTorch, and/or TensorFlow.
| [YOLOS](model_doc/yolos) ||||
| [YOSO](model_doc/yoso) ||||
| [Zamba](model_doc/zamba) ||||
| [Zamba2](model_doc/zamba2) ||||
| [ZoeDepth](model_doc/zoedepth) ||||

<!-- End table-->
91 changes: 91 additions & 0 deletions docs/source/en/model_doc/zamba2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Zamba2

Zamba2 is a large language model (LLM) trained by Zyphra, and made available under an Apache 2.0 license. Please see the [Zyphra Hugging Face](https://huggingface.co/collections/zyphra/) repository for model weights.

This model was contributed by [pglo](https://huggingface.co/pglo).


## Model details

Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B are hybrid models combining state-space models (Specifically [Mamba](https://github.com/state-spaces/mamba)) and transformer, and were trained using next-token prediction. Zamba2 uses shared transformer layers after every 6 mamba blocks. It uses the [Mistral v0.1 tokenizer](https://huggingface.co/mistralai/Mistral-7B-v0.1). We came to this architecture after a series of ablations at small scales. Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B were pre-trained on 2T and 3T tokens, respectively.

<img src=https://github.com/user-attachments/assets/c2cff209-b901-483c-87aa-774b82a0769f width=30% height=40% />

## Quick start


### Presequities

Zamba2 requires you use `transformers` version 4.48.0 or higher:
```bash
pip install transformers>=4.48.0
## Inference

```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16)
input_text = "What factors contributed to the fall of the Roman Empire?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))
```


## Model card

The model cards can be found at:
* [Zamba2-1.2B](https://huggingface.co/Zyphra/Zamba2-1.2B)
* [Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B)
* [Zamba2-7B](https://huggingface.co/Zyphra/Zamba2-7B)


## Issues
For issues with model output, or community discussion, please use the Hugging Face community [forum](https://huggingface.co/Zyphra/Zamba2-7B/discussions)


## License

The model weights are open-sourced via an Apache 2.0 license.


## Zamba2Config

[[autodoc]] Zamba2Config


## Zamba2Model

[[autodoc]] Zamba2Model
- forward


## Zamba2ForCausalLM

[[autodoc]] Zamba2ForCausalLM
- forward


## Zamba2ForSequenceClassification

[[autodoc]] transformers.Zamba2ForSequenceClassification
- forward
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)
* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2)

You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.

Expand Down Expand Up @@ -328,6 +329,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)
* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2)

<Tip>

Expand Down
16 changes: 16 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,7 @@
"models.yolos": ["YolosConfig"],
"models.yoso": ["YosoConfig"],
"models.zamba": ["ZambaConfig"],
"models.zamba2": ["Zamba2Config"],
"models.zoedepth": ["ZoeDepthConfig"],
"onnx": [],
"pipelines": [
Expand Down Expand Up @@ -3989,6 +3990,14 @@
"ZambaPreTrainedModel",
]
)
_import_structure["models.zamba2"].extend(
[
"Zamba2ForCausalLM",
"Zamba2ForSequenceClassification",
"Zamba2Model",
"Zamba2PreTrainedModel",
]
)
_import_structure["models.zoedepth"].extend(
[
"ZoeDepthForDepthEstimation",
Expand Down Expand Up @@ -6004,6 +6013,7 @@
from .models.yolos import YolosConfig
from .models.yoso import YosoConfig
from .models.zamba import ZambaConfig
from .models.zamba2 import Zamba2Config
from .models.zoedepth import ZoeDepthConfig

# Pipelines
Expand Down Expand Up @@ -8542,6 +8552,12 @@
ZambaModel,
ZambaPreTrainedModel,
)
from .models.zamba2 import (
Zamba2ForCausalLM,
Zamba2ForSequenceClassification,
Zamba2Model,
Zamba2PreTrainedModel,
)
from .models.zoedepth import (
ZoeDepthForDepthEstimation,
ZoeDepthPreTrainedModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,5 +303,6 @@
yolos,
yoso,
zamba,
zamba2,
zoedepth,
)
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@
("yolos", "YolosConfig"),
("yoso", "YosoConfig"),
("zamba", "ZambaConfig"),
("zamba2", "Zamba2Config"),
("zoedepth", "ZoeDepthConfig"),
]
)
Expand Down Expand Up @@ -680,6 +681,7 @@
("yolos", "YOLOS"),
("yoso", "YOSO"),
("zamba", "Zamba"),
("zamba2", "Zamba2"),
("zoedepth", "ZoeDepth"),
]
)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@
("yolos", "YolosModel"),
("yoso", "YosoModel"),
("zamba", "ZambaModel"),
("zamba2", "Zamba2Model"),
]
)

Expand Down Expand Up @@ -577,6 +578,7 @@
("xlnet", "XLNetLMHeadModel"),
("xmod", "XmodForCausalLM"),
("zamba", "ZambaForCausalLM"),
("zamba2", "Zamba2ForCausalLM"),
]
)

Expand Down Expand Up @@ -1055,6 +1057,7 @@
("xmod", "XmodForSequenceClassification"),
("yoso", "YosoForSequenceClassification"),
("zamba", "ZambaForSequenceClassification"),
("zamba2", "Zamba2ForSequenceClassification"),
]
)

Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,13 @@
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"zamba2",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
]
)

Expand Down
13 changes: 1 addition & 12 deletions src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ def forward(
layer_idx: int,
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[ZambaHybridDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
Expand Down Expand Up @@ -621,11 +620,9 @@ def forward(
original_hidden_states: torch.Tensor,
layer_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand All @@ -638,7 +635,6 @@ def forward(
layer_idx (`int`): layer_idx in the forward pass. Used to distinguish Zamba's tied transformer layers.
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
position_ids (`torch.LongTensor`, *optional*): token positions of shape `(batch, seq_len)`. Used for positional encodings.
past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
Expand All @@ -655,11 +651,9 @@ def forward(
hidden_states=hidden_states,
layer_idx=layer_idx,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
# feed-forward (MLP)
Expand Down Expand Up @@ -688,12 +682,12 @@ def forward(
layer_idx: int = None,
attention_mask: Optional[torch.Tensor] = None,
causal_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
transformer_hidden_states: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand Down Expand Up @@ -756,7 +750,6 @@ def forward(
layer_idx: int = None,
attention_mask: Optional[torch.Tensor] = None,
causal_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
Expand Down Expand Up @@ -786,7 +779,6 @@ def forward(
original_hidden_states=original_hidden_states,
layer_idx=layer_idx,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
Expand All @@ -804,7 +796,6 @@ def forward(
hidden_states,
transformer_hidden_states=transformer_hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
Expand Down Expand Up @@ -1108,7 +1099,6 @@ def forward(
layer_idx,
attention_mask,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
Expand All @@ -1121,7 +1111,6 @@ def forward(
layer_idx=layer_idx,
attention_mask=attention_mask,
causal_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/models/zamba2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_zamba2 import *
from .modeling_zamba2 import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Loading

0 comments on commit 33cb1f7

Please sign in to comment.