-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
speed up PrefixConstrainedLogitsProcessor #35275
base: main
Are you sure you want to change the base?
speed up PrefixConstrainedLogitsProcessor #35275
Conversation
@gante @zucchini-nlp Hi, is there anything else you need me to do before this can be reviewed? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @vukasin ! Thanks for your PR.
If we want to create the mask on CPU, we can do that with torch.full
. Any reason for choosing numpy
?
no particular reason, I can change it to use torch instead |
updated the code so it keeps to torch functions and only manipulates the device |
@zucchini-nlp good call, this just became a pretty trivial change :-) but still has the same speedup (in my tests at least) |
@zucchini-nlp anything else I need to do to get this moving forward? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, missed this one! Thanks for iterating, LGTM! I will request one more review from @gante before merging, as from the generation code owner
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@gante anything else I need to do? |
Hi @vukasin Thank you for your PR! For future reference, please share a stand-alone benchmarking script so that everyone can immediately assess the impact of the change :) In this particular case, the proposed change actually results in a slowdown if everything is done as efficiently as possible. The best speed is achievable when I would accept two changes: For full reference, here's a benchmarking script import torch
import math
import time
from tqdm import tqdm
NUM_BEAMS = 4
BATCH_SIZE = 2
VOCAB_SIZE = 50256
def prefix_allowed_tokens_fn_cpu(batch_id, input_ids):
return input_ids.tolist()
def prefix_allowed_tokens_fn_gpu(batch_id, input_ids):
return input_ids
def cpu_mask(input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.full_like(scores, -math.inf, device=torch.cpu.current_device())
for batch_id, beam_sent in enumerate(input_ids.view(-1, NUM_BEAMS, input_ids.shape[-1])):
for beam_id, sent in enumerate(beam_sent):
prefix_allowed_tokens = prefix_allowed_tokens_fn_cpu(batch_id, sent)
if len(prefix_allowed_tokens) == 0:
raise ValueError(
f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
f"This means that the constraint is unsatisfiable. Please check your implementation"
f"of `prefix_allowed_tokens_fn` "
)
mask[batch_id * NUM_BEAMS + beam_id, prefix_allowed_tokens] = 0
scores_processed = scores + mask.to(scores.device)
return scores_processed
def gpu_mask(input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.full_like(scores, -math.inf)
for batch_id, beam_sent in enumerate(input_ids.view(-1, NUM_BEAMS, input_ids.shape[-1])):
for beam_id, sent in enumerate(beam_sent):
prefix_allowed_tokens = prefix_allowed_tokens_fn_gpu(batch_id, sent)
if len(prefix_allowed_tokens) == 0:
raise ValueError(
f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
f"This means that the constraint is unsatisfiable. Please check your implementation"
f"of `prefix_allowed_tokens_fn` "
)
mask[batch_id * NUM_BEAMS + beam_id, prefix_allowed_tokens] = 0
scores_processed = scores + mask
return scores_processed
cpu_times = []
gpu_times = []
for _ in tqdm(range(1000)):
input_ids = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE * NUM_BEAMS, 5)).to("cuda")
scores = torch.randn(BATCH_SIZE * NUM_BEAMS, VOCAB_SIZE).to("cuda")
start = time.time()
cpu_mask(input_ids, scores)
end = time.time()
cpu_times.append(end - start)
start = time.time()
gpu_mask(input_ids, scores)
end = time.time()
gpu_times.append(end - start)
print("CPU time:", sum(cpu_times) / len(cpu_times))
print("GPU time:", sum(gpu_times) / len(gpu_times)) |
Speed up PrefixConstrainedLogitsProcessor
The
__call__
method creates an empty tensor and then updates it repeatedly. This causes a significant slowdown when executing on a GPU since it requires updating the GPU memory over and over. This pull request creates the mask in CPU memory and then creates a tensor on GPU with one call. The tests I ran (the test case from https://github.com/worldbank/REaLTabFormer/tree/main/colab) has a speed increase of over 200% up to depending on the batch size, 300%.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.