diff --git a/nexa/gguf/llama/llama.py b/nexa/gguf/llama/llama.py index 087d434c..87837545 100644 --- a/nexa/gguf/llama/llama.py +++ b/nexa/gguf/llama/llama.py @@ -89,7 +89,7 @@ def __init__( yarn_beta_fast: float = 32.0, yarn_beta_slow: float = 1.0, yarn_orig_ctx: int = 0, - logits_all: bool = False, + logits_all: bool = True, # switch embedding: bool = False, offload_kqv: bool = True, flash_attn: bool = False, @@ -335,9 +335,10 @@ def __init__( yarn_beta_slow if yarn_beta_slow != 0.0 else 0 ) self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0 - self.context_params.logits_all = ( - logits_all if draft_model is None else True - ) # Must be set to True for speculative decoding + # self.context_params.logits_all = ( + # logits_all if draft_model is None else True + # ) # Must be set to True for speculative decoding + self.context_params.logits_all = True self.context_params.embeddings = embedding # TODO: Rename to embeddings self.context_params.offload_kqv = offload_kqv self.context_params.flash_attn = flash_attn @@ -662,6 +663,8 @@ def sample( mirostat_tau: float = 5.0, penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, grammar: Optional[LlamaGrammar] = None, idx: Optional[int] = None, ): @@ -718,7 +721,27 @@ def sample( id=id, apply_grammar=grammar is not None, ) - return id + + if logprobs is not None and (top_logprobs is not None and top_logprobs > 0): + sampled_logprobs = self.logits_to_logprobs(logits) + token_logprob = float(sampled_logprobs[id]) + + top_logprobs_dict = None + if top_logprobs is not None: + sorted_indices = sampled_logprobs.argsort()[::-1] + top_indices = sorted_indices[:top_logprobs] + top_logprobs_dict = { + self.detokenize([i]).decode("utf-8", errors="ignore"): float(sampled_logprobs[i]) + for i in top_indices + } + + return { + "token": id, + "token_logprob": token_logprob, + "top_logprobs": top_logprobs_dict + } + else: + return id def generate( self, @@ -738,6 +761,8 @@ def generate( mirostat_eta: float = 0.1, penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, grammar: Optional[LlamaGrammar] = None, ) -> Generator[int, Optional[Sequence[int]], None]: @@ -777,7 +802,7 @@ def generate( self.n_tokens = longest_prefix if self.verbose: print(f"Llama.generate: {longest_prefix} prefix-match hit, " - f"remaining {len(tokens)} prompt tokens to eval", file=sys.stderr) + f"remaining {len(tokens)} prompt tokens to eval", file=sys.stderr) # Reset the model state if reset: @@ -794,7 +819,7 @@ def generate( while True: self.eval(tokens) while sample_idx < self.n_tokens: - token = self.sample( + result = self.sample( top_k=top_k, top_p=top_p, min_p=min_p, @@ -808,17 +833,26 @@ def generate( mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, logits_processor=logits_processor, + logprobs=logprobs, + top_logprobs=top_logprobs, grammar=grammar, penalize_nl=penalize_nl, idx=sample_idx, ) + if isinstance(result, dict): + token = result["token"] + logprobs_info = result + else: + token = result + logprobs_info = None + sample_idx += 1 if stopping_criteria is not None and stopping_criteria( self._input_ids, self._scores[-1, :] ): return - tokens_or_none = yield token + tokens_or_none = yield token, logprobs_info tokens.clear() tokens.append(token) if tokens_or_none is not None: @@ -1011,7 +1045,7 @@ def _create_completion( top_p: float = 0.95, min_p: float = 0.05, typical_p: float = 1.0, - logprobs: Optional[int] = None, + logprobs: Optional[bool] = None, echo: bool = False, stop: Optional[Union[str, List[str]]] = [], frequency_penalty: float = 0.0, @@ -1027,6 +1061,7 @@ def _create_completion( model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, + top_logprobs: Optional[int] = None, grammar: Optional[LlamaGrammar] = None, logit_bias: Optional[Dict[str, float]] = None, ) -> Union[ @@ -1183,7 +1218,9 @@ def logit_bias_processor( finish_reason = "length" multibyte_fix = 0 - for token in self.generate( + logprobs_or_none = None + + for token, logprobs_info in self.generate( prompt_tokens, top_k=top_k, top_p=top_p, @@ -1199,6 +1236,8 @@ def logit_bias_processor( repeat_penalty=repeat_penalty, stopping_criteria=stopping_criteria, logits_processor=logits_processor, + logprobs=logprobs, + top_logprobs=top_logprobs, grammar=grammar, ): assert self._model.model is not None @@ -1209,6 +1248,20 @@ def logit_bias_processor( completion_tokens.append(token) + if logprobs_info and logprobs_or_none is None: + logprobs_or_none = { + "tokens": [], + "text_offset": [], + "token_logprobs": [], + "top_logprobs": [] + } + + if logprobs_info: + logprobs_or_none["tokens"].append(self.detokenize([token]).decode("utf-8", errors="ignore")) + logprobs_or_none["text_offset"].append(len(self.detokenize(completion_tokens[:-1]))) + logprobs_or_none["token_logprobs"].append(logprobs_info["token_logprob"]) + logprobs_or_none["top_logprobs"].append(logprobs_info["top_logprobs"]) + all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) # Contains multi-byte UTF8 @@ -1407,7 +1460,7 @@ def logit_bias_processor( ) ) - logprobs_or_none: Optional[CompletionLogprobs] = None + # logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: if token == bos_token_id: continue @@ -1492,7 +1545,10 @@ def logit_bias_processor( { "text": "", "index": 0, - "logprobs": None, + "delta": { + "content": "", + }, + "logprobs": logprobs_or_none, "finish_reason": finish_reason, } ], @@ -1507,7 +1563,7 @@ def logit_bias_processor( if suffix_token_id < 0 and suffix is not None: text_str = text_str + suffix - logprobs_or_none: Optional[CompletionLogprobs] = None + # logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: text_offset = 0 if echo else len(prompt) token_offset = 0 if echo else len(prompt_tokens[1:]) @@ -1603,7 +1659,7 @@ def create_completion( top_p: float = 0.95, min_p: float = 0.05, typical_p: float = 1.0, - logprobs: Optional[int] = None, + logprobs: Optional[bool] = None, echo: bool = False, stop: Optional[Union[str, List[str]]] = [], frequency_penalty: float = 0.0, @@ -1621,6 +1677,7 @@ def create_completion( logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, logit_bias: Optional[Dict[str, float]] = None, + top_logprobs: Optional[int] = None, ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]: """Generate text from a prompt. @@ -1700,7 +1757,7 @@ def __call__( top_p: float = 0.95, min_p: float = 0.05, typical_p: float = 1.0, - logprobs: Optional[int] = None, + logprobs: Optional[bool] = None, echo: bool = False, stop: Optional[Union[str, List[str]]] = [], frequency_penalty: float = 0.0, diff --git a/nexa/gguf/nexa_inference_text.py b/nexa/gguf/nexa_inference_text.py index edbd63c9..7c376943 100644 --- a/nexa/gguf/nexa_inference_text.py +++ b/nexa/gguf/nexa_inference_text.py @@ -49,10 +49,13 @@ def __init__(self, model_path, local_path=None, stop_words=None, **kwargs): self.params = DEFAULT_TEXT_GEN_PARAMS self.params.update(kwargs) self.model = None - + self.model_path = model_path self.downloaded_path = local_path - + + self.logprobs = kwargs.get('logprobs', None) + self.top_logprobs = kwargs.get('top_logprobs', None) + if self.downloaded_path is None: self.downloaded_path, run_type = pull_model(self.model_path) @@ -80,6 +83,7 @@ def __init__(self, model_path, local_path=None, stop_words=None, **kwargs): "Failed to load model or tokenizer. Exiting.", exc_info=True ) exit(1) + def create_embedding( self, input: Union[str, List[str]], @@ -191,7 +195,7 @@ def run(self): logging.error(f"Error during generation: {e}", exc_info=True) print("\n") - def create_chat_completion(self, messages, temperature=0.7, max_tokens=2048, top_k=50, top_p=1.0, stream=False, stop=None): + def create_chat_completion(self, messages, temperature=0.7, max_tokens=2048, top_k=50, top_p=1.0, stream=False, stop=None, logprobs=None, top_logprobs=None): """ Used for SDK. Generate completion for a chat conversation. @@ -207,9 +211,12 @@ def create_chat_completion(self, messages, temperature=0.7, max_tokens=2048, top Returns: Iterator: Iterator for the completion. """ - return self.model.create_chat_completion(messages=messages, temperature=temperature, max_tokens=max_tokens, top_k=top_k, top_p=top_p, stream=stream, stop=stop) + if logprobs and top_logprobs is None: + top_logprobs = 4 - def create_completion(self, prompt, temperature=0.7, max_tokens=2048, top_k=50, top_p=1.0, echo=False, stream=False, stop=None): + return self.model.create_chat_completion(messages=messages, temperature=temperature, max_tokens=max_tokens, top_k=top_k, top_p=top_p, stream=stream, stop=stop, logprobs=logprobs, top_logprobs=top_logprobs) + + def create_completion(self, prompt, temperature=0.7, max_tokens=2048, top_k=50, top_p=1.0, echo=False, stream=False, stop=None, logprobs=None, top_logprobs=None): """ Used for SDK. Generate completion for a given prompt. @@ -226,7 +233,10 @@ def create_completion(self, prompt, temperature=0.7, max_tokens=2048, top_k=50, Returns: Iterator: Iterator for the completion. """ - return self.model.create_completion(prompt=prompt, temperature=temperature, max_tokens=max_tokens, top_k=top_k, top_p=top_p, echo=echo, stream=stream, stop=stop) + if logprobs and top_logprobs is None: + top_logprobs = 4 + + return self.model.create_completion(prompt=prompt, temperature=temperature, max_tokens=max_tokens, top_k=top_k, top_p=top_p, echo=echo, stream=stream, stop=stop, logprobs=logprobs, top_logprobs=top_logprobs) def _chat(self, user_input: str) -> Iterator: @@ -239,6 +249,8 @@ def _chat(self, user_input: str) -> Iterator: top_p=self.params["top_p"], stream=True, stop=self.stop_words, + logprobs=self.logprobs, + top_logprobs=self.top_logprobs, ) def _complete(self, user_input: str) -> Iterator: @@ -256,6 +268,8 @@ def _complete(self, user_input: str) -> Iterator: echo=False, # Echo the prompt back in the output stream=True, stop=self.stop_words, + logprobs=self.logprobs, + top_logprobs=self.top_logprobs, ) def run_streamlit(self, model_path: str): @@ -322,10 +336,18 @@ def run_streamlit(self, model_path: str): action="store_true", help="Run the inference in Streamlit UI", ) + # parser.add_argument( + # "-tlps", + # "--top_logprobs", + # type=int, + # default=None, # -tlps 5 + # help="Number of most likely tokens to return at each token position", + # ) args = parser.parse_args() kwargs = {k: v for k, v in vars(args).items() if v is not None} model_path = kwargs.pop("model_path") stop_words = kwargs.pop("stop_words", []) + inference = NexaTextInference(model_path, stop_words=stop_words, **kwargs) if args.streamlit: inference.run_streamlit(model_path) diff --git a/nexa/gguf/server/nexa_service.py b/nexa/gguf/server/nexa_service.py index a00047bd..ce914656 100644 --- a/nexa/gguf/server/nexa_service.py +++ b/nexa/gguf/server/nexa_service.py @@ -31,6 +31,7 @@ from nexa.gguf.sd.stable_diffusion import StableDiffusion from faster_whisper import WhisperModel import argparse + logging.basicConfig(level=logging.INFO) app = FastAPI() @@ -58,7 +59,8 @@ class GenerationRequest(BaseModel): top_k: int = 50 top_p: float = 1.0 stop_words: Optional[List[str]] = None - + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = 4 async def load_model(): global model, chat_format, completion_template, model_path @@ -74,6 +76,7 @@ async def load_model(): verbose=False, chat_format=chat_format, n_gpu_layers=-1 if is_gpu_available() else 0, + logits_all=True ) except Exception as e: logging.error( @@ -83,7 +86,8 @@ async def load_model(): model_path=downloaded_path, verbose=False, chat_format=chat_format, - n_gpu_layers=0, # hardcode to use CPU + n_gpu_layers=0, # hardcode to use CPU, + logits_all=True ) logging.info(f"model loaded as {model}") @@ -97,6 +101,7 @@ async def load_model(): verbose=False, chat_format=chat_format, n_gpu_layers=-1 if is_gpu_available() else 0, + logits_all=True # switch on ) except Exception as e: logging.error( @@ -107,6 +112,7 @@ async def load_model(): verbose=False, chat_format=chat_format, n_gpu_layers=0, # hardcode to use CPU + logits_all=True ) logging.info(f"model loaded as {model}") elif run_type == "Computer Vision": @@ -139,28 +145,47 @@ async def startup_event(): await load_model() async def nexa_run_text_generation( - prompt, temperature, stop_words, max_new_tokens, top_k, top_p -) -> str: + prompt, temperature, stop_words, max_new_tokens, top_k, top_p, logprobs=None, top_logprobs=None +) -> Dict[str, Any]: global model, chat_format, completion_template, conversation_history if model is None: raise ValueError("Model is not loaded. Please check the model path and try again.") + generated_text = "" + logprobs_or_none = None # init to store the logprobs if requested if chat_format: conversation_history.append({"role": "user", "content": prompt}) - streamer = model.create_chat_completion( - messages=conversation_history, - temperature=temperature, - max_tokens=max_new_tokens, - top_k=top_k, - top_p=top_p, - stream=True, - stop=stop_words, - ) + + params = { + 'messages': conversation_history, + 'temperature': temperature, + 'max_tokens': max_new_tokens, + 'top_k': top_k, + 'top_p': top_p, + 'stream': True, + 'stop': stop_words, + 'logprobs': logprobs, + 'top_logprobs': top_logprobs, + } + + params_json = json.dumps(params, default=str, indent=2) + + streamer = model.create_chat_completion(**params) + for chunk in streamer: delta = chunk["choices"][0]["delta"] if "content" in delta: generated_text += delta["content"] + + if logprobs and "logprobs" in chunk["choices"][0]: + if logprobs_or_none is None: + logprobs_or_none = chunk["choices"][0]["logprobs"] + else: + for key in logprobs_or_none: # tokens, token_logprobs, top_logprobs, text_offset + if key in chunk["choices"][0]["logprobs"]: + logprobs_or_none[key].extend(chunk["choices"][0]["logprobs"][key]) # accumulate data from each chunk + else: prompt = completion_template.format(prompt) if completion_template else prompt streamer = model.create_completion( @@ -171,15 +196,31 @@ async def nexa_run_text_generation( top_p=top_p, stream=True, stop=stop_words, + logprobs=logprobs, + top_logprobs=top_logprobs, ) + for chunk in streamer: + pp.pprint(chunk) delta = chunk["choices"][0]["text"] generated_text += delta + if logprobs and "logprobs" in chunk["choices"][0]: + if logprobs_or_none is None: + logprobs_or_none = chunk["choices"][0]["logprobs"] + else: + for key in logprobs_or_none: # tokens, token_logprobs, top_logprobs, text_offset + if key in chunk["choices"][0]["logprobs"]: + logprobs_or_none[key].extend(chunk["choices"][0]["logprobs"][key]) # accumulate data from each chunk + if is_chat_mode: conversation_history.append({"role": "assistant", "content": generated_text}) - return generated_text + result = { + "result": generated_text, + "logprobs": logprobs_or_none + } + return result @app.get("/", response_class=HTMLResponse) @@ -191,9 +232,16 @@ async def read_root(request: Request): @app.post("/v1/completions") async def generate_text(request: GenerationRequest): + logging.info(f"[/v1/completions] Request logprobs: {request.logprobs}, top_logprobs: {request.top_logprobs}") try: result = await nexa_run_text_generation(**request.dict()) - return JSONResponse(content={"result": result}) + # return JSONResponse(content={"result": result}) + return JSONResponse(content={ + "choices": [{ + "text": result["result"], + "logprobs": result["logprobs"] + }] + }) except Exception as e: logging.error(f"Error in text generation: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -214,6 +262,8 @@ class ChatCompletionRequest(BaseModel): temperature: Optional[float] = 0.1 stream: Optional[bool] = False stop_words: Optional[List[str]] = [] + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = 4 class FunctionDefinitionRequestClass(BaseModel): type: str = "function" @@ -368,18 +418,22 @@ async def img2img(request: ImageGenerationRequest): logging.error(f"Error in img2img generation: {e}") raise HTTPException(status_code=500, detail=str(e)) + @app.post("/v1/chat/completions") async def chat_completions(request: ChatCompletionRequest): + logging.info(f"[/v1/chat/completions] Request logprobs: {request.logprobs}, top_logprobs: {request.top_logprobs}") try: generation_kwargs = GenerationRequest( prompt="" if len(request.messages) == 0 else request.messages[-1].content, temperature=request.temperature, max_new_tokens=request.max_tokens, stop_words=request.stop_words, + logprobs=request.logprobs, + top_logprobs=request.top_logprobs, ).dict() if request.stream: - # Run the generation and stream the response + # run the generation and stream the response: async def stream_generator(): streamer = await nexa_run_text_generation(**generation_kwargs) async for chunk in _resp_async_generator(streamer): @@ -388,13 +442,16 @@ async def stream_generator(): return StreamingResponse(stream_generator(), media_type="application/x-ndjson") else: - # Generate text synchronously and return the response - resp_content = await nexa_run_text_generation(**generation_kwargs) + # generate text synchronously and return the response: + result = await nexa_run_text_generation(**generation_kwargs) return { "id": str(uuid.uuid4()), "object": "chat.completion", "created": time.time(), - "choices": [{"message": Message(role="assistant", content=resp_content)}], + "choices": [{ + "message": Message(role="assistant", content=result["result"]), + "logprobs": result["logprobs"] if "logprobs" in result else None, + }], } except Exception as e: logging.error(f"Error in chat completions: {e}") @@ -502,4 +559,4 @@ def run_nexa_ai_service(model_path_arg, **kwargs): help="Enable automatic reloading on code changes", ) args = parser.parse_args() - run_nexa_ai_service(args.model_path, host=args.host, port=args.port, reload=args.reload) + run_nexa_ai_service(args.model_path, host=args.host, port=args.port, reload=args.reload) \ No newline at end of file