Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
This PR introduces support for embedding scaling in the
MistralModel
. The feature is inspired by the Scaled Embed method, which demonstrates that applying a scaling factor to the embeddings significantly improves the stability of large language model (LLM) training, effectively mitigating gradient spikes.Key Changes:
embedding_scale
(float
, optional, defaults to1.0
): A scaling factor applied to the model's embeddings.MistralModel
implementation to apply the scaling factor to the embeddings during training and inference.This implementation currently supports the PyTorch backend. Support for TensorFlow and Flax backends can be added in the future.
Motivation
The Scaled Embed method improves training stability and helps mitigate gradient spikes, as shown in the referenced paper. By implementing this feature, we aim to bring these benefits to the
MistralModel
while maintaining backward compatibility.Open Questions for Discussion
Relevance:
Do you see this feature as relevant for integration into the library? Would it make sense to extend this functionality to other models, such as
LlamaModel
?Implementation Scope:
Should embedding scaling be integrated more broadly across models in the library or should it remain model-specific ?
Let me know if further adjustments or refinements are needed!