Skip to content

Commit

Permalink
Merge branch 'localui-group' of https://github.com/NexaAI/nexa-sdk in…
Browse files Browse the repository at this point in the history
…to maokun-local
  • Loading branch information
MaokunZhang committed Jan 2, 2025
2 parents 17c5df3 + 2f75117 commit db0c615
Showing 1 changed file with 108 additions and 6 deletions.
114 changes: 108 additions & 6 deletions nexa/gguf/server/nexa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ class DownloadModelRequest(BaseModel):
"protected_namespaces": ()
}

class ActionRequest(BaseModel):
prompt: str = ""

class StreamASRProcessor:
def __init__(self, asr, task, language):
self.asr = asr
Expand Down Expand Up @@ -300,6 +303,20 @@ def ts_words(self, segments):
words.append((w.start, w.end, w.word))
return words

class MetricsResult:
def __init__(self, ttft: float, decoding_speed:float):
self.ttft = ttft
self.decoding_speed = decoding_speed

def to_dict(self):
return {
'ttft': round(self.ttft, 2),
'decoding_speed': round(self.decoding_speed, 2)
}

def to_json(self):
return json.dumps(self.to_dict())

# helper functions
async def load_model():
global model, chat_format, completion_template, model_path, n_ctx, is_local_path, model_type, is_huggingface, is_modelscope, projector_path
Expand Down Expand Up @@ -711,16 +728,22 @@ async def read_root(request: Request):
)


def _resp_async_generator(streamer):
def _resp_async_generator(streamer, start_time):
_id = str(uuid.uuid4())
ttft = 0
decoding_times = 0
for token in streamer:
ttft = time.perf_counter() - start_time if ttft==0 else ttft
decoding_times += 1
chunk = {
"id": _id,
"object": "chat.completion.chunk",
"created": time.time(),
"choices": [{"delta": {"content": token}}],
}
yield f"data: {json.dumps(chunk)}\n\n"

yield f"metrics: {MetricsResult(ttft=ttft, decoding_speed=decoding_times / (time.perf_counter() - start_time)).to_json()}\n\n"
yield "data: [DONE]\n\n"

# Global variable for download progress tracking
Expand Down Expand Up @@ -1050,8 +1073,9 @@ async def generate_text(request: GenerationRequest):
generation_kwargs = request.dict()
if request.stream:
# Run the generation and stream the response
start_time = time.perf_counter()
streamer = nexa_run_text_generation(is_chat_completion=False, **generation_kwargs)
return StreamingResponse(_resp_async_generator(streamer), media_type="application/x-ndjson")
return StreamingResponse(_resp_async_generator(streamer, start_time), media_type="application/x-ndjson")
else:
# Generate text synchronously and return the response
result = nexa_run_text_generation(is_chat_completion=False, **generation_kwargs)
Expand Down Expand Up @@ -1095,8 +1119,9 @@ async def text_chat_completions(request: ChatCompletionRequest):
).dict()

if request.stream:
start_time = time.perf_counter()
streamer = nexa_run_text_generation(is_chat_completion=True, **generation_kwargs)
return StreamingResponse(_resp_async_generator(streamer), media_type="application/x-ndjson")
return StreamingResponse(_resp_async_generator(streamer, start_time), media_type="application/x-ndjson")

result = nexa_run_text_generation(is_chat_completion=True, **generation_kwargs)
return {
Expand Down Expand Up @@ -1144,7 +1169,8 @@ async def multimodal_chat_completions(request: VLMChatCompletionRequest):
processed_messages.append({"role": msg.role, "content": processed_content})
else:
processed_messages.append({"role": msg.role, "content": msg.content})


start_time = time.perf_counter()
response = model.create_chat_completion(
messages=processed_messages,
max_tokens=request.max_tokens,
Expand All @@ -1156,7 +1182,8 @@ async def multimodal_chat_completions(request: VLMChatCompletionRequest):
)

if request.stream:
return StreamingResponse(_resp_async_generator(response), media_type="application/x-ndjson")

return StreamingResponse(_resp_async_generator(response, start_time), media_type="application/x-ndjson")
return response

except HTTPException as e:
Expand All @@ -1165,13 +1192,18 @@ async def multimodal_chat_completions(request: VLMChatCompletionRequest):
logging.error(f"Error in multimodal chat completions: {e}")
raise HTTPException(status_code=500, detail=str(e))

async def _resp_omnivlm_async_generator(model, prompt: str, image_path: str):
async def _resp_omnivlm_async_generator(model: NexaOmniVlmInference, prompt: str, image_path: str):
_id = str(uuid.uuid4())
ttft = 0
start_time = time.perf_counter()
decoding_times = 0
try:
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")

for token in model.inference_streaming(prompt, image_path):
ttft = time.perf_counter() - start_time if ttft==0 else ttft
decoding_times += 1
chunk = {
"id": _id,
"object": "chat.completion.chunk",
Expand All @@ -1183,6 +1215,7 @@ async def _resp_omnivlm_async_generator(model, prompt: str, image_path: str):
}]
}
yield f"data: {json.dumps(chunk)}\n\n"
yield f"metrics: {MetricsResult(ttft=ttft, decoding_speed=decoding_times / (time.perf_counter() - start_time)).to_json()}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logging.error(f"Error in OmniVLM streaming: {e}")
Expand Down Expand Up @@ -1500,6 +1533,9 @@ async def audio_chat_completions(
stream: Optional[bool] = Query(False, description="Whether to stream the response"),
):
temp_file = None
ttft = 0
start_time = time.perf_counter()
decoding_times = 0

try:
if model_type != "AudioLM":
Expand All @@ -1516,8 +1552,11 @@ async def audio_chat_completions(

if stream:
async def stream_with_cleanup():
nonlocal ttft, decoding_times, start_time
try:
for token in model.inference_streaming(audio_path, prompt or ""):
ttft = time.perf_counter() - start_time if ttft==0 else ttft
decoding_times += 1
chunk = {
"id": str(uuid.uuid4()),
"object": "chat.completion.chunk",
Expand All @@ -1529,6 +1568,7 @@ async def stream_with_cleanup():
}]
}
yield f"data: {json.dumps(chunk)}\n\n"
yield f"metrics: {MetricsResult(ttft=ttft, decoding_speed=decoding_times / (time.perf_counter() - start_time)).to_json()}\n\n"
yield "data: [DONE]\n\n"
finally:
temp_file.close()
Expand Down Expand Up @@ -1620,6 +1660,68 @@ async def create_embedding(request: EmbeddingRequest):
logging.error(f"Error in embedding generation: {e}")
raise HTTPException(status_code=500, detail=str(e))

@app.post("/v1/action", tags=["Actions"])
async def action(request: ActionRequest):
try:
# Extract content between <nexa_X> and <nexa_end>
prompt = request.prompt
import re

# Use regex to match <nexa_X> pattern
match = re.match(r"<nexa_\d+>(.*?)<nexa_end>", prompt)
if not match:
raise ValueError("Invalid prompt format. Must be wrapped in <nexa_X> and <nexa_end>")

# Extract the function call content
function_content = match.group(1)

# Parse function name and parameters
function_name = function_content[:function_content.index("(")]
params_str = function_content[function_content.index("(")+1:function_content.rindex(")")]

# Parse parameters into dictionary
params = {}
for param in params_str.split(","):
if "=" in param:
key, value = param.split("=")
params[key.strip()] = value.strip().strip("'").strip('"')

# Handle different function types
if function_name == "query_plane_ticket":
# Validate required parameters
required_params = ["year", "date", "time", "departure", "destination"]
for param in required_params:
if param not in params:
raise ValueError(f"Missing required parameter: {param}")

# Construct the date string in required format
date_str = f"{params['date']}/{params['year']}"

# Build the URL
url = (f"https://www.expedia.com/Flights-Search?"
f"leg1=from:{params['departure']},to:{params['destination']},"
f"departure:{date_str}T&"
f"passengers=adults:1&trip=oneway&mode=search")

return {
"status": "success",
"function": function_name,
"parameters": params,
"url": url
}
else:
# Handle other function types in the future
return {
"status": "error",
"message": f"Unsupported function: {function_name}"
}

except Exception as e:
return {
"status": "error",
"message": str(e)
}

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run the Nexa AI Text Generation Service"
Expand Down

0 comments on commit db0c615

Please sign in to comment.