-
Notifications
You must be signed in to change notification settings - Fork 27.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* First commit * Finish model implementation * First commit * Finish model implementation * Register zamba2 * generated modeling and configuration * generated modeling and configuration * added hybrid cache * fix attention_mask in mamba * dropped unused loras * fix flash2 * config docstrings * fix config and fwd pass * make fixup fixes * text_modeling_zamba2 * small fixes * make fixup fixes * Fix modular model converter * added inheritances in modular, renamed zamba cache * modular rebase * new modular conversion * fix generated modeling file * fixed import for Zamba2RMSNormGated * modular file cleanup * make fixup and model tests * dropped inheritance for Zamba2PreTrainedModel * make fixup and unit tests * Add inheritance of rope from GemmaRotaryEmbedding * moved rope to model init * drop del self.self_attn and del self.feed_forward * fix tests * renamed lora -> adapter * rewrote adapter implementation * fixed tests * Fix torch_forward in mamba2 layer * Fix torch_forward in mamba2 layer * Fix torch_forward in mamba2 layer * Dropped adapter in-place sum * removed rope from attention init * updated rope * created get_layers method * make fixup fix * make fixup fixes * make fixup fixes * update to new attention standard * update to new attention standard * make fixup fixes * minor fixes * cache_position * removed cache_position postion_ids use_cache * remove config from modular * removed config from modular (2) * import apply_rotary_pos_emb from llama * fixed rope_kwargs * Instantiate cache in Zamba2Model * fix cache * fix @slow decorator * small fix in modular file * Update docs/source/en/model_doc/zamba2.md Co-authored-by: Arthur <[email protected]> * several minor fixes * inherit mamba2decoder fwd and drop position_ids in mamba * removed docstrings from modular * reinstate zamba2 attention decoder fwd * use regex for tied keys * Revert "use regex for tied keys" This reverts commit 9007a52. * use regex for tied keys * add cpu to slow forward tests * dropped config.use_shared_mlp_adapter * Update docs/source/en/model_doc/zamba2.md Co-authored-by: Arthur <[email protected]> * re-convert from modular --------- Co-authored-by: root <[email protected]> Co-authored-by: Arthur <[email protected]>
- Loading branch information
1 parent
14a9bb5
commit 33cb1f7
Showing
18 changed files
with
4,148 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
--> | ||
# Zamba2 | ||
|
||
Zamba2 is a large language model (LLM) trained by Zyphra, and made available under an Apache 2.0 license. Please see the [Zyphra Hugging Face](https://huggingface.co/collections/zyphra/) repository for model weights. | ||
|
||
This model was contributed by [pglo](https://huggingface.co/pglo). | ||
|
||
|
||
## Model details | ||
|
||
Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B are hybrid models combining state-space models (Specifically [Mamba](https://github.com/state-spaces/mamba)) and transformer, and were trained using next-token prediction. Zamba2 uses shared transformer layers after every 6 mamba blocks. It uses the [Mistral v0.1 tokenizer](https://huggingface.co/mistralai/Mistral-7B-v0.1). We came to this architecture after a series of ablations at small scales. Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B were pre-trained on 2T and 3T tokens, respectively. | ||
|
||
<img src=https://github.com/user-attachments/assets/c2cff209-b901-483c-87aa-774b82a0769f width=30% height=40% /> | ||
|
||
## Quick start | ||
|
||
|
||
### Presequities | ||
|
||
Zamba2 requires you use `transformers` version 4.48.0 or higher: | ||
```bash | ||
pip install transformers>=4.48.0 | ||
## Inference | ||
|
||
```python | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
import torch | ||
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B") | ||
model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16) | ||
input_text = "What factors contributed to the fall of the Roman Empire?" | ||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") | ||
outputs = model.generate(**input_ids, max_new_tokens=100) | ||
print(tokenizer.decode(outputs[0])) | ||
``` | ||
|
||
|
||
## Model card | ||
|
||
The model cards can be found at: | ||
* [Zamba2-1.2B](https://huggingface.co/Zyphra/Zamba2-1.2B) | ||
* [Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B) | ||
* [Zamba2-7B](https://huggingface.co/Zyphra/Zamba2-7B) | ||
|
||
|
||
## Issues | ||
For issues with model output, or community discussion, please use the Hugging Face community [forum](https://huggingface.co/Zyphra/Zamba2-7B/discussions) | ||
|
||
|
||
## License | ||
|
||
The model weights are open-sourced via an Apache 2.0 license. | ||
|
||
|
||
## Zamba2Config | ||
|
||
[[autodoc]] Zamba2Config | ||
|
||
|
||
## Zamba2Model | ||
|
||
[[autodoc]] Zamba2Model | ||
- forward | ||
|
||
|
||
## Zamba2ForCausalLM | ||
|
||
[[autodoc]] Zamba2ForCausalLM | ||
- forward | ||
|
||
|
||
## Zamba2ForSequenceClassification | ||
|
||
[[autodoc]] transformers.Zamba2ForSequenceClassification | ||
- forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -303,5 +303,6 @@ | |
yolos, | ||
yoso, | ||
zamba, | ||
zamba2, | ||
zoedepth, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import TYPE_CHECKING | ||
|
||
from ...utils import _LazyModule | ||
from ...utils.import_utils import define_import_structure | ||
|
||
|
||
if TYPE_CHECKING: | ||
from .configuration_zamba2 import * | ||
from .modeling_zamba2 import * | ||
else: | ||
import sys | ||
|
||
_file = globals()["__file__"] | ||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) |
Oops, something went wrong.