diff --git a/README.md b/README.md index 2b45e89f0..eeedd1445 100644 --- a/README.md +++ b/README.md @@ -224,6 +224,7 @@ loss.backward() | LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Pixtral | `liger_kernel.transformers.apply_liger_kernel_to_pixtral` | RoPE, RMSNorm, SwiGLU| | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss | | Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 147948b18..40d8de7c0 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -14,6 +14,7 @@ apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, apply_liger_kernel_to_phi3, + apply_liger_kernel_to_pixtral, apply_liger_kernel_to_qwen2, apply_liger_kernel_to_qwen2_vl, ) diff --git a/src/liger_kernel/transformers/model/pixtral.py b/src/liger_kernel/transformers/model/pixtral.py new file mode 100644 index 000000000..6bac95da8 --- /dev/null +++ b/src/liger_kernel/transformers/model/pixtral.py @@ -0,0 +1,103 @@ +from typing import Optional, Tuple, Union + +import torch +from transformers.modeling_outputs import BaseModelOutput +from transformers.models.pixtral.modeling_pixtral import ( + _CONFIG_FOR_DOC, + PIXTRAL_INPUTS_DOCSTRING, +) +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) + + +@add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING) +@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) +def lce_forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutput]: + r""" + Copy paste Pixtral's forward from transformers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy + + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + position_embeddings, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, encoder_states, all_attentions] if v is not None + ) + + return BaseModelOutput( + last_hidden_states=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 1cca9753c..cdb0756b5 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -10,6 +10,7 @@ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward +from liger_kernel.transformers.model.pixtral import lce_forward as pixtral_lce_forward from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb @@ -139,6 +140,35 @@ def apply_liger_kernel_to_mixtral( modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP +def apply_liger_kernel_to_pixtral( + rope: bool = True, + rms_norm: bool = True, + fused_linear_cross_entropy: bool = True, + swiglu: bool = True, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Mistral models + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True. + fused_linear_cross_entropy (bool): If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + """ + from transformers.models.pixtral import modeling_pixtral + + if rope: + modeling_pixtral.apply_rotary_pos_emb = liger_rotary_pos_emb + if rms_norm: + modeling_pixtral.MistralRMSNorm = LigerRMSNorm + if fused_linear_cross_entropy: + modeling_pixtral.PixtralTransformer.forward = pixtral_lce_forward + if swiglu: + modeling_pixtral.MistralMLP = LigerSwiGLUMLP + + def apply_liger_kernel_to_gemma( rope: bool = True, cross_entropy: bool = False, @@ -339,6 +369,7 @@ def apply_liger_kernel_to_phi3( "llama": apply_liger_kernel_to_llama, "mistral": apply_liger_kernel_to_mistral, "mixtral": apply_liger_kernel_to_mixtral, + "pixtral": apply_liger_kernel_to_pixtral, "qwen2": apply_liger_kernel_to_qwen2, "qwen2_vl": apply_liger_kernel_to_qwen2_vl, "phi3": apply_liger_kernel_to_phi3, diff --git a/src/liger_kernel/triton/monkey_patch.py b/src/liger_kernel/triton/monkey_patch.py index 590842a83..70863f4e3 100644 --- a/src/liger_kernel/triton/monkey_patch.py +++ b/src/liger_kernel/triton/monkey_patch.py @@ -37,6 +37,6 @@ def apply_liger_triton_cache_manager(): Experimental feature to get around transient FileNotFoundError in triton compilation. For more details please see https://github.com/triton-lang/triton/pull/4295 """ - os.environ[ - "TRITON_CACHE_MANAGER" - ] = "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager" + os.environ["TRITON_CACHE_MANAGER"] = ( + "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager" + ) diff --git a/test/convergence/test_mini_models_no_logits.py b/test/convergence/test_mini_models_no_logits.py index 540468849..563efbeac 100644 --- a/test/convergence/test_mini_models_no_logits.py +++ b/test/convergence/test_mini_models_no_logits.py @@ -17,6 +17,7 @@ from transformers.models.mistral import MistralConfig, MistralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM +from transformers.models.pixtral import PixtralConfig, PixtralTransformer from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM from liger_kernel.transformers import ( @@ -26,6 +27,7 @@ apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, apply_liger_kernel_to_phi3, + apply_liger_kernel_to_pixtral, apply_liger_kernel_to_qwen2, apply_liger_kernel_to_qwen2_vl, ) @@ -174,6 +176,24 @@ attn_implementation="sdpa", ), ), + "mini_pixtral": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_pixtral, + model_class=PixtralTransformer, + mini_model_config=PixtralConfig( + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=24, + num_attention_heads=16, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_activation="gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + rope_theta=10000.0, + tie_word_embeddings=False, + ), + ), "mini_gemma1": MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_gemma, model_class=GemmaForCausalLM, @@ -498,6 +518,22 @@ def run_mini_model( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), + ("mini_pixtral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + pytest.param( + "mini_pixtral", + 32, + 1e-4, + torch.bfloat16, + 1e-8, + 1e-5, + 1e-2, + 1e-5, + 1e-2, + 1e-5, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way) ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 4116287fd..e0fe6fd5c 100644 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -21,6 +21,7 @@ def test_import_from_root(): apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, apply_liger_kernel_to_phi3, + apply_liger_kernel_to_pixtral, apply_liger_kernel_to_qwen2, apply_liger_kernel_to_qwen2_vl, )