Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speed up PrefixConstrainedLogitsProcessor #35275

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

vukasin
Copy link

@vukasin vukasin commented Dec 14, 2024

Speed up PrefixConstrainedLogitsProcessor

The __call__ method creates an empty tensor and then updates it repeatedly. This causes a significant slowdown when executing on a GPU since it requires updating the GPU memory over and over. This pull request creates the mask in CPU memory and then creates a tensor on GPU with one call. The tests I ran (the test case from https://github.com/worldbank/REaLTabFormer/tree/main/colab) has a speed increase of over 200% up to depending on the batch size, 300%.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Rocketknight1
Copy link
Member

cc @gante @zucchini-nlp

@vukasin
Copy link
Author

vukasin commented Dec 18, 2024

@gante @zucchini-nlp Hi, is there anything else you need me to do before this can be reviewed?

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @vukasin ! Thanks for your PR.

If we want to create the mask on CPU, we can do that with torch.full. Any reason for choosing numpy?

@vukasin
Copy link
Author

vukasin commented Dec 19, 2024

no particular reason, I can change it to use torch instead

@vukasin
Copy link
Author

vukasin commented Dec 19, 2024

updated the code so it keeps to torch functions and only manipulates the device

@vukasin
Copy link
Author

vukasin commented Dec 19, 2024

@zucchini-nlp good call, this just became a pretty trivial change :-) but still has the same speedup (in my tests at least)

@vukasin vukasin requested a review from zucchini-nlp December 20, 2024 14:40
@vukasin
Copy link
Author

vukasin commented Jan 5, 2025

@zucchini-nlp anything else I need to do to get this moving forward?

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, missed this one! Thanks for iterating, LGTM! I will request one more review from @gante before merging, as from the generation code owner

@zucchini-nlp zucchini-nlp requested a review from gante January 9, 2025 12:08
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vukasin
Copy link
Author

vukasin commented Jan 29, 2025

@gante anything else I need to do?

@gante
Copy link
Member

gante commented Jan 30, 2025

Hi @vukasin

Thank you for your PR! For future reference, please share a stand-alone benchmarking script so that everyone can immediately assess the impact of the change :)

In this particular case, the proposed change actually results in a slowdown if everything is done as efficiently as possible. The best speed is achievable when prefix_allowed_tokens_fn consists of tensor manipulations, where the tensors never leave the GPU -- the slowdown you observe is because we are slicing mask, a tensor on the GPU, with a tensor located on the CPU. If prefix_allowed_tokens_fn relies on CPU operations, then your change is slightly faster (as measured on my machine).

I would accept two changes:
1 - document this important property that prefix_allowed_tokens_fn should have for optimal performance 🤗
2 - vectorize the mask update: collect all prefix_allowed_tokens to then update mask once. Because the update is done once, it should be faster when prefix_allowed_tokens_fn relies on CPU ops (at no cost of performance on the GPU case)


For full reference, here's a benchmarking script

import torch
import math
import time
from tqdm import tqdm

NUM_BEAMS = 4
BATCH_SIZE = 2
VOCAB_SIZE = 50256

def prefix_allowed_tokens_fn_cpu(batch_id, input_ids):
    return input_ids.tolist()

def prefix_allowed_tokens_fn_gpu(batch_id, input_ids):
    return input_ids


def cpu_mask(input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
    mask = torch.full_like(scores, -math.inf, device=torch.cpu.current_device())
    for batch_id, beam_sent in enumerate(input_ids.view(-1, NUM_BEAMS, input_ids.shape[-1])):
        for beam_id, sent in enumerate(beam_sent):
            prefix_allowed_tokens = prefix_allowed_tokens_fn_cpu(batch_id, sent)
            if len(prefix_allowed_tokens) == 0:
                raise ValueError(
                    f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
                    f"This means that the constraint is unsatisfiable. Please check your implementation"
                    f"of `prefix_allowed_tokens_fn` "
                )
            mask[batch_id * NUM_BEAMS + beam_id, prefix_allowed_tokens] = 0

    scores_processed = scores + mask.to(scores.device)
    return scores_processed


def gpu_mask(input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
    mask = torch.full_like(scores, -math.inf)
    for batch_id, beam_sent in enumerate(input_ids.view(-1, NUM_BEAMS, input_ids.shape[-1])):
        for beam_id, sent in enumerate(beam_sent):
            prefix_allowed_tokens = prefix_allowed_tokens_fn_gpu(batch_id, sent)
            if len(prefix_allowed_tokens) == 0:
                raise ValueError(
                    f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
                    f"This means that the constraint is unsatisfiable. Please check your implementation"
                    f"of `prefix_allowed_tokens_fn` "
                )
            mask[batch_id * NUM_BEAMS + beam_id, prefix_allowed_tokens] = 0

    scores_processed = scores + mask
    return scores_processed


cpu_times = []
gpu_times = []
for _ in tqdm(range(1000)):
    input_ids = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE * NUM_BEAMS, 5)).to("cuda")
    scores = torch.randn(BATCH_SIZE * NUM_BEAMS, VOCAB_SIZE).to("cuda")

    start = time.time()
    cpu_mask(input_ids, scores)
    end = time.time()
    cpu_times.append(end - start)

    start = time.time()
    gpu_mask(input_ids, scores)
    end = time.time()
    gpu_times.append(end - start)

print("CPU time:", sum(cpu_times) / len(cpu_times))
print("GPU time:", sum(gpu_times) / len(gpu_times))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants