Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

logits API support #67

Merged
merged 25 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d4e88d3
implement logits and logprobs output, add debug print statements
qmeng222 Aug 29, 2024
f833649
debug (in progress)
qmeng222 Aug 30, 2024
946da23
Merge branch 'main' of https://github.com/NexaAI/nexa-sdk into qingyi…
qmeng222 Aug 30, 2024
e69c81c
implement the expected structure in response body
qmeng222 Sep 2, 2024
20f10a6
debug null logprobs in Nexa SDK server
qmeng222 Sep 3, 2024
e79a043
add additional intermediate print statements for enhanced debugging
qmeng222 Sep 3, 2024
3bad104
handle and accumulate logprobs for completion format
qmeng222 Sep 3, 2024
09f0d5c
clean code
qmeng222 Sep 4, 2024
12c1904
update NexaTextInference class to handle logprobs
qmeng222 Sep 4, 2024
7c15aa0
add logprobs functionality to Llama class
qmeng222 Sep 4, 2024
c4cd6fb
troubleshoot the CI issue
qmeng222 Sep 4, 2024
5a090c9
solve pr conflict: update completion chunk format for streaming
qmeng222 Sep 4, 2024
9fc5b98
Merge remote-tracking branch 'origin/main' into qingying-logits
qmeng222 Sep 4, 2024
fcdfc6c
fix logprobs handling consistency in llama.py
qmeng222 Sep 5, 2024
21d5cbc
implement logprobs and top_logprobs functionality
qmeng222 Sep 5, 2024
89780b9
fix TypeError in create_completion by conditionally passing top_logprobs
qmeng222 Sep 5, 2024
9ba5684
add top_logprobs parameter to Llama.create_completion method
qmeng222 Sep 5, 2024
49615ec
ensure logprobs is being passed as a boolean throughout the call chain
qmeng222 Sep 5, 2024
1c77636
set logprobs: true and top_logprobs: 4 by default for both the /v1/co…
qmeng222 Sep 5, 2024
a37f501
remove redundant logprobs CLI argument
qmeng222 Sep 5, 2024
329993f
comment out top_logprobs CLI argument
qmeng222 Sep 5, 2024
771b301
remove intermediate printouts
qmeng222 Sep 10, 2024
ce20a9e
minor changes: uncomment and reset defaults
qmeng222 Sep 11, 2024
e1ee1b6
clean code
qmeng222 Sep 11, 2024
08ba622
Merge branch 'main' of https://github.com/NexaAI/nexa-sdk into qingyi…
JoyboyBrian Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 72 additions & 15 deletions nexa/gguf/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
yarn_beta_fast: float = 32.0,
yarn_beta_slow: float = 1.0,
yarn_orig_ctx: int = 0,
logits_all: bool = False,
logits_all: bool = True, # switch
embedding: bool = False,
offload_kqv: bool = True,
flash_attn: bool = False,
Expand Down Expand Up @@ -335,9 +335,10 @@ def __init__(
yarn_beta_slow if yarn_beta_slow != 0.0 else 0
)
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
self.context_params.logits_all = (
logits_all if draft_model is None else True
) # Must be set to True for speculative decoding
# self.context_params.logits_all = (
# logits_all if draft_model is None else True
# ) # Must be set to True for speculative decoding
self.context_params.logits_all = True
self.context_params.embeddings = embedding # TODO: Rename to embeddings
self.context_params.offload_kqv = offload_kqv
self.context_params.flash_attn = flash_attn
Expand Down Expand Up @@ -662,6 +663,8 @@ def sample(
mirostat_tau: float = 5.0,
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
grammar: Optional[LlamaGrammar] = None,
idx: Optional[int] = None,
):
Expand Down Expand Up @@ -718,7 +721,27 @@ def sample(
id=id,
apply_grammar=grammar is not None,
)
return id

if logprobs is not None and (top_logprobs is not None and top_logprobs > 0):
sampled_logprobs = self.logits_to_logprobs(logits)
token_logprob = float(sampled_logprobs[id])

top_logprobs_dict = None
if top_logprobs is not None:
sorted_indices = sampled_logprobs.argsort()[::-1]
top_indices = sorted_indices[:top_logprobs]
top_logprobs_dict = {
self.detokenize([i]).decode("utf-8", errors="ignore"): float(sampled_logprobs[i])
for i in top_indices
}

return {
"token": id,
"token_logprob": token_logprob,
"top_logprobs": top_logprobs_dict
}
else:
return id

def generate(
self,
Expand All @@ -738,6 +761,8 @@ def generate(
mirostat_eta: float = 0.1,
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
grammar: Optional[LlamaGrammar] = None,
) -> Generator[int, Optional[Sequence[int]], None]:
Expand Down Expand Up @@ -777,7 +802,7 @@ def generate(
self.n_tokens = longest_prefix
if self.verbose:
print(f"Llama.generate: {longest_prefix} prefix-match hit, "
f"remaining {len(tokens)} prompt tokens to eval", file=sys.stderr)
f"remaining {len(tokens)} prompt tokens to eval", file=sys.stderr)

# Reset the model state
if reset:
Expand All @@ -794,7 +819,7 @@ def generate(
while True:
self.eval(tokens)
while sample_idx < self.n_tokens:
token = self.sample(
result = self.sample(
top_k=top_k,
top_p=top_p,
min_p=min_p,
Expand All @@ -808,17 +833,26 @@ def generate(
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
logits_processor=logits_processor,
logprobs=logprobs,
top_logprobs=top_logprobs,
grammar=grammar,
penalize_nl=penalize_nl,
idx=sample_idx,
)

if isinstance(result, dict):
token = result["token"]
logprobs_info = result
else:
token = result
logprobs_info = None

sample_idx += 1
if stopping_criteria is not None and stopping_criteria(
self._input_ids, self._scores[-1, :]
):
return
tokens_or_none = yield token
tokens_or_none = yield token, logprobs_info
tokens.clear()
tokens.append(token)
if tokens_or_none is not None:
Expand Down Expand Up @@ -1011,7 +1045,7 @@ def _create_completion(
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
logprobs: Optional[int] = None,
logprobs: Optional[bool] = None,
echo: bool = False,
stop: Optional[Union[str, List[str]]] = [],
frequency_penalty: float = 0.0,
Expand All @@ -1027,6 +1061,7 @@ def _create_completion(
model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
top_logprobs: Optional[int] = None,
grammar: Optional[LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
) -> Union[
Expand Down Expand Up @@ -1183,7 +1218,9 @@ def logit_bias_processor(

finish_reason = "length"
multibyte_fix = 0
for token in self.generate(
logprobs_or_none = None

for token, logprobs_info in self.generate(
prompt_tokens,
top_k=top_k,
top_p=top_p,
Expand All @@ -1199,6 +1236,8 @@ def logit_bias_processor(
repeat_penalty=repeat_penalty,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
logprobs=logprobs,
top_logprobs=top_logprobs,
grammar=grammar,
):
assert self._model.model is not None
Expand All @@ -1209,6 +1248,20 @@ def logit_bias_processor(

completion_tokens.append(token)

if logprobs_info and logprobs_or_none is None:
logprobs_or_none = {
"tokens": [],
"text_offset": [],
"token_logprobs": [],
"top_logprobs": []
}

if logprobs_info:
logprobs_or_none["tokens"].append(self.detokenize([token]).decode("utf-8", errors="ignore"))
logprobs_or_none["text_offset"].append(len(self.detokenize(completion_tokens[:-1])))
logprobs_or_none["token_logprobs"].append(logprobs_info["token_logprob"])
logprobs_or_none["top_logprobs"].append(logprobs_info["top_logprobs"])

all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)

# Contains multi-byte UTF8
Expand Down Expand Up @@ -1407,7 +1460,7 @@ def logit_bias_processor(
)
)

logprobs_or_none: Optional[CompletionLogprobs] = None
# logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
if token == bos_token_id:
continue
Expand Down Expand Up @@ -1492,7 +1545,10 @@ def logit_bias_processor(
{
"text": "",
"index": 0,
"logprobs": None,
"delta": {
"content": "",
},
"logprobs": logprobs_or_none,
"finish_reason": finish_reason,
}
],
Expand All @@ -1507,7 +1563,7 @@ def logit_bias_processor(
if suffix_token_id < 0 and suffix is not None:
text_str = text_str + suffix

logprobs_or_none: Optional[CompletionLogprobs] = None
# logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
text_offset = 0 if echo else len(prompt)
token_offset = 0 if echo else len(prompt_tokens[1:])
Expand Down Expand Up @@ -1603,7 +1659,7 @@ def create_completion(
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
logprobs: Optional[int] = None,
logprobs: Optional[bool] = None,
echo: bool = False,
stop: Optional[Union[str, List[str]]] = [],
frequency_penalty: float = 0.0,
Expand All @@ -1621,6 +1677,7 @@ def create_completion(
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
top_logprobs: Optional[int] = None,
) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
"""Generate text from a prompt.

Expand Down Expand Up @@ -1700,7 +1757,7 @@ def __call__(
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
logprobs: Optional[int] = None,
logprobs: Optional[bool] = None,
echo: bool = False,
stop: Optional[Union[str, List[str]]] = [],
frequency_penalty: float = 0.0,
Expand Down
34 changes: 28 additions & 6 deletions nexa/gguf/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,13 @@ def __init__(self, model_path, local_path=None, stop_words=None, **kwargs):
self.params = DEFAULT_TEXT_GEN_PARAMS
self.params.update(kwargs)
self.model = None

self.model_path = model_path
self.downloaded_path = local_path


self.logprobs = kwargs.get('logprobs', None)
self.top_logprobs = kwargs.get('top_logprobs', None)

if self.downloaded_path is None:
self.downloaded_path, run_type = pull_model(self.model_path)

Expand Down Expand Up @@ -80,6 +83,7 @@ def __init__(self, model_path, local_path=None, stop_words=None, **kwargs):
"Failed to load model or tokenizer. Exiting.", exc_info=True
)
exit(1)

def create_embedding(
self,
input: Union[str, List[str]],
Expand Down Expand Up @@ -191,7 +195,7 @@ def run(self):
logging.error(f"Error during generation: {e}", exc_info=True)
print("\n")

def create_chat_completion(self, messages, temperature=0.7, max_tokens=2048, top_k=50, top_p=1.0, stream=False, stop=None):
def create_chat_completion(self, messages, temperature=0.7, max_tokens=2048, top_k=50, top_p=1.0, stream=False, stop=None, logprobs=None, top_logprobs=None):
"""
Used for SDK. Generate completion for a chat conversation.

Expand All @@ -207,9 +211,12 @@ def create_chat_completion(self, messages, temperature=0.7, max_tokens=2048, top
Returns:
Iterator: Iterator for the completion.
"""
return self.model.create_chat_completion(messages=messages, temperature=temperature, max_tokens=max_tokens, top_k=top_k, top_p=top_p, stream=stream, stop=stop)
if logprobs and top_logprobs is None:
top_logprobs = 4

def create_completion(self, prompt, temperature=0.7, max_tokens=2048, top_k=50, top_p=1.0, echo=False, stream=False, stop=None):
return self.model.create_chat_completion(messages=messages, temperature=temperature, max_tokens=max_tokens, top_k=top_k, top_p=top_p, stream=stream, stop=stop, logprobs=logprobs, top_logprobs=top_logprobs)

def create_completion(self, prompt, temperature=0.7, max_tokens=2048, top_k=50, top_p=1.0, echo=False, stream=False, stop=None, logprobs=None, top_logprobs=None):
"""
Used for SDK. Generate completion for a given prompt.

Expand All @@ -226,7 +233,10 @@ def create_completion(self, prompt, temperature=0.7, max_tokens=2048, top_k=50,
Returns:
Iterator: Iterator for the completion.
"""
return self.model.create_completion(prompt=prompt, temperature=temperature, max_tokens=max_tokens, top_k=top_k, top_p=top_p, echo=echo, stream=stream, stop=stop)
if logprobs and top_logprobs is None:
top_logprobs = 4

return self.model.create_completion(prompt=prompt, temperature=temperature, max_tokens=max_tokens, top_k=top_k, top_p=top_p, echo=echo, stream=stream, stop=stop, logprobs=logprobs, top_logprobs=top_logprobs)


def _chat(self, user_input: str) -> Iterator:
Expand All @@ -239,6 +249,8 @@ def _chat(self, user_input: str) -> Iterator:
top_p=self.params["top_p"],
stream=True,
stop=self.stop_words,
logprobs=self.logprobs,
top_logprobs=self.top_logprobs,
)

def _complete(self, user_input: str) -> Iterator:
Expand All @@ -256,6 +268,8 @@ def _complete(self, user_input: str) -> Iterator:
echo=False, # Echo the prompt back in the output
stream=True,
stop=self.stop_words,
logprobs=self.logprobs,
top_logprobs=self.top_logprobs,
)

def run_streamlit(self, model_path: str):
Expand Down Expand Up @@ -322,10 +336,18 @@ def run_streamlit(self, model_path: str):
action="store_true",
help="Run the inference in Streamlit UI",
)
# parser.add_argument(
# "-tlps",
# "--top_logprobs",
# type=int,
# default=None, # -tlps 5
# help="Number of most likely tokens to return at each token position",
# )
args = parser.parse_args()
kwargs = {k: v for k, v in vars(args).items() if v is not None}
model_path = kwargs.pop("model_path")
stop_words = kwargs.pop("stop_words", [])

inference = NexaTextInference(model_path, stop_words=stop_words, **kwargs)
if args.streamlit:
inference.run_streamlit(model_path)
Expand Down
Loading
Loading