Skip to content

Commit

Permalink
[Bugfix] Fix speculative decoding with MLPSpeculator with padded voca…
Browse files Browse the repository at this point in the history
…bulary (vllm-project#7218)

Signed-off-by: Travis Johnson <[email protected]>
  • Loading branch information
tjohnson31415 authored Aug 9, 2024
1 parent e02ac55 commit 99b4cf5
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 5 deletions.
60 changes: 60 additions & 0 deletions tests/spec_decode/e2e/test_mlp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
correctess for the target model outputs.
"""

from unittest.mock import patch

import pytest

from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size

from .conftest import (run_equality_correctness_test,
run_greedy_equality_correctness_test)

Expand Down Expand Up @@ -178,6 +182,62 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality when the vocab dimension is padded
"""

# Default pad_to is 64, test model has vocab_size of 32000
def patched_pad_vocab_size(vocab_size, pad_to=None):
return pad_vocab_size(vocab_size, pad_to=32064)

with patch(
"vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size",
patched_pad_vocab_size):
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _get_logits(self, hidden_states: torch.Tensor,
logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
logits = logits[..., :self.org_vocab_size]
return logits

def extra_repr(self) -> str:
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def forward(
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if self._strict_mode:
self._raise_if_incorrect_input(target_probs, bonus_token_ids,
draft_probs, draft_token_ids)
self._raise_if_incorrect_input(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)

accepted, recovered_token_ids = (
self._batch_modified_rejection_sampling(
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/mlp_speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,14 @@ def generate_proposals(
states.add_(z, alpha=self.emb_weight / self.state_weight)

states = self.activation(self.ln[head_index](states)) # b k d
# TODO: not yet supporting top_k_tokens_per_head
previous_hidden_states = states
# TODO: not yet supporting top_k_tokens_per_head
states = states.flatten(0, 1)

logits = self.logits_processor(self.head[head_index], states,
sampling_metadata)

output = self.sampler(logits.flatten(0, 1), sampling_metadata)
output = self.sampler(logits, sampling_metadata)
last_tokens = output.sampled_token_ids
next_tokens.append(output)

Expand Down

0 comments on commit 99b4cf5

Please sign in to comment.