From 17c5df3081b1f525b0051ac8900268c02b9f4005 Mon Sep 17 00:00:00 2001 From: MaokunZhang Date: Thu, 2 Jan 2025 14:01:27 -0800 Subject: [PATCH 1/4] Update APIs: 1. check_model_type 2. download_model 3. get_download_progress --- nexa/general.py | 10 ++ nexa/gguf/server/nexa_service.py | 263 +++++++++++++++++++++++++++---- 2 files changed, 243 insertions(+), 30 deletions(-) diff --git a/nexa/general.py b/nexa/general.py index 2289f661..19565f58 100644 --- a/nexa/general.py +++ b/nexa/general.py @@ -341,6 +341,7 @@ def download_file_with_progress( chunk_size: int = 5 * 1024 * 1024, max_workers: int = 20, use_processes: bool = default_use_processes(), + progress_callback=None, **kwargs ): file_path.parent.mkdir(parents=True, exist_ok=True) @@ -376,6 +377,8 @@ def download_file_with_progress( else concurrent.futures.ThreadPoolExecutor ) + downloaded_size = 0 # Track the total downloaded size + with executor_class(max_workers=max_workers) as executor: future_to_chunk = { executor.submit( @@ -389,7 +392,13 @@ def download_file_with_progress( try: chunk_size, chunk_number = future.result() completed_chunks[chunk_number] = True + downloaded_size += chunk_size progress_bar.update(chunk_size) + + # Call the progress callback with the current progress + if progress_callback: + progress_callback(downloaded_size, file_size) + except Exception as e: print(f"Error downloading chunk {chunk_number}: {e}") @@ -428,6 +437,7 @@ def download_file_with_progress( shutil.rmtree(temp_dir) + def download_model_from_official(model_path, model_type, **kwargs): try: model_name, model_version = model_path.split(":") diff --git a/nexa/gguf/server/nexa_service.py b/nexa/gguf/server/nexa_service.py index 3db82c8d..6c166ae2 100644 --- a/nexa/gguf/server/nexa_service.py +++ b/nexa/gguf/server/nexa_service.py @@ -1,7 +1,11 @@ import json import logging import os +from pathlib import Path +import queue +import shutil import socket +import threading import time import uuid from typing import List, Optional, Dict, Any, Union, Literal @@ -9,6 +13,8 @@ import multiprocessing from PIL import Image import tempfile +import concurrent +import tqdm import uvicorn from fastapi import FastAPI, HTTPException, Request, File, UploadFile, Query from fastapi.middleware.cors import CORSMiddleware @@ -19,10 +25,16 @@ from PIL import Image import base64 from urllib.parse import urlparse +import asyncio +from fastapi.responses import StreamingResponse from nexa.constants import ( + NEXA_MODELS_HUB_OFFICIAL_DIR, + NEXA_OFFICIAL_MODELS_TYPE, NEXA_RUN_CHAT_TEMPLATE_MAP, + NEXA_RUN_MODEL_MAP_TEXT, NEXA_RUN_MODEL_MAP_VLM, + NEXA_RUN_MODEL_MAP_VOICE, NEXA_RUN_PROJECTOR_MAP, NEXA_RUN_OMNI_VLM_MAP, NEXA_RUN_OMNI_VLM_PROJECTOR_MAP, @@ -31,8 +43,11 @@ NEXA_RUN_COMPLETION_TEMPLATE_MAP, NEXA_RUN_MODEL_PRECISION_MAP, NEXA_RUN_MODEL_MAP_FUNCTION_CALLING, - NEXA_MODEL_LIST_PATH + NEXA_MODEL_LIST_PATH, + NEXA_MODEL_LIST_PATH, + NEXA_OFFICIAL_BUCKET, ) +from nexa.gguf.converter.constants import NEXA_MODELS_HUB_DIR from nexa.gguf.lib_utils import is_gpu_available from nexa.gguf.llama.llama_chat_format import ( Llava15ChatHandler, @@ -40,7 +55,7 @@ NanoLlavaChatHandler, ) from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr -from nexa.general import pull_model +from nexa.general import add_model_to_list, default_use_processes, download_file_with_progress, get_model_info, is_model_exists, pull_model from nexa.gguf.llama.llama import Llama from nexa.gguf.nexa_inference_vlm_omni import NexaOmniVlmInference from nexa.gguf.nexa_inference_audio_lm import NexaAudioLMInference @@ -708,45 +723,233 @@ def _resp_async_generator(streamer): yield f"data: {json.dumps(chunk)}\n\n" yield "data: [DONE]\n\n" +# Global variable for download progress tracking +download_progress = {} + +def pull_model_with_progress(model_path, progress_key, **kwargs): + """ + Wrapper for pull_model to track download progress using download_file_with_progress. + """ + try: + # Initialize progress tracking + download_progress[progress_key] = 0 + + # Extract local download path + local_download_path = kwargs.get('local_download_path') + base_download_dir = Path(local_download_path) if local_download_path else NEXA_MODELS_HUB_OFFICIAL_DIR + model_name, model_version = model_path.split(":") + file_extension = ".zip" if kwargs.get("model_type") in ["onnx", "bin"] else ".gguf" + filename = f"{model_version}{file_extension}" + file_path = base_download_dir / model_name / filename + + # Progress tracker + def monitor_progress(file_path, total_size): + """ + Monitor file size growth to estimate progress. + """ + while not os.path.exists(file_path): + time.sleep(0.5) + + while True: + try: + current_size = os.path.getsize(file_path) + progress = min(int((current_size / total_size) * 100), 99) + download_progress[progress_key] = progress + + # Break the loop if the file size stops growing + if current_size >= total_size: + break + time.sleep(0.5) # Adjust frequency of checking + except FileNotFoundError: + # Handle cases where the file gets deleted during download + break + + def progress_callback(downloaded_chunks, total_chunks): + """ + Callback to update progress based on downloaded chunks. + """ + if total_chunks > 0: + progress = int((downloaded_chunks / total_chunks) * 100) + download_progress[progress_key] = progress + + # Call pull_model to start the download + url = f"{NEXA_OFFICIAL_BUCKET}/{model_name}/{filename}" + response = requests.head(url) + total_size = int(response.headers.get("Content-Length", 0)) + + # Start monitoring progress in a background thread + from threading import Thread + progress_thread = Thread(target=monitor_progress, args=(file_path, total_size)) + progress_thread.start() + + # Call download_file_with_progress with progress callback + result = pull_model( + model_path= model_path, + hf=kwargs.get("hf", False), + ms=kwargs.get("ms", False), + progress_callback=progress_callback, + **kwargs, + ) + + final_file_path, run_type = result + + if not final_file_path or not run_type: + raise ValueError("Failed to download model or invalid response from pull_model.") + + # Extract model type from the returned file path or extension + model_type = Path(final_file_path).suffix.strip(".") or "undefined" + + return { + "local_path": str(final_file_path), + "model_type": model_type, + "run_type": run_type, + } + except Exception as e: + download_progress[progress_key] = -1 # Mark download as failed + raise ValueError(f"Error in pull_model_with_progress: {e}") + +@app.get("/v1/check_model_type", tags=["Model"]) +async def check_model(model_path: str): + """ + Check if the model exists and return its type. + """ + model_name, model_version = model_path.split(":") + if model_name in NEXA_OFFICIAL_MODELS_TYPE: + model_type = NEXA_OFFICIAL_MODELS_TYPE[model_name].value + return { + "model_name": model_name, + "model_type": model_type + } + else: + raise HTTPException( + status_code=404, + detail=f"Model '{model_name}' not found in the official model list." + ) + @app.post("/v1/download_model", tags=["Model"]) async def download_model(request: DownloadModelRequest): - """Download a model from the model hub""" + """ + Download a model from the model hub with progress tracking. + """ try: - if request.model_path in NEXA_RUN_MODEL_MAP_VLM or request.model_path in NEXA_RUN_OMNI_VLM_MAP or request.model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM: # models and projectors - if request.model_path in NEXA_RUN_MODEL_MAP_VLM: - downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_VLM[request.model_path]) - projector_downloaded_path, _ = pull_model(NEXA_RUN_PROJECTOR_MAP[request.model_path]) - elif request.model_path in NEXA_RUN_OMNI_VLM_MAP: - downloaded_path, model_type = pull_model(NEXA_RUN_OMNI_VLM_MAP[request.model_path]) - projector_downloaded_path, _ = pull_model(NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[request.model_path]) - elif request.model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM: - downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_AUDIO_LM[request.model_path]) - projector_downloaded_path, _ = pull_model(NEXA_RUN_AUDIO_LM_PROJECTOR_MAP[request.model_path]) - return { - "status": "success", - "message": "Successfully downloaded model and projector", - "model_path": request.model_path, - "model_local_path": downloaded_path, - "projector_local_path": projector_downloaded_path, - "model_type": model_type - } - else: - downloaded_path, model_type = pull_model(request.model_path) - return { - "status": "success", - "message": "Successfully downloaded model", - "model_path": request.model_path, - "model_local_path": downloaded_path, - "model_type": model_type - } + # Initialize progress tracking + progress_key = request.model_path + download_progress[progress_key] = 0 + + def perform_download(): + """ + Perform the download process with progress tracking. + """ + try: + if request.model_path in NEXA_RUN_MODEL_MAP_VLM: + downloaded_path = pull_model_with_progress( + NEXA_RUN_MODEL_MAP_VLM[request.model_path], progress_key=progress_key + ) + projector_downloaded_path = pull_model_with_progress( + NEXA_RUN_PROJECTOR_MAP[request.model_path], progress_key=progress_key + ) + return { + "status": "success", + "message": "Successfully downloaded model and projector", + "model_path": request.model_path, + "model_local_path": downloaded_path["local_path"], + "projector_local_path": projector_downloaded_path["local_path"], + "model_type": downloaded_path["run_type"] + } + elif request.model_path in NEXA_RUN_OMNI_VLM_MAP: + downloaded_path = pull_model_with_progress( + NEXA_RUN_OMNI_VLM_MAP[request.model_path], progress_key=progress_key + ) + projector_downloaded_path = pull_model_with_progress( + NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[request.model_path], progress_key=progress_key + ) + return { + "status": "success", + "message": "Successfully downloaded model and projector", + "model_path": request.model_path, + "model_local_path": downloaded_path["local_path"], + "projector_local_path": projector_downloaded_path["local_path"], + "model_type": downloaded_path["run_type"] + } + elif request.model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM: + downloaded_path = pull_model_with_progress( + NEXA_RUN_MODEL_MAP_AUDIO_LM[request.model_path], progress_key=progress_key + ) + projector_downloaded_path = pull_model_with_progress( + NEXA_RUN_AUDIO_LM_PROJECTOR_MAP[request.model_path], progress_key=progress_key + ) + return { + "status": "success", + "message": "Successfully downloaded model and projector", + "model_path": request.model_path, + "model_local_path": downloaded_path["local_path"], + "projector_local_path": projector_downloaded_path["local_path"], + "model_type": downloaded_path["run_type"] + } + else: + downloaded_path = pull_model_with_progress( + request.model_path, progress_key=progress_key + ) + return { + "status": "success", + "message": "Successfully downloaded model", + "model_path": request.model_path, + "model_local_path": downloaded_path["local_path"], + "model_type": downloaded_path["run_type"] + } + except Exception as e: + logging.error(f"Error during download: {e}") + download_progress[progress_key] = -1 # Mark download as failed + raise + + # Execute the download in a background thread + loop = asyncio.get_event_loop() + result = await loop.run_in_executor(None, perform_download) + + # Return the result of the download + return result except Exception as e: + # Log error and raise HTTP exception logging.error(f"Error downloading model: {e}") raise HTTPException( status_code=500, detail=f"Failed to download model: {str(e)}" ) +async def progress_generator(model_path: str): + """ + A generator to stream download progress updates. + """ + try: + while True: + progress = download_progress.get(model_path, -1) + # Check if the model_path exists in download_progress + if progress == -1: + yield f"data: {{\"error\": \"Download failed or invalid model path.\"}}\n\n" + break + + yield f"data: {{\"model_name\": \"{model_path}\", \"progress\": {progress}}}\n\n" + + if progress >= 100: + break + + await asyncio.sleep(1) + except Exception as e: + yield f"data: {{\"error\": \"Error streaming progress: {str(e)}\"}}\n\n" + + +@app.get("/v1/download_progress", tags=["Model"]) +async def get_download_progress(model_path: str): + """ + Stream the download progress for a specific model. + """ + if model_path not in download_progress: + raise HTTPException(status_code=404, detail="No download progress found for the specified model.") + + # Return a StreamingResponse + return StreamingResponse(progress_generator(model_path), media_type="text/event-stream") + @app.post("/v1/load_model", tags=["Model"]) async def load_different_model(request: LoadModelRequest): """Load a different model while maintaining the global model state""" From c72f1d4b6a7ed21760d10b26b013e4bc81296790 Mon Sep 17 00:00:00 2001 From: MaokunZhang Date: Thu, 2 Jan 2025 22:47:13 -0800 Subject: [PATCH 2/4] Update download_progress API --- nexa/general.py | 8 ++- nexa/gguf/server/nexa_service.py | 83 ++++++++++++++++++-------------- 2 files changed, 53 insertions(+), 38 deletions(-) diff --git a/nexa/general.py b/nexa/general.py index 19565f58..ea5a1dd0 100644 --- a/nexa/general.py +++ b/nexa/general.py @@ -397,7 +397,7 @@ def download_file_with_progress( # Call the progress callback with the current progress if progress_callback: - progress_callback(downloaded_size, file_size) + progress_callback(downloaded_size, file_size, stage="downloading") except Exception as e: print(f"Error downloading chunk {chunk_number}: {e}") @@ -405,6 +405,12 @@ def download_file_with_progress( progress_bar.close() if all(completed_chunks): + # Transition to Verifying phase + try: + if progress_callback: + progress_callback(downloaded_size, file_size, stage="verifying") + except Exception as e: + print(f"Error Verifying chunk {chunk_number}: {e}") # Create a new progress bar for combining chunks combine_progress = tqdm( total=file_size, diff --git a/nexa/gguf/server/nexa_service.py b/nexa/gguf/server/nexa_service.py index 6f13dc84..1791e755 100644 --- a/nexa/gguf/server/nexa_service.py +++ b/nexa/gguf/server/nexa_service.py @@ -756,7 +756,8 @@ def pull_model_with_progress(model_path, progress_key, **kwargs): try: # Initialize progress tracking download_progress[progress_key] = 0 - + current_file_completed = False + # Extract local download path local_download_path = kwargs.get('local_download_path') base_download_dir = Path(local_download_path) if local_download_path else NEXA_MODELS_HUB_OFFICIAL_DIR @@ -764,64 +765,71 @@ def pull_model_with_progress(model_path, progress_key, **kwargs): file_extension = ".zip" if kwargs.get("model_type") in ["onnx", "bin"] else ".gguf" filename = f"{model_version}{file_extension}" file_path = base_download_dir / model_name / filename - + + # Record expected file details + expected_files = [ + { + "path": file_path, + "size": int(requests.head(f"{NEXA_OFFICIAL_BUCKET}/{model_name}/{filename}").headers.get("Content-Length", 0)) + } + ] + # Progress tracker def monitor_progress(file_path, total_size): """ Monitor file size growth to estimate progress. """ - while not os.path.exists(file_path): + nonlocal current_file_completed + while not current_file_completed: + # Update downloading progress time.sleep(0.5) - - while True: - try: - current_size = os.path.getsize(file_path) - progress = min(int((current_size / total_size) * 100), 99) - download_progress[progress_key] = progress - - # Break the loop if the file size stops growing - if current_size >= total_size: - break - time.sleep(0.5) # Adjust frequency of checking - except FileNotFoundError: - # Handle cases where the file gets deleted during download - break - def progress_callback(downloaded_chunks, total_chunks): + def progress_callback(downloaded_chunks, total_chunks, stage="downloading"): """ Callback to update progress based on downloaded chunks. """ - if total_chunks > 0: - progress = int((downloaded_chunks / total_chunks) * 100) - download_progress[progress_key] = progress - + nonlocal current_file_completed + if stage == "downloading": + if total_chunks > 0: + progress = int((downloaded_chunks / total_chunks) * 100) + download_progress[progress_key] = min(progress, 99) + if downloaded_chunks == total_chunks: + current_file_completed = True # Mark file as completed + elif stage == "verifying": + download_progress[progress_key] = 100 + # Call pull_model to start the download url = f"{NEXA_OFFICIAL_BUCKET}/{model_name}/{filename}" response = requests.head(url) total_size = int(response.headers.get("Content-Length", 0)) - + # Start monitoring progress in a background thread from threading import Thread progress_thread = Thread(target=monitor_progress, args=(file_path, total_size)) progress_thread.start() - + # Call download_file_with_progress with progress callback result = pull_model( model_path= model_path, hf=kwargs.get("hf", False), ms=kwargs.get("ms", False), - progress_callback=progress_callback, + progress_callback=lambda downloaded, total, stage: progress_callback(downloaded, total, stage), **kwargs, ) + if not result or len(result) != 2: + raise ValueError("Invalid response from pull_model.") + final_file_path, run_type = result - + if not final_file_path or not run_type: raise ValueError("Failed to download model or invalid response from pull_model.") - + # Extract model type from the returned file path or extension model_type = Path(final_file_path).suffix.strip(".") or "undefined" - + + download_progress[progress_key] = 100 + return { "local_path": str(final_file_path), "model_type": model_type, @@ -858,7 +866,7 @@ async def download_model(request: DownloadModelRequest): # Initialize progress tracking progress_key = request.model_path download_progress[progress_key] = 0 - + def perform_download(): """ Perform the download process with progress tracking. @@ -924,14 +932,14 @@ def perform_download(): logging.error(f"Error during download: {e}") download_progress[progress_key] = -1 # Mark download as failed raise - + # Execute the download in a background thread loop = asyncio.get_event_loop() result = await loop.run_in_executor(None, perform_download) - + # Return the result of the download return result - + except Exception as e: # Log error and raise HTTP exception logging.error(f"Error downloading model: {e}") @@ -951,12 +959,13 @@ async def progress_generator(model_path: str): if progress == -1: yield f"data: {{\"error\": \"Download failed or invalid model path.\"}}\n\n" break - + yield f"data: {{\"model_name\": \"{model_path}\", \"progress\": {progress}}}\n\n" - - if progress >= 100: + + if progress == 100: + # if model_path in download_progress and download_progress[model_path] == 100: break - + await asyncio.sleep(1) except Exception as e: yield f"data: {{\"error\": \"Error streaming progress: {str(e)}\"}}\n\n" @@ -969,7 +978,7 @@ async def get_download_progress(model_path: str): """ if model_path not in download_progress: raise HTTPException(status_code=404, detail="No download progress found for the specified model.") - + # Return a StreamingResponse return StreamingResponse(progress_generator(model_path), media_type="text/event-stream") From 719749938c983857d0af71da510be59520c0eb0e Mon Sep 17 00:00:00 2001 From: MaokunZhang Date: Fri, 3 Jan 2025 12:01:56 -0800 Subject: [PATCH 3/4] solve duplicate import --- nexa/gguf/server/nexa_service.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/nexa/gguf/server/nexa_service.py b/nexa/gguf/server/nexa_service.py index 1791e755..025af513 100644 --- a/nexa/gguf/server/nexa_service.py +++ b/nexa/gguf/server/nexa_service.py @@ -22,11 +22,8 @@ from pydantic import BaseModel, HttpUrl, AnyUrl, Field import requests from io import BytesIO -from PIL import Image -import base64 from urllib.parse import urlparse import asyncio -from fastapi.responses import StreamingResponse from nexa.constants import ( NEXA_MODELS_HUB_OFFICIAL_DIR, From 5d1c897e8068ba87ef4728d06329af4871f75c19 Mon Sep 17 00:00:00 2001 From: MaokunZhang Date: Fri, 3 Jan 2025 12:04:34 -0800 Subject: [PATCH 4/4] solve duplicate import --- nexa/gguf/server/nexa_service.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nexa/gguf/server/nexa_service.py b/nexa/gguf/server/nexa_service.py index 025af513..0d0e25bc 100644 --- a/nexa/gguf/server/nexa_service.py +++ b/nexa/gguf/server/nexa_service.py @@ -41,7 +41,6 @@ NEXA_RUN_MODEL_PRECISION_MAP, NEXA_RUN_MODEL_MAP_FUNCTION_CALLING, NEXA_MODEL_LIST_PATH, - NEXA_MODEL_LIST_PATH, NEXA_OFFICIAL_BUCKET, ) from nexa.gguf.converter.constants import NEXA_MODELS_HUB_DIR