-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Liger FlexChunkLoss: Alignment and Distillation loss #371
Comments
take DPO |
I can take fused linear kl div. BTW, really nice illustration on the chunk linear op fusion from the paper. Very clear to new contributors 😄 |
@shivam15s @ByronHsu I think we should also consider including some of the loss functions commonly used for training embedding models, especially the popular ones supported in Sentence transformers. It's quite common for embedding models to require large batch sizes to be trained well. Coupled with the fact that their batch/input structure is kind of similar to RLHF where we have positive and negative pairs, I believe that this can prove to be useful. I'd recommend supporting |
@pramodith that is a good idea! do you know if the models in embedding also has large vocab and suffer from memory bottleneck? |
@ByronHsu most embedding models have a final Linear layer of shape (hidden_dim, hidden_dim), so vocab size doesn't really come into the picture for them so you're right to point it out, but it is common to have an effective batch size of 65k |
Then i think chunk loss is still helpful given the large batch size |
Yes, I think so too. I can give this a try after we wrap up all the important RLHF and distillation losses. I'll also get Tom Aarsen's perspective since he's the lead of Sentence Transformers. |
## Summary Add support for a fused, torch-compiled, and chunked DPO ([Direct Preference Optimization](https://arxiv.org/html/2305.18290v3)) loss kernel, as requested in #371. This implementation is largely based on the excellent work done on ORPO (#362) by @shivam15s. ### DPO Loss Formulation In a reference setting (not reference free): $$r_\theta(x,y_c) - r_\theta(x,y_r) = \log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x))$$ $$-\log(\sigma((\log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x)) - \log(\pi_{\theta_{\text{ref}}}(y_c|x)) + \log(\pi_{\theta_{\text{ref}}}(y_r|x)))/\beta))$$ Corresponds to: ```python # Policy model log probabilities policy_chosen_logps = log_probs(policy_chosen_logits) policy_rejected_logps = log_probs(policy_rejected_logits) # Reference model log probabilities ref_chosen_logps = log_probs(ref_chosen_logits) ref_rejected_logps = log_probs(ref_rejected_logits) # Compute advantages chosen_advantages = policy_chosen_logps - ref_chosen_logps rejected_advantages = policy_rejected_logps - ref_rejected_logps # DPO loss logits_diff = (chosen_advantages - rejected_advantages) / beta losses = -F.logsigmoid(logits_diff) ``` In this PR: 1. The above mathematical equation shows that to maximize the reward difference, we get formula: $$r_θ(x_c) - r_θ(x_r)$$ 2. This can be further optimized using just: $$-log(σ((π_θ(x_c) - π_θ(x_r))/β))$$ 3. So, the code implements: ```python logits_diff = (chosen_logps - rejected_logps) / beta # (π_θ(x_c) - π_θ(x_r))/β losses = -F.logsigmoid(logits_diff) # -log(σ(logits_diff)) ``` 4. Sum up DPO and NLL: $$L_{DPO+NLL} = L_{DPO}+αL_{NLL}$$ ## Testing Done ![dpo_loss_memory](https://github.com/user-attachments/assets/d48965a2-bab7-4a81-9872-a43826106731) ![dpo_loss_speed](https://github.com/user-attachments/assets/10ab33c3-a905-435f-886b-67c911b8fff6) - Hardware Type: **NVIDIA L40S (48G)** - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [X] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu <[email protected]> Co-authored-by: shivam15s <[email protected]>
#take Simpo and Irpo since they are just extensions of CPO. |
I will #take KTO as the next |
A little update on kto I am working now on the tests |
@Chillee FYI We are working on a set of post-training losses based on your compiled chunked loss implementation for CE. Thanks for the reference! |
Update on KTO loss I am done with the loss but I have problem with assertions.I am working on it. |
I was following this thread and working on a chunked, fused linear KL-divergence implementation for distillation use cases. Since distillation losses differ from preference losses, introducing a In general, the distillation pipeline involves three key inputs: To leverage chunked, linear-fused optimizations, we could design the solution to accept inputs as cc @ByronHsu, @shivam15s, @pramodith: What are your thoughts on this? Do you think it makes sense to include the cross-entropy loss as part of the |
@hongpeng-guo yes! I like your approach it's cleaner to create a new Base class for distillation losses, we're kind of doing the same for the Alignment losses to by computing the nll (cross-entropy loss of the accepted responses inside the Base class.)
|
+1 on @hongpeng-guo proposal. @shivam15s can help polish the base class |
Sounds good @hongpeng-guo, a separate base class for distillation is absolutely needed! |
please review and comment my PR on KTO here #410 |
there is an update about #410 |
Is CPO-SimPO planned? This can be implemented in SimPO. Reference: https://github.com/fe1ixxu/CPO_SIMPO QuoteCPO and SimPO share similar objectives but have different goals. CPO adds a BC-regularizer to prevent the model from deviating too much from the preferred data distribution. SimPO incorporates length normalization and target reward margin to improve model performance and prevent the generation of long but low-quality sequences: These two objectives can be jointly used, which we call CPO-SimPO: |
@ccdv-ai I think this can be done via the existing set of hyperparams of setting
|
## Summary Made #417 from the main repo. Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR is the s first split from #408, focusing solely on introducing the Knowledge Distillation base class. As a result, this PR does not include any tests at the moment. #### Code Changes 1. Refactor `beta` into two weights: `weight_hard_loss` and `weight_soft_loss`, as coefficients between `hard_loss` and `soft_loss`. @Tcc0403 also pointed out that we could use `torch.lerp` if applicable. 2. Pass `teacher_logits` and `student_logits` directly to the divergence loss function. This avoids redundant computations of converting logits to log probabilities and then reverting them to raw logits. However note that we are not reusing the `student_log_probs` value calculated during `ce_loss` in distillation base. 1. Remove the unnecessary `get_batch_logps` in `test/utils.py`. 3. Modify `chunking` dimensions from `B` to `B * T`. Thanks to @hongpeng-guo's great advice. 1. Fix the loss calculation to use per-token values instead of averaging across the sequence length dimension. 4. Normalize the `distillation_loss` using `(full_target != ignore_index).sum()`. #### TODO 1. [X] Although a slightly slowdown is reasonable, we need to investigate why this PR's implementation is **significantly slower** compared to the naive approach. Thanks to @Tcc0403 's clarification. The issue arises because we are not properly configuring the `chunk_size` for the `B * T` dimension, which is extremely large (a few thousand). The previous default of 1 results in an excessive number of chunks. In contrast, this problem does not occur with the preference loss, as chunking is performed on the `B` dimension. This produces fewer than 10 chunks, which is efficient and works as expected. In conclusion, I set `chunk_size` to `1024` works pretty well in new benchmark results as shown in #425 2. [ ] #417 (comment) #### Knowledge Distillation Knowledge Distillation (KD; [Hinton et al. 2015](https://arxiv.org/abs/1503.02531), [Gou et al. 2020](https://arxiv.org/abs/2006.05525)) is a straightforward way to build a smaller, cheaper model (“student model”) to speed up inference by transferring skills from a pre-trained expensive model (“teacher model”) into the student. In knowledge distillation, a student model is trained to replicate the outputs of a teacher model using a distillation loss. Neural networks typically include a softmax layer; for instance, a large language model produces a probability distribution over tokens. Let `z_t` and `z_s` represent the logits before the softmax layer for the teacher and student models, respectively. The distillation loss reduces the discrepancy between the two softmax outputs at a high temperature `T`. When ground truth labels `y` are available, this approach can be combined with a supervised learning objective, such as cross-entropy, to compare the student’s outputs with the ground truth. The combined loss function is defined as: ```math \mathcal{L}_{\text{knowledge distillation}} = \mathcal{w}_{\text{soft}} \cdot \mathcal{L}_{\text{distill}}(\mathbf{z_t}, \mathbf{z_s}, T) + \mathcal{w}_{\text{hard}} \cdot \mathcal{L}_{\text{cross entropy}}(\mathbf{y}, \mathbf{z_s}), ``` Here, we directly pass in `logits` rather than `logpbs`. @Tcc0403 #### Shared `DistillationBase` To support various distillation learning objectives, this PR aims to add a `LigerFusedLinearDistillationBase` which is basically same as propose by @hongpeng-guo within this discussion #371 (comment). Thank you @hongpeng-guo for thinking through this. ## Testing Done I'll post JSD tests and benchmarks results in next PR: #425 - Hardware Type: L40S - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu <[email protected]> Co-authored-by: shivam15s <[email protected]>
there is an update about KTO on #410 |
🚀 The feature, motivation and pitch
We want to support various alignment and distillation loss functions.
Refer this PR on ORPO: #362
Progress
Alignment
Distillation
Design
Approach Overview:
The core idea is to extend the methods used in chunked Fused Linear Cross Entropy (FLCE) to various alignment algorithms. Here's how the process is structured:
By combining these strategies, we efficiently optimize alignment algorithms while also simplifying development.
Key Findings
By leveraging torch.compile alongside optimization techniques like chunking, online softmax, etc, we observed close to custom triton kernel performance and reduced development time. This is why we want to introduce torch.compile as a key component of Liger.
References:
Interface
Have a base class
FlexChunkLoss
that handles chunking, accumulation and compiling strategies.A custom loss class wraps the
FlexChunkLoss
and implements the loss fn that operates on a given chunk.Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: