Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update APIs #336

Merged
merged 5 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions nexa/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def download_file_with_progress(
chunk_size: int = 5 * 1024 * 1024,
max_workers: int = 20,
use_processes: bool = default_use_processes(),
progress_callback=None,
**kwargs
):
file_path.parent.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -376,6 +377,8 @@ def download_file_with_progress(
else concurrent.futures.ThreadPoolExecutor
)

downloaded_size = 0 # Track the total downloaded size

with executor_class(max_workers=max_workers) as executor:
future_to_chunk = {
executor.submit(
Expand All @@ -389,13 +392,25 @@ def download_file_with_progress(
try:
chunk_size, chunk_number = future.result()
completed_chunks[chunk_number] = True
downloaded_size += chunk_size
progress_bar.update(chunk_size)

# Call the progress callback with the current progress
if progress_callback:
progress_callback(downloaded_size, file_size, stage="downloading")

except Exception as e:
print(f"Error downloading chunk {chunk_number}: {e}")

progress_bar.close()

if all(completed_chunks):
# Transition to Verifying phase
try:
if progress_callback:
progress_callback(downloaded_size, file_size, stage="verifying")
except Exception as e:
print(f"Error Verifying chunk {chunk_number}: {e}")
# Create a new progress bar for combining chunks
combine_progress = tqdm(
total=file_size,
Expand Down Expand Up @@ -428,6 +443,7 @@ def download_file_with_progress(
shutil.rmtree(temp_dir)



def download_model_from_official(model_path, model_type, **kwargs):
try:
model_name, model_version = model_path.split(":")
Expand Down
279 changes: 242 additions & 37 deletions nexa/gguf/server/nexa_service.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import json
import logging
import os
from pathlib import Path
import queue
import shutil
import socket
import threading
import time
import uuid
from typing import List, Optional, Dict, Any, Union, Literal
import base64
import multiprocessing
from PIL import Image
import tempfile
import concurrent
import tqdm
import uvicorn
from fastapi import FastAPI, HTTPException, Request, File, UploadFile, Query
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -19,10 +25,16 @@
from PIL import Image
import base64
from urllib.parse import urlparse
import asyncio
from fastapi.responses import StreamingResponse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate import


from nexa.constants import (
NEXA_MODELS_HUB_OFFICIAL_DIR,
NEXA_OFFICIAL_MODELS_TYPE,
NEXA_RUN_CHAT_TEMPLATE_MAP,
NEXA_RUN_MODEL_MAP_TEXT,
NEXA_RUN_MODEL_MAP_VLM,
NEXA_RUN_MODEL_MAP_VOICE,
NEXA_RUN_PROJECTOR_MAP,
NEXA_RUN_OMNI_VLM_MAP,
NEXA_RUN_OMNI_VLM_PROJECTOR_MAP,
Expand All @@ -31,16 +43,19 @@
NEXA_RUN_COMPLETION_TEMPLATE_MAP,
NEXA_RUN_MODEL_PRECISION_MAP,
NEXA_RUN_MODEL_MAP_FUNCTION_CALLING,
NEXA_MODEL_LIST_PATH
NEXA_MODEL_LIST_PATH,
NEXA_MODEL_LIST_PATH,
NEXA_OFFICIAL_BUCKET,
)
from nexa.gguf.converter.constants import NEXA_MODELS_HUB_DIR
from nexa.gguf.lib_utils import is_gpu_available
from nexa.gguf.llama.llama_chat_format import (
Llava15ChatHandler,
Llava16ChatHandler,
NanoLlavaChatHandler,
)
from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr
from nexa.general import pull_model
from nexa.general import add_model_to_list, default_use_processes, download_file_with_progress, get_model_info, is_model_exists, pull_model
from nexa.gguf.llama.llama import Llama
from nexa.gguf.nexa_inference_vlm_omni import NexaOmniVlmInference
from nexa.gguf.nexa_inference_audio_lm import NexaAudioLMInference
Expand Down Expand Up @@ -731,52 +746,242 @@ def _resp_async_generator(streamer, start_time):
yield f"metrics: {MetricsResult(ttft=ttft, decoding_speed=decoding_times / (time.perf_counter() - start_time)).to_json()}\n\n"
yield "data: [DONE]\n\n"

# Global variable for download progress tracking
download_progress = {}

def pull_model_with_progress(model_path, progress_key, **kwargs):
"""
Wrapper for pull_model to track download progress using download_file_with_progress.
"""
try:
# Initialize progress tracking
download_progress[progress_key] = 0
current_file_completed = False

# Extract local download path
local_download_path = kwargs.get('local_download_path')
base_download_dir = Path(local_download_path) if local_download_path else NEXA_MODELS_HUB_OFFICIAL_DIR
model_name, model_version = model_path.split(":")
file_extension = ".zip" if kwargs.get("model_type") in ["onnx", "bin"] else ".gguf"
filename = f"{model_version}{file_extension}"
file_path = base_download_dir / model_name / filename

# Record expected file details
expected_files = [
{
"path": file_path,
"size": int(requests.head(f"{NEXA_OFFICIAL_BUCKET}/{model_name}/{filename}").headers.get("Content-Length", 0))
}
]

# Progress tracker
def monitor_progress(file_path, total_size):
"""
Monitor file size growth to estimate progress.
"""
nonlocal current_file_completed
while not current_file_completed:
# Update downloading progress
time.sleep(0.5)

def progress_callback(downloaded_chunks, total_chunks, stage="downloading"):
"""
Callback to update progress based on downloaded chunks.
"""
nonlocal current_file_completed
if stage == "downloading":
if total_chunks > 0:
progress = int((downloaded_chunks / total_chunks) * 100)
download_progress[progress_key] = min(progress, 99)
if downloaded_chunks == total_chunks:
current_file_completed = True # Mark file as completed
elif stage == "verifying":
download_progress[progress_key] = 100

# Call pull_model to start the download
url = f"{NEXA_OFFICIAL_BUCKET}/{model_name}/{filename}"
response = requests.head(url)
total_size = int(response.headers.get("Content-Length", 0))

# Start monitoring progress in a background thread
from threading import Thread
progress_thread = Thread(target=monitor_progress, args=(file_path, total_size))
progress_thread.start()

# Call download_file_with_progress with progress callback
result = pull_model(
model_path= model_path,
hf=kwargs.get("hf", False),
ms=kwargs.get("ms", False),
progress_callback=lambda downloaded, total, stage: progress_callback(downloaded, total, stage),
**kwargs,
)

if not result or len(result) != 2:
raise ValueError("Invalid response from pull_model.")

final_file_path, run_type = result

if not final_file_path or not run_type:
raise ValueError("Failed to download model or invalid response from pull_model.")

# Extract model type from the returned file path or extension
model_type = Path(final_file_path).suffix.strip(".") or "undefined"

download_progress[progress_key] = 100

return {
"local_path": str(final_file_path),
"model_type": model_type,
"run_type": run_type,
}
except Exception as e:
download_progress[progress_key] = -1 # Mark download as failed
raise ValueError(f"Error in pull_model_with_progress: {e}")

@app.get("/v1/check_model_type", tags=["Model"])
async def check_model(model_path: str):
"""
Check if the model exists and return its type.
"""
model_name, model_version = model_path.split(":")
if model_name in NEXA_OFFICIAL_MODELS_TYPE:
model_type = NEXA_OFFICIAL_MODELS_TYPE[model_name].value
return {
"model_name": model_name,
"model_type": model_type
}
else:
raise HTTPException(
status_code=404,
detail=f"Model '{model_name}' not found in the official model list."
)

@app.post("/v1/download_model", tags=["Model"])
async def download_model(request: DownloadModelRequest):
"""Download a model from the model hub"""
"""
Download a model from the model hub with progress tracking.
"""
try:
if request.model_path in NEXA_RUN_MODEL_MAP_VLM or request.model_path in NEXA_RUN_OMNI_VLM_MAP or request.model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM: # models and projectors
if request.model_path in NEXA_RUN_MODEL_MAP_VLM:
downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_VLM[request.model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_PROJECTOR_MAP[request.model_path])
elif request.model_path in NEXA_RUN_OMNI_VLM_MAP:
downloaded_path, model_type = pull_model(NEXA_RUN_OMNI_VLM_MAP[request.model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[request.model_path])
elif request.model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM:
downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_AUDIO_LM[request.model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_AUDIO_LM_PROJECTOR_MAP[request.model_path])

if not downloaded_path or not model_type:
return JSONResponse(content="Failed to download model. Please check whether model_path is correct.", status_code=400)

return {
"status": "success",
"message": "Successfully downloaded model and projector",
"model_path": request.model_path,
"model_local_path": downloaded_path,
"projector_local_path": projector_downloaded_path,
"model_type": model_type
}
else:
downloaded_path, model_type = pull_model(request.model_path)
if not downloaded_path or not model_type:
return JSONResponse(content="Failed to download model. Please check whether model_path is correct.", status_code=400)
# Initialize progress tracking
progress_key = request.model_path
download_progress[progress_key] = 0

def perform_download():
"""
Perform the download process with progress tracking.
"""
try:
if request.model_path in NEXA_RUN_MODEL_MAP_VLM:
downloaded_path = pull_model_with_progress(
NEXA_RUN_MODEL_MAP_VLM[request.model_path], progress_key=progress_key
)
projector_downloaded_path = pull_model_with_progress(
NEXA_RUN_PROJECTOR_MAP[request.model_path], progress_key=progress_key
)
return {
"status": "success",
"message": "Successfully downloaded model and projector",
"model_path": request.model_path,
"model_local_path": downloaded_path["local_path"],
"projector_local_path": projector_downloaded_path["local_path"],
"model_type": downloaded_path["run_type"]
}
elif request.model_path in NEXA_RUN_OMNI_VLM_MAP:
downloaded_path = pull_model_with_progress(
NEXA_RUN_OMNI_VLM_MAP[request.model_path], progress_key=progress_key
)
projector_downloaded_path = pull_model_with_progress(
NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[request.model_path], progress_key=progress_key
)
return {
"status": "success",
"message": "Successfully downloaded model and projector",
"model_path": request.model_path,
"model_local_path": downloaded_path["local_path"],
"projector_local_path": projector_downloaded_path["local_path"],
"model_type": downloaded_path["run_type"]
}
elif request.model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM:
downloaded_path = pull_model_with_progress(
NEXA_RUN_MODEL_MAP_AUDIO_LM[request.model_path], progress_key=progress_key
)
projector_downloaded_path = pull_model_with_progress(
NEXA_RUN_AUDIO_LM_PROJECTOR_MAP[request.model_path], progress_key=progress_key
)
return {
"status": "success",
"message": "Successfully downloaded model and projector",
"model_path": request.model_path,
"model_local_path": downloaded_path["local_path"],
"projector_local_path": projector_downloaded_path["local_path"],
"model_type": downloaded_path["run_type"]
}
else:
downloaded_path = pull_model_with_progress(
request.model_path, progress_key=progress_key
)
return {
"status": "success",
"message": "Successfully downloaded model",
"model_path": request.model_path,
"model_local_path": downloaded_path["local_path"],
"model_type": downloaded_path["run_type"]
}
except Exception as e:
logging.error(f"Error during download: {e}")
download_progress[progress_key] = -1 # Mark download as failed
raise

return {
"status": "success",
"message": "Successfully downloaded model",
"model_path": request.model_path,
"model_local_path": downloaded_path,
"model_type": model_type
}

# Execute the download in a background thread
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, perform_download)

# Return the result of the download
return result

except Exception as e:
# Log error and raise HTTP exception
logging.error(f"Error downloading model: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to download model: {str(e)}"
)

async def progress_generator(model_path: str):
"""
A generator to stream download progress updates.
"""
try:
while True:
progress = download_progress.get(model_path, -1)
# Check if the model_path exists in download_progress
if progress == -1:
yield f"data: {{\"error\": \"Download failed or invalid model path.\"}}\n\n"
break

yield f"data: {{\"model_name\": \"{model_path}\", \"progress\": {progress}}}\n\n"

if progress == 100:
# if model_path in download_progress and download_progress[model_path] == 100:
break

await asyncio.sleep(1)
except Exception as e:
yield f"data: {{\"error\": \"Error streaming progress: {str(e)}\"}}\n\n"


@app.get("/v1/download_progress", tags=["Model"])
async def get_download_progress(model_path: str):
"""
Stream the download progress for a specific model.
"""
if model_path not in download_progress:
raise HTTPException(status_code=404, detail="No download progress found for the specified model.")

# Return a StreamingResponse
return StreamingResponse(progress_generator(model_path), media_type="text/event-stream")

@app.post("/v1/load_model", tags=["Model"])
async def load_different_model(request: LoadModelRequest):
"""Load a different model while maintaining the global model state"""
Expand Down