Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Molmo (7B-D, 7B-O, 70B) #33962

Open
wants to merge 145 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
145 commits
Select commit Hold shift + click to select a range
dc6fcac
add base convert keys + chat template
molbap Oct 1, 2024
574e01f
Merge branch 'main' into add_molmo
molbap Oct 2, 2024
0bd413b
draft: add up modular files for molmo
molbap Oct 4, 2024
9e454e4
Squashed commit of the following:
molbap Oct 8, 2024
d82c471
sync changes
molbap Oct 8, 2024
339a8d3
push a simple fix
ArthurZucker Oct 8, 2024
c0c25d6
finish fixing
ArthurZucker Oct 8, 2024
5ee6a44
Merge branch 'main' into add_molmo
molbap Oct 8, 2024
33e43ec
suppress diff
molbap Oct 8, 2024
d23e1c1
Merge branch 'main' into add_molmo
molbap Oct 10, 2024
c8c12fe
fix
ArthurZucker Oct 10, 2024
0909c02
style
ArthurZucker Oct 10, 2024
1799d20
add config + 2d pooling
molbap Oct 10, 2024
fb133d4
suppress changes
molbap Oct 10, 2024
5ba4105
Merge branch 'add_molmo' of github.com:molbap/transformers into add_m…
molbap Oct 10, 2024
a2a6a9b
fix
ArthurZucker Oct 10, 2024
8fe7a9f
Merge branch 'add_molmo' of github.com:molbap/transformers into add_m…
ArthurZucker Oct 10, 2024
20681f5
conversion works :raised_hands:
molbap Oct 11, 2024
c85af98
fixup
molbap Oct 11, 2024
35ea3cc
handle missing MOLMO_VISION_ATTENTION_CLASSES
molbap Oct 11, 2024
ab79d0e
fix
molbap Oct 11, 2024
b9bdf99
fix fused keys mismatch
molbap Oct 15, 2024
98d5ccd
fix
molbap Oct 15, 2024
3bca742
[Modular-breaking] add manually vision attention classes list
molbap Oct 15, 2024
a13fe05
finish weight conversion script
molbap Oct 15, 2024
fac8dfd
add more keys
molbap Oct 16, 2024
c1e5f19
flipped the linear layers
molbap Oct 16, 2024
a68e5f5
add pooling forward + draft general forward
molbap Oct 16, 2024
8298b80
modeling file with swiglu, forward(input_ids) passing
molbap Oct 16, 2024
9f69c6b
BIG push of image processor
molbap Oct 23, 2024
0711e08
add missing objects to init
molbap Oct 23, 2024
7efe22e
Merge branch 'main' into add_molmo
molbap Nov 5, 2024
f5bd3b0
fix up wrong channel dimension
molbap Nov 7, 2024
3ae884f
fix typo
molbap Nov 7, 2024
3ef60c0
add missing image token indices used in forward
molbap Nov 19, 2024
cf9d4ab
pad patch orderings
molbap Nov 19, 2024
91a2d3c
clean up conversion script
molbap Nov 19, 2024
0f7904f
remind that tests are TODO
molbap Nov 19, 2024
577e347
merge main
zucchini-nlp Nov 21, 2024
b514041
at least it runs like this
zucchini-nlp Nov 24, 2024
cf6cb5d
add bos token
molbap Nov 27, 2024
26c517d
add bos token in prompt
molbap Nov 27, 2024
35c168d
fix processor, missing batching img_mask
molbap Nov 27, 2024
e7275c7
fix image masks + batching
molbap Nov 27, 2024
3e7530d
working version
zucchini-nlp Nov 27, 2024
4bbc89b
+1 only on non masked indices
zucchini-nlp Nov 27, 2024
54e072b
attemp 1 to make modular work
zucchini-nlp Nov 27, 2024
1e99752
update conversion to fit all ckpt + chat template + clean up a bit
zucchini-nlp Nov 27, 2024
92a1f31
fix processing tests
zucchini-nlp Nov 27, 2024
42330e0
add more tests (failing for now)
zucchini-nlp Nov 27, 2024
932f6d1
fix the conversion
zucchini-nlp Nov 27, 2024
aafb827
done!
zucchini-nlp Nov 27, 2024
36cc6dd
nit
zucchini-nlp Nov 27, 2024
f399c3a
some tests are failing, coming back tomorrow
zucchini-nlp Nov 27, 2024
7322227
adapt to any image format
molbap Nov 27, 2024
e4db50a
Merge branch 'add_molmo' of github.com:molbap/transformers into add_m…
molbap Nov 27, 2024
205a755
try to get batched generation working
molbap Nov 28, 2024
eb61617
fix other tests, should work now
zucchini-nlp Nov 28, 2024
b77d947
adjust test for batching
zucchini-nlp Nov 28, 2024
ba4dd50
little bit of style
zucchini-nlp Nov 28, 2024
0e2d184
docs + imports + automapping
zucchini-nlp Nov 28, 2024
9a83706
remove images kwargs
zucchini-nlp Nov 28, 2024
171eb8e
some unused config attributes
zucchini-nlp Nov 28, 2024
35b517a
remove additional vocab size and pad lm head
zucchini-nlp Nov 28, 2024
6a0cbc5
remove einops dependency
molbap Nov 28, 2024
5c7b141
Merge branch 'add_molmo' of github.com:molbap/transformers into add_m…
molbap Nov 28, 2024
434d4b1
dont skip these tests
zucchini-nlp Nov 28, 2024
4645f97
format + add integration testing
molbap Nov 28, 2024
48f2e21
Merge branch 'add_molmo' of github.com:molbap/transformers into add_m…
molbap Nov 28, 2024
4bb4e48
fix tests + fix 72B conversion
molbap Nov 29, 2024
e676782
fix format
molbap Nov 29, 2024
a74bda2
modualr kinda works but adds extra classes like `VisionVisionModel` :(
zucchini-nlp Nov 29, 2024
2c428ae
accomodate 7B-O version as well (broken)
molbap Nov 29, 2024
d338153
merge, fix conflicts and clean up modular extra code
molbap Nov 29, 2024
00376c4
fix 7B-O
zucchini-nlp Dec 2, 2024
48354fe
remove unused code path
zucchini-nlp Dec 2, 2024
d738493
nit
zucchini-nlp Dec 3, 2024
d0e90d4
make modular work mostly
zucchini-nlp Dec 3, 2024
f06b6d9
fix imports
zucchini-nlp Dec 3, 2024
9fc25c0
update modulat last time
zucchini-nlp Dec 3, 2024
38dc9e8
fix copies
zucchini-nlp Dec 3, 2024
eb77f3c
fix copies
zucchini-nlp Dec 4, 2024
190cc35
fix tests
zucchini-nlp Dec 4, 2024
84ed244
initial push of fast processor
molbap Dec 6, 2024
b4d48d5
Merge branch 'add_molmo' of github.com:molbap/transformers into add_m…
molbap Dec 6, 2024
1298d08
Merge branch 'main' into add_molmo
molbap Dec 10, 2024
6687d43
fix various issues + tests
molbap Dec 10, 2024
5f79577
add Molmo submodules as private
molbap Dec 10, 2024
9e72758
do not test submodules
molbap Dec 10, 2024
439aed6
[run-slow] molmo
molbap Dec 10, 2024
5a6a965
underscore prefixed method is not public
molbap Dec 10, 2024
b9746a8
fix tests
molbap Dec 10, 2024
2090ed6
fix docs
molbap Dec 10, 2024
8ad3a25
[run-slow] molmo
molbap Dec 10, 2024
0d10ee4
Merge branch 'main' into add_molmo
molbap Dec 10, 2024
9bd96f5
fix cache shape
molbap Dec 10, 2024
af5468b
[run-slow] molmo
molbap Dec 10, 2024
c02c6de
trigger CI
molbap Dec 10, 2024
5f35055
mark flaky test
molbap Dec 10, 2024
2b7af87
add missing objects
molbap Dec 10, 2024
9f0f09d
add config to init
molbap Dec 10, 2024
74ebb24
more init fixes
molbap Dec 10, 2024
8b00c44
fix style
molbap Dec 10, 2024
d6403ad
fix?
molbap Dec 10, 2024
eb43cb9
fix
molbap Dec 10, 2024
33f0624
what is this again
molbap Dec 10, 2024
cc59007
Merge branch 'main' into add_molmo
molbap Dec 10, 2024
23ae692
is this real life
molbap Dec 10, 2024
4c456e7
it was real life, fix broken eager
molbap Dec 10, 2024
91f2820
fix attribtues
molbap Dec 10, 2024
e2df6bc
this attention should be fixed
molbap Dec 10, 2024
ae77cc6
set 7b test to bf16
molbap Dec 11, 2024
166b28a
[run-slow] molmo
molbap Dec 11, 2024
50bcb7c
Merge branch 'main' into add_molmo
molbap Dec 11, 2024
bf012d8
[run-slow] molmo
molbap Dec 11, 2024
6e0634b
fix text (variability T4/A100)
molbap Dec 11, 2024
8569fd0
push clean Fast (x3!) image processor
molbap Dec 12, 2024
fd401bc
Merge branch 'main' into add_molmo
molbap Dec 12, 2024
86acf22
fix modular changes from main
molbap Dec 12, 2024
1ebea3c
Merge branch 'main' into add_molmo
molbap Dec 16, 2024
5ebc6f0
push fast image proc with device check
molbap Dec 23, 2024
19d2689
push fast image proc with device check
molbap Dec 23, 2024
c652bb9
format
molbap Dec 23, 2024
50c21e5
images kwargs were missing
molbap Dec 23, 2024
092da76
merge and fix conflicts
molbap Dec 23, 2024
1254eac
style
molbap Dec 23, 2024
bd39143
update with modular conversion
molbap Dec 23, 2024
3efcb13
add torch import
molbap Dec 23, 2024
56ae76f
style
molbap Dec 23, 2024
9417ff7
protect import
molbap Dec 23, 2024
51f9336
fix modular
molbap Dec 23, 2024
3719481
Merge branch 'main' into add_molmo
molbap Jan 7, 2025
f394b02
cherry-pick: cohere (from 67c3fcd4f32c64e07f302f00243be7d54914d78b)
molbap Jan 8, 2025
e418aa3
fix modular with cohere interface
molbap Jan 8, 2025
5af0b57
fixup cohere all imports
molbap Jan 8, 2025
a574b93
fix bf16 test output
molbap Jan 8, 2025
9f3018d
fix
molbap Jan 8, 2025
e2d1ba8
style
molbap Jan 8, 2025
c872095
Merge branch 'main' into add_molmo
molbap Jan 9, 2025
41ab3a7
uniformize fast image processor
molbap Jan 9, 2025
dd74b78
Merge branch 'main' into add_molmo
molbap Jan 9, 2025
d052666
fix merge
molbap Jan 9, 2025
0a822f4
unbloat modular a tad
molbap Jan 9, 2025
8ebf44f
fix import
molbap Jan 9, 2025
4e6070f
fix modular
molbap Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix 7B-O
  • Loading branch information
zucchini-nlp committed Dec 2, 2024
commit 00376c4d9914af4bbba787874044f1a938fc0a57
12 changes: 12 additions & 0 deletions src/transformers/models/molmo/configuration_molmo.py
Original file line number Diff line number Diff line change
@@ -300,6 +300,12 @@ class MolmoTextConfig(PretrainedConfig):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
use_postnorm (`bool), *optional*, defaults to `True`):
Whther to apply pre or post layer normalization in each decoder layer.
use_attention_layer_norm (`bool`, *optional*, defaults to `False`):
Whether to apply norm to keys and queries in the attention layer.

```python
>>> from transformers import MolmoTextModel, MolmoTextConfig
@@ -338,6 +344,9 @@ def __init__(
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
attention_bias=False,
use_postnorm=True,
use_attention_layer_norm=False,
**kwargs,
):
super().__init__(
@@ -354,6 +363,9 @@ def __init__(
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if use_sliding_window else None
self.max_window_layers = max_window_layers
self.attention_bias = attention_bias
self.use_postnorm = use_postnorm
self.use_attention_layer_norm = use_attention_layer_norm

# for backward compatibility
if num_key_value_heads is None:
30 changes: 8 additions & 22 deletions src/transformers/models/molmo/convert_molmo_weights_to_hf.py
Original file line number Diff line number Diff line change
@@ -67,7 +67,7 @@
r"transformer.blocks.(\d+).(q|k)_norm.weight": r"language_model.model.layers.\1.self_attn.\2_norm.layer.weight",
r"transformer.blocks.(\d+).attn_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight",
r"transformer.blocks.(\d+).attn_out.weight": r"language_model.model.layers.\1.self_attn.o_proj.weight",
r"transformer.blocks.(\d+).ff_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.layer.weight",
r"transformer.blocks.(\d+).ff_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight",
r"transformer.blocks.(\d+).ff_out.weight": r"language_model.model.layers.\1.mlp.fc2.weight",
r"transformer.blocks.(\d+).ff_proj.weight": r"language_model.model.layers.\1.mlp.fc1.weight",
r"transformer.ff_out.weight": r"language_model.lm_head.weight",
@@ -176,24 +176,13 @@ def write_model(
if variant == "72B":
pooling_config.text_intermediate_size = 59136
pooling_config.text_hidden_size = 8192
text_config.qkv_bias = True
text_config.use_attention_layer_norm = False
text_config.use_post_attention_layernorm = True
text_config.use_post_mlp_layernorm = False
elif variant == "7B-O":
pooling_config.text_intermediate_size = 22016
pooling_config.text_hidden_size = 4096
text_config.qkv_bias = original_config["qkv_bias"]
text_config.use_attention_layer_norm = original_config["attention_layer_norm"]
text_config.use_post_attention_layernorm = False
text_config.use_post_mlp_layernorm = True
elif variant == "7B-D":
text_config.qkv_bias = True
text_config.use_attention_layer_norm = False
text_config.use_post_attention_layernorm = True
text_config.use_post_mlp_layernorm = False

text_config.o_proj_bias = False

text_config.attention_bias = original_config["qkv_bias"]
text_config.use_postnorm = original_config["norm_after"]
text_config.use_attention_layer_norm = original_config["attention_layer_norm"]

config = MolmoConfig(
text_config=text_config.to_dict(),
@@ -221,9 +210,6 @@ def write_model(
# Some post-processing of specific params.
for old_key, new_key in new_keys.items():
new_key = new_key.removeprefix("model.")
# remap keys
if "post_attention_layernorm" in new_key and variant == "7B-O":
new_key = new_key.replace("post_attention_layernorm", "post_mlp_layernorm")
state_dict[new_key] = state_dict.pop(old_key)
# Post-process the current_parameter.

@@ -293,9 +279,9 @@ def write_model(
# ------------------------------------------------------------
extra_special_tokens = {
"image_token": "<image>",
"boi_token": "<im_patch>",
"eoi_token": "<im_start>",
"im_patch_token": "<im_end>",
"boi_token": "<im_start>",
"eoi_token": "<im_end>",
"im_patch_token": "<im_patch>",
"im_col_token": "<im_col>",
}
if variant in ["7B-D", "72B"]:
106 changes: 91 additions & 15 deletions src/transformers/models/molmo/modeling_molmo.py
Original file line number Diff line number Diff line change
@@ -331,10 +331,10 @@ def __init__(self, config: MolmoTextConfig, layer_idx: Optional[int] = None):
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.qkv_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.o_proj_bias)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

self.q_norm = ConditionalMolmoRMSNorm(
hidden_size=self.hidden_size,
@@ -646,7 +646,7 @@ def forward(
}


class MolmoDecoderLayer(nn.Module):
class MolmoPrenormDecoderLayer(nn.Module):
def __init__(self, config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@@ -659,12 +659,7 @@ def __init__(self, config, layer_idx: int):
self.self_attn = MOLMO_TEXT_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.mlp = MolmoMLP(config)
self.input_layernorm = MolmoTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = ConditionalMolmoRMSNorm(
config.hidden_size, use_layer_norm=config.use_post_attention_layernorm, eps=config.rms_norm_eps
)
self.post_mlp_layernorm = ConditionalMolmoRMSNorm(
config.hidden_size, use_layer_norm=config.use_post_mlp_layernorm, eps=config.rms_norm_eps
)
self.post_attention_layernorm = MolmoTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
@@ -701,7 +696,6 @@ def forward(
"""

residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)

# Self Attention
@@ -721,7 +715,88 @@ def forward(
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_mlp_layernorm(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)

if output_attentions:
outputs += (self_attn_weights,)

if use_cache:
outputs += (present_key_value,)

return outputs


class MolmoDecoderLayer(nn.Module):
def __init__(self, config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size

if config.sliding_window and config._attn_implementation != "flash_attention_2":
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
)
self.self_attn = MOLMO_TEXT_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.mlp = MolmoMLP(config)
self.input_layernorm = MolmoTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MolmoTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence.
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""

residual = hidden_states

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = self.input_layernorm(hidden_states)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)
@@ -807,7 +882,7 @@ class MolmoTextPreTrainedModel(PreTrainedModel):
config_class = MolmoTextConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MolmoTextDecoderLayer"]
_no_split_modules = ["MolmoDecoderLayer", "MolmoPrenormDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
@@ -923,8 +998,9 @@ def __init__(self, config):
config.hidden_size,
)

decoder_layer = MolmoDecoderLayer if self.config.use_postnorm else MolmoPrenormDecoderLayer
self.layers = nn.ModuleList(
[MolmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
[decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self._attn_implementation = config._attn_implementation
self.norm = MolmoTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)