Skip to content

Commit

Permalink
Fix lookahead sample error & add update strategy (#10894)
Browse files Browse the repository at this point in the history
* Fix sample error & add update strategy

* add mtl config

* fix style

* remove print
  • Loading branch information
cyita authored Apr 28, 2024
1 parent 94b4e96 commit 015d07a
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions python/llm/src/ipex_llm/transformers/lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
_crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.utils import get_xpu_device_type

logger = logging.getLogger("ipex_llm.lookup")

Expand Down Expand Up @@ -119,10 +120,16 @@ def __init__(
self,
num_output_tokens: int = 10,
max_matching_ngram_size: int = None,
device: str = "arc",
):
self.num_output_tokens = num_output_tokens
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2

if device == "mtl":
self.max_candidates = 3
else:
self.max_candidates = 9

invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
"Invalid max_matching_ngram_size or num_output_tokens")

Expand Down Expand Up @@ -183,25 +190,18 @@ def get_candidates(self,
# so returning None
return candidate_input_ids, None

def update_candidate_strategy(self, input_ids: torch.LongTensor,
scores: torch.FloatTensor, num_matches: int):
def update_candidate_strategy(self, candidate_num: int, num_matches: int):
"""
Updates the candidate generation strategy based on the outcomes.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
[What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length,
config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each
vocabulary when not using beam search or log softmax for each vocabulary
token when using beam search
num_matches (`int`):
The number of matches between the candidate sequences and the model predictions.
"""
# Currently does nothing
return
if num_matches == self.num_output_tokens:
self.num_output_tokens = min(self.num_output_tokens + 1, self.max_candidates)
elif candidate_num > num_matches:
self.num_output_tokens = max(self.num_output_tokens - 1, 1)


@torch.no_grad()
Expand All @@ -217,9 +217,12 @@ def lookup_generate(self,
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
**sampling_kwargs)

device_name = get_xpu_device_type(input_ids)

candidates_generator = PromptLookupCandidateGenerator(
num_output_tokens=num_output_tokens,
max_matching_ngram_size=max_matching_ngram_size)
max_matching_ngram_size=max_matching_ngram_size,
device=device_name)

step = 0
step_verify = 0
Expand Down Expand Up @@ -291,6 +294,7 @@ def lookup_generate(self,
top_k=generation_config.top_k,
top_p=generation_config.top_p,
temperature=generation_config.temperature)
output_ids = output_ids.transpose(0, 1)
else:
output_ids = greedy(logits)

Expand All @@ -303,13 +307,14 @@ def lookup_generate(self,
# Drafts start from [1, k]
# Verified output start from [0, k - 1]
# including the one generated by the base model
max_matched = ((output_ids[:, :-1] != verify_input_ids[:, 1:]).cumsum(-1) == 0)
max_matched = max_matched.sum(-1).item() + 1
n_matches = ((output_ids[:, :-1] != verify_input_ids[:, 1:])
.cumsum(-1) == 0).sum(-1).item()
max_matched = n_matches + 1

max_of_max_matched = output_ids.size(1)
# Accept number is max_matched, min is 1
self.accept_num.append(max_matched)
self.n_matched += max_matched - 1
self.n_matched += n_matches
self.n_drafted += candidate_length

# Clean up target model KV cache
Expand All @@ -319,6 +324,9 @@ def lookup_generate(self,
past_key_values = _crop_past_key_values(self, past_key_values,
new_cache_size)

# Update the candidate generation strategy if needed
candidates_generator.update_candidate_strategy(candidate_length, n_matches)

input_ids = torch.cat((input_ids, output_ids), dim=-1)

step += output_ids.size(1)
Expand Down

0 comments on commit 015d07a

Please sign in to comment.