Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Faster logit_bias_logits_processor #13334

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

xu-song
Copy link
Contributor

@xu-song xu-song commented Feb 15, 2025

This PR changes python ops to tensor ops, which reduce time cost from 106ms to 0.01ms.

Before

    for token_id, bias in logit_bias.items():
        logits[token_id] += bias

The above approach is time consuming especially when len(logit_bias) is very large.

After

    logits.index_add_(0, logit_bias["index"], logit_bias["value"]) 

Time Cost

before -> v1 -> v2

len(logit_bias) time cost (ms)
1 4.5 -> 0.3 -> 0.01
20 4.5 -> 0.3 -> 0.01
100 5.3 -> 0.3 -> 0.01
1000 14.4 -> 0.3 -> 0.01
10000 106 -> 0.4 -> 0.01

experiment settings:

GPU: A100 
model: Llama-3.2-1B-Instruct

impl history

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Xu Song <[email protected]>
Signed-off-by: Xu Song <[email protected]>
@xu-song xu-song changed the title Optimize logit_bias_logits_processor [Core] Optimize logit_bias_logits_processor Feb 15, 2025
@imkero
Copy link
Contributor

imkero commented Feb 15, 2025

If len(logit_bias) is large, maybe we can keep the copy of logit_bias["index"] and logit_bias["value"] in the device memory ahead of time (or in the first sample step), and re-use it in the following sample steps, to avoid duplicated tensor copy?

@xu-song
Copy link
Contributor Author

xu-song commented Feb 17, 2025

If len(logit_bias) is large, maybe we can keep the copy of logit_bias["index"] and logit_bias["value"] in the device memory ahead of time (or in the first sample step), and re-use it in the following sample steps, to avoid duplicated tensor copy?

@imkero Thanks for your suggestion, a new commit has been added, which avoid duplicated tensor copy.

After this change, the time_cost is reduced to 0.01ms

len(logit_bias) time cost (ms)
1 4.5 -> 0.01
20 4.5 -> 0.01
100 5.3 -> 0.01
1000 14.4 -> 0.01
10000 106 -> 0.01

@xu-song xu-song changed the title [Core] Optimize logit_bias_logits_processor [Core] Faster logit_bias_logits_processor Feb 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants