Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add embedding scaling #34980

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

literid
Copy link

@literid literid commented Nov 27, 2024

What does this PR do?

This PR introduces support for embedding scaling in the MistralModel. The feature is inspired by the Scaled Embed method, which demonstrates that applying a scaling factor to the embeddings significantly improves the stability of large language model (LLM) training, effectively mitigating gradient spikes.

Key Changes:

  • Adds a new configuration parameter:
    embedding_scale (float, optional, defaults to 1.0): A scaling factor applied to the model's embeddings.
  • Updates the MistralModel implementation to apply the scaling factor to the embeddings during training and inference.

This implementation currently supports the PyTorch backend. Support for TensorFlow and Flax backends can be added in the future.


Motivation

The Scaled Embed method improves training stability and helps mitigate gradient spikes, as shown in the referenced paper. By implementing this feature, we aim to bring these benefits to the MistralModel while maintaining backward compatibility.


Open Questions for Discussion

  1. Relevance:
    Do you see this feature as relevant for integration into the library? Would it make sense to extend this functionality to other models, such as LlamaModel?

  2. Implementation Scope:
    Should embedding scaling be integrated more broadly across models in the library or should it remain model-specific ?


Let me know if further adjustments or refinements are needed!

Copy link

@ruidazeng ruidazeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me.

@Rocketknight1
Copy link
Member

Hi @literid, this is an interesting paper! However, I'm not sure it makes sense to add it to an existing model class that wasn't trained with it - it's very likely that everyone will just leave it at 1.0.

I think we should only add this feature when a model was trained with it, to ensure it'll actually get used. This might change in future if this becomes a very widespread and popular method - then we might go back and enable it for every model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants