Skip to content

Commit

Permalink
first rough draft
Browse files Browse the repository at this point in the history
  • Loading branch information
iamlemec committed Feb 25, 2024
1 parent 4787ec3 commit 0ce2269
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 328 deletions.
86 changes: 85 additions & 1 deletion llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,4 +818,88 @@ def sample(
def accept(self, ctx_main: _LlamaContext, id: int, apply_grammar: bool):
if apply_grammar and self.grammar is not None:
ctx_main.grammar_accept_token(self.grammar, id)
self.prev.append(id)
self.prev.append(id)

class _TokenTextQueue:
def __init__(self, detokenize, stop_sequences: List[int] = None):
# settings
self.detokenize = detokenize
self.stop_sequences = stop_sequences or []

# current state
self.tokens: List[int] = []

def __len__(self):
return len(self.tokens)

@staticmethod
def decode_robust(bstr):
try:
return bstr.decode("utf-8")
except UnicodeError:
return

def detect_stop_token(self):
text = self.detokenize(self.tokens)
stop_idxs = [text.index(s) for s in self.stop_sequences if s in text]
if len(stop_idxs) > 0:
return text[:min(stop_idxs)]

# detect first index of partial stop sequence
def first_stop_position(self):
text = self.detokenize(self.tokens)
length = len(text)
first_stop_len = 0
for s in self.stop_sequences:
for i in range(min(len(s), length), 0, -1):
if text.endswith(s[:i]):
if i > first_stop_len:
first_stop_len = i
break
return length - first_stop_len

def push_token(self, token: int):
self.tokens.append(token)

def pop_text(self) -> bytes:
if len(self) == 0:
return

# attempt decode on substrings
for i in range(1, len(self.tokens) + 1):
bstr = self.detokenize(self.tokens[:i])
text = self.decode_robust(bstr)
if text is not None:
break

# all remaining tokens cannot be decoded to a UTF-8 character
if text is None:
return

# avoid yield if possible stop sequence in progress
if len(bstr) > self.first_stop_position():
return

# trim token list
self.tokens = self.tokens[i:]

return i, bstr, text

def empty_text(self):
text = ""
position = 0
end_position = self.first_stop_position()

for token in self.tokens:
last_text = self.detokenize([token])
position += len(last_text)

if position >= end_position:
text += last_text[
: len(last_text) - (position - end_position)
].decode("utf-8", errors="ignore")
break

text += last_text.decode("utf-8", errors="ignore")

return text
Loading

0 comments on commit 0ce2269

Please sign in to comment.