Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TVD (Total variation distance) Kernel #281

Closed
qingquansong opened this issue Sep 28, 2024 · 3 comments · Fixed by #324
Closed

Add TVD (Total variation distance) Kernel #281

qingquansong opened this issue Sep 28, 2024 · 3 comments · Fixed by #324
Assignees
Labels

Comments

@qingquansong
Copy link
Collaborator

qingquansong commented Sep 28, 2024

🚀 The feature, motivation and pitch

TVD is a good distance metric (ref) and easy to implement kernel to make the gradient more stable compared to KL divergence and JS Divergence.

Alternatives

No response

Additional context

No response

@S1ro1
Copy link
Contributor

S1ro1 commented Sep 30, 2024

I'll look into it over the week if noone else takes.

@saurabhkoshatwar
Copy link
Contributor

saurabhkoshatwar commented Oct 22, 2024

#take @ByronHsu @qingquansong , I’d like to make an attempt. Could you please assign it to me?

@ByronHsu
Copy link
Collaborator

assigned to you. Thanks!

shivam15s pushed a commit that referenced this issue Feb 21, 2025
## Summary
Resolves [#281](#281).
Implements the TVD (Total Variation Distance) kernel by computing both
the loss and gradient in the forward pass.

## Testing Done
Implemented tests to verify that the results of the forward and backward
passes match the Torch implementation. Additionally, added a script to
benchmark the memory usage and speed of the Liger implementation
compared to Torch, with the results shown below.


![tvd_speed](https://github.com/user-attachments/assets/05080030-81ae-4751-aba2-f001a5144072)


![tvd_memory](https://github.com/user-attachments/assets/d7dbb9d9-5fd5-4ffe-aa12-8ba785e09857)

- Hardware Type: Nvidia H100 (80GB PCIe)
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Shao Tang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants