diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 3c7cc58ae..0a53ffeca 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -818,4 +818,88 @@ def sample( def accept(self, ctx_main: _LlamaContext, id: int, apply_grammar: bool): if apply_grammar and self.grammar is not None: ctx_main.grammar_accept_token(self.grammar, id) - self.prev.append(id) \ No newline at end of file + self.prev.append(id) + +class _TokenTextQueue: + def __init__(self, detokenize, stop_sequences: List[int] = None): + # settings + self.detokenize = detokenize + self.stop_sequences = stop_sequences or [] + + # current state + self.tokens: List[int] = [] + + def __len__(self): + return len(self.tokens) + + @staticmethod + def decode_robust(bstr): + try: + return bstr.decode("utf-8") + except UnicodeError: + return + + def detect_stop_token(self): + text = self.detokenize(self.tokens) + stop_idxs = [text.index(s) for s in self.stop_sequences if s in text] + if len(stop_idxs) > 0: + return text[:min(stop_idxs)] + + # detect first index of partial stop sequence + def first_stop_position(self): + text = self.detokenize(self.tokens) + length = len(text) + first_stop_len = 0 + for s in self.stop_sequences: + for i in range(min(len(s), length), 0, -1): + if text.endswith(s[:i]): + if i > first_stop_len: + first_stop_len = i + break + return length - first_stop_len + + def push_token(self, token: int): + self.tokens.append(token) + + def pop_text(self) -> bytes: + if len(self) == 0: + return + + # attempt decode on substrings + for i in range(1, len(self.tokens) + 1): + bstr = self.detokenize(self.tokens[:i]) + text = self.decode_robust(bstr) + if text is not None: + break + + # all remaining tokens cannot be decoded to a UTF-8 character + if text is None: + return + + # avoid yield if possible stop sequence in progress + if len(bstr) > self.first_stop_position(): + return + + # trim token list + self.tokens = self.tokens[i:] + + return i, bstr, text + + def empty_text(self): + text = "" + position = 0 + end_position = self.first_stop_position() + + for token in self.tokens: + last_text = self.detokenize([token]) + position += len(last_text) + + if position >= end_position: + text += last_text[ + : len(last_text) - (position - end_position) + ].decode("utf-8", errors="ignore") + break + + text += last_text.decode("utf-8", errors="ignore") + + return text diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index bfda45ef8..02d6a8ce4 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -48,6 +48,7 @@ _LlamaTokenDataArray, # type: ignore _LlamaSamplingParams, # type: ignore _LlamaSamplingContext, # type: ignore + _TokenTextQueue, # type: ignore ) from ._logger import set_verbose from ._utils import ( @@ -523,7 +524,7 @@ def eval(self, tokens: Sequence[int]): # Update n_tokens self.n_tokens += n_tokens - def eval_parallel(self, tokens: List[int]): + def _eval_parallel(self, tokens: List[int]): """Evaluate a list of tokens in different sequences but at the same position. Args: @@ -745,7 +746,7 @@ def longest_common_prefix(vecs): else: return vecs[0][:max_len] - def generate_parallel( + def _generate_parallel( self, tokens: List[List[int]], max_tokens: Optional[int] = None, @@ -964,6 +965,7 @@ def decode_batch(n_seq: int): else: return output + # TODO: reintegrate logprobs def _create_completion( self, prompt: Union[str, List[int]], @@ -1090,8 +1092,13 @@ def logit_bias_processor( if seed is not None: self._ctx.set_rng_seed(seed) + # create text token queue and state + queue = _TokenTextQueue(self.detokenize, stop_sequences=stop_sequences) + all_toks = 0 + all_text = b"" finish_reason = "length" multibyte_fix = 0 + for token in self.generate( prompt_tokens, top_k=top_k, @@ -1111,14 +1118,9 @@ def logit_bias_processor( grammar=grammar, ): if token == self._token_eos: - text = self.detokenize(completion_tokens) finish_reason = "stop" break - completion_tokens.append(token) - - all_text = self.detokenize(completion_tokens) - # Contains multi-byte UTF8 for k, char in enumerate(all_text[-3:]): k = 3 - k @@ -1132,345 +1134,48 @@ def logit_bias_processor( multibyte_fix -= 1 continue - any_stop = [s for s in stop_sequences if s in all_text] - if len(any_stop) > 0: - first_stop = any_stop[0] - text = all_text[: all_text.index(first_stop)] + # Add token to queue + print(self.detokenize([token])) + queue.push_token(token) + + # TODO: will this find the actual first stop? + stop_text = queue.detect_stop_token() + if stop_text is not None: + all_text += stop_text finish_reason = "stop" break - if stream: - remaining_tokens = completion_tokens[returned_tokens:] - remaining_text = self.detokenize(remaining_tokens) - remaining_length = len(remaining_text) - - # We want to avoid yielding any characters from - # the generated text if they are part of a stop - # sequence. - first_stop_position = 0 - for s in stop_sequences: - for i in range(min(len(s), remaining_length), 0, -1): - if remaining_text.endswith(s[:i]): - if i > first_stop_position: - first_stop_position = i - break - - token_end_position = 0 - - if logprobs is not None: - # not sure how to handle this branch when dealing - # with CJK output, so keep it unchanged - for token in remaining_tokens: - if token == self.token_bos(): - continue - token_end_position += len(self.detokenize([token])) - # Check if stop sequence is in the token - if token_end_position > ( - remaining_length - first_stop_position - ): - break - token_str = self.detokenize([token]).decode( - "utf-8", errors="ignore" - ) - text_offset = len(prompt) + len( - self.detokenize(completion_tokens[:returned_tokens]).decode( - "utf-8", errors="ignore" - ) - ) - token_offset = len(prompt_tokens) + returned_tokens - logits = self._scores[token_offset - 1, :] - current_logprobs = Llama.logits_to_logprobs(logits).tolist() - sorted_logprobs = list( - sorted( - zip(current_logprobs, range(len(current_logprobs))), - reverse=True, - ) - ) - top_logprob = { - self.detokenize([i]).decode( - "utf-8", errors="ignore" - ): logprob - for logprob, i in sorted_logprobs[:logprobs] - } - top_logprob.update({token_str: current_logprobs[int(token)]}) - logprobs_or_none = { - "tokens": [ - self.detokenize([token]).decode( - "utf-8", errors="ignore" - ) - ], - "text_offset": [text_offset], - "token_logprobs": [current_logprobs[int(token)]], - "top_logprobs": [top_logprob], - } - returned_tokens += 1 - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": self.detokenize([token]).decode( - "utf-8", errors="ignore" - ), - "index": 0, - "logprobs": logprobs_or_none, - "finish_reason": None, - } - ], - } - else: - while len(remaining_tokens) > 0: - decode_success = False - for i in range(1, len(remaining_tokens) + 1): - try: - bs = self.detokenize(remaining_tokens[:i]) - ts = bs.decode("utf-8") - decode_success = True - break - except UnicodeError: - pass - else: - break - if not decode_success: - # all remaining tokens cannot be decoded to a UTF-8 character - break - token_end_position += len(bs) - if token_end_position > ( - remaining_length - first_stop_position - ): - break - remaining_tokens = remaining_tokens[i:] - returned_tokens += i - - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": ts, - "index": 0, - "logprobs": None, - "finish_reason": None, - } - ], - } - - if len(completion_tokens) >= max_tokens: - text = self.detokenize(completion_tokens) + # Generate text chunks + while len(queue) > 0: + ret = queue.pop_text() + if ret is None: + break + nt, bs, ts = ret + all_toks += nt + all_text += bs + yield ts + + if all_toks >= max_tokens: finish_reason = "length" break + print(all_text) + if stopping_criteria is not None and stopping_criteria( self._input_ids, self._scores[-1, :] ): - text = self.detokenize(completion_tokens) finish_reason = "stop" if self.verbose: self._ctx.print_timings() - if stream: - remaining_tokens = completion_tokens[returned_tokens:] - all_text = self.detokenize(remaining_tokens) - any_stop = [s for s in stop_sequences if s in all_text] - if len(any_stop) > 0: - end = min(all_text.index(stop) for stop in any_stop) - else: - end = len(all_text) - - token_end_position = 0 - for token in remaining_tokens: - token_end_position += len(self.detokenize([token])) - - logprobs_or_none: Optional[CompletionLogprobs] = None - if logprobs is not None: - if token == self.token_bos(): - continue - token_str = self.detokenize([token]).decode( - "utf-8", errors="ignore" - ) - text_offset = len(prompt) + len( - self.detokenize(completion_tokens[:returned_tokens]) - ) - token_offset = len(prompt_tokens) + returned_tokens - 1 - logits = self._scores[token_offset, :] - current_logprobs = Llama.logits_to_logprobs(logits).tolist() - sorted_logprobs = list( - sorted( - zip(current_logprobs, range(len(current_logprobs))), - reverse=True, - ) - ) - top_logprob = { - self.detokenize([i]).decode("utf-8", errors="ignore"): logprob - for logprob, i in sorted_logprobs[:logprobs] - } - top_logprob.update({token_str: current_logprobs[int(token)]}) - logprobs_or_none = { - "tokens": [ - self.detokenize([token]).decode("utf-8", errors="ignore") - ], - "text_offset": [text_offset], - "token_logprobs": [current_logprobs[int(token)]], - "top_logprobs": [top_logprob], - } - - if token_end_position >= end: - last_text = self.detokenize([token]) - if token_end_position == end - 1: - break - returned_tokens += 1 - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": last_text[ - : len(last_text) - (token_end_position - end) - ].decode("utf-8", errors="ignore"), - "index": 0, - "logprobs": logprobs_or_none, - "finish_reason": None, - } - ], - } - break - returned_tokens += 1 - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": self.detokenize([token]).decode( - "utf-8", errors="ignore" - ), - "index": 0, - "logprobs": logprobs_or_none, - "finish_reason": None, - } - ], - } - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": "", - "index": 0, - "logprobs": None, - "finish_reason": finish_reason, - } - ], - } - if self.cache: - if self.verbose: - print("Llama._create_completion: cache save", file=sys.stderr) - self.cache[prompt_tokens + completion_tokens] = self.save_state() - print("Llama._create_completion: cache saved", file=sys.stderr) - return + yield finish_reason if self.cache: if self.verbose: print("Llama._create_completion: cache save", file=sys.stderr) self.cache[prompt_tokens + completion_tokens] = self.save_state() - - text_str = text.decode("utf-8", errors="ignore") - - if echo: - text_str = prompt + text_str - - if suffix is not None: - text_str = text_str + suffix - - logprobs_or_none: Optional[CompletionLogprobs] = None - if logprobs is not None: - text_offset = 0 if echo else len(prompt) - token_offset = 0 if echo else len(prompt_tokens[1:]) - text_offsets: List[int] = [] - token_logprobs: List[Optional[float]] = [] - tokens: List[str] = [] - top_logprobs: List[Optional[Dict[str, float]]] = [] - - if echo: - # Remove leading BOS token - all_tokens = prompt_tokens[1:] + completion_tokens - else: - all_tokens = completion_tokens - - all_token_strs = [ - self.detokenize([token]).decode("utf-8", errors="ignore") - for token in all_tokens - ] - all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:] - # TODO: may be able to change this loop to use np.take_along_dim - for idx, (token, token_str, logprobs_token) in enumerate( - zip(all_tokens, all_token_strs, all_logprobs) - ): - if token == self.token_bos(): - continue - text_offsets.append( - text_offset - + len( - self.detokenize(all_tokens[:idx]).decode( - "utf-8", errors="ignore" - ) - ) - ) - tokens.append(token_str) - sorted_logprobs = list( - sorted( - zip(logprobs_token, range(len(logprobs_token))), reverse=True - ) - ) - token_logprobs.append(logprobs_token[int(token)]) - top_logprob: Optional[Dict[str, float]] = { - self.detokenize([i]).decode("utf-8", errors="ignore"): logprob - for logprob, i in sorted_logprobs[:logprobs] - } - top_logprob.update({token_str: logprobs_token[int(token)]}) - top_logprobs.append(top_logprob) - # Weird idosincracy of the OpenAI API where - # token_logprobs and top_logprobs are null for - # the first token. - if echo and len(all_tokens) > 0: - token_logprobs[0] = None - top_logprobs[0] = None - logprobs_or_none = { - "tokens": tokens, - "text_offset": text_offsets, - "token_logprobs": token_logprobs, - "top_logprobs": top_logprobs, - } - - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": text_str, - "index": 0, - "logprobs": logprobs_or_none, - "finish_reason": finish_reason, - } - ], - "usage": { - "prompt_tokens": len(prompt_tokens), - "completion_tokens": len(completion_tokens), - "total_tokens": len(prompt_tokens) + len(completion_tokens), - }, - } + print("Llama._create_completion: cache saved", file=sys.stderr) def create_completion( self,