diff --git a/nexa/general.py b/nexa/general.py index 2289f661..ea5a1dd0 100644 --- a/nexa/general.py +++ b/nexa/general.py @@ -341,6 +341,7 @@ def download_file_with_progress( chunk_size: int = 5 * 1024 * 1024, max_workers: int = 20, use_processes: bool = default_use_processes(), + progress_callback=None, **kwargs ): file_path.parent.mkdir(parents=True, exist_ok=True) @@ -376,6 +377,8 @@ def download_file_with_progress( else concurrent.futures.ThreadPoolExecutor ) + downloaded_size = 0 # Track the total downloaded size + with executor_class(max_workers=max_workers) as executor: future_to_chunk = { executor.submit( @@ -389,13 +392,25 @@ def download_file_with_progress( try: chunk_size, chunk_number = future.result() completed_chunks[chunk_number] = True + downloaded_size += chunk_size progress_bar.update(chunk_size) + + # Call the progress callback with the current progress + if progress_callback: + progress_callback(downloaded_size, file_size, stage="downloading") + except Exception as e: print(f"Error downloading chunk {chunk_number}: {e}") progress_bar.close() if all(completed_chunks): + # Transition to Verifying phase + try: + if progress_callback: + progress_callback(downloaded_size, file_size, stage="verifying") + except Exception as e: + print(f"Error Verifying chunk {chunk_number}: {e}") # Create a new progress bar for combining chunks combine_progress = tqdm( total=file_size, @@ -428,6 +443,7 @@ def download_file_with_progress( shutil.rmtree(temp_dir) + def download_model_from_official(model_path, model_type, **kwargs): try: model_name, model_version = model_path.split(":") diff --git a/nexa/gguf/server/nexa_service.py b/nexa/gguf/server/nexa_service.py index 580e8205..0d0e25bc 100644 --- a/nexa/gguf/server/nexa_service.py +++ b/nexa/gguf/server/nexa_service.py @@ -1,7 +1,11 @@ import json import logging import os +from pathlib import Path +import queue +import shutil import socket +import threading import time import uuid from typing import List, Optional, Dict, Any, Union, Literal @@ -9,6 +13,8 @@ import multiprocessing from PIL import Image import tempfile +import concurrent +import tqdm import uvicorn from fastapi import FastAPI, HTTPException, Request, File, UploadFile, Query from fastapi.middleware.cors import CORSMiddleware @@ -16,13 +22,16 @@ from pydantic import BaseModel, HttpUrl, AnyUrl, Field import requests from io import BytesIO -from PIL import Image -import base64 from urllib.parse import urlparse +import asyncio from nexa.constants import ( + NEXA_MODELS_HUB_OFFICIAL_DIR, + NEXA_OFFICIAL_MODELS_TYPE, NEXA_RUN_CHAT_TEMPLATE_MAP, + NEXA_RUN_MODEL_MAP_TEXT, NEXA_RUN_MODEL_MAP_VLM, + NEXA_RUN_MODEL_MAP_VOICE, NEXA_RUN_PROJECTOR_MAP, NEXA_RUN_OMNI_VLM_MAP, NEXA_RUN_OMNI_VLM_PROJECTOR_MAP, @@ -31,8 +40,10 @@ NEXA_RUN_COMPLETION_TEMPLATE_MAP, NEXA_RUN_MODEL_PRECISION_MAP, NEXA_RUN_MODEL_MAP_FUNCTION_CALLING, - NEXA_MODEL_LIST_PATH + NEXA_MODEL_LIST_PATH, + NEXA_OFFICIAL_BUCKET, ) +from nexa.gguf.converter.constants import NEXA_MODELS_HUB_DIR from nexa.gguf.lib_utils import is_gpu_available from nexa.gguf.llama.llama_chat_format import ( Llava15ChatHandler, @@ -40,7 +51,7 @@ NanoLlavaChatHandler, ) from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr -from nexa.general import pull_model +from nexa.general import add_model_to_list, default_use_processes, download_file_with_progress, get_model_info, is_model_exists, pull_model from nexa.gguf.llama.llama import Llama from nexa.gguf.nexa_inference_vlm_omni import NexaOmniVlmInference from nexa.gguf.nexa_inference_audio_lm import NexaAudioLMInference @@ -731,52 +742,242 @@ def _resp_async_generator(streamer, start_time): yield f"metrics: {MetricsResult(ttft=ttft, decoding_speed=decoding_times / (time.perf_counter() - start_time)).to_json()}\n\n" yield "data: [DONE]\n\n" +# Global variable for download progress tracking +download_progress = {} + +def pull_model_with_progress(model_path, progress_key, **kwargs): + """ + Wrapper for pull_model to track download progress using download_file_with_progress. + """ + try: + # Initialize progress tracking + download_progress[progress_key] = 0 + current_file_completed = False + + # Extract local download path + local_download_path = kwargs.get('local_download_path') + base_download_dir = Path(local_download_path) if local_download_path else NEXA_MODELS_HUB_OFFICIAL_DIR + model_name, model_version = model_path.split(":") + file_extension = ".zip" if kwargs.get("model_type") in ["onnx", "bin"] else ".gguf" + filename = f"{model_version}{file_extension}" + file_path = base_download_dir / model_name / filename + + # Record expected file details + expected_files = [ + { + "path": file_path, + "size": int(requests.head(f"{NEXA_OFFICIAL_BUCKET}/{model_name}/{filename}").headers.get("Content-Length", 0)) + } + ] + + # Progress tracker + def monitor_progress(file_path, total_size): + """ + Monitor file size growth to estimate progress. + """ + nonlocal current_file_completed + while not current_file_completed: + # Update downloading progress + time.sleep(0.5) + + def progress_callback(downloaded_chunks, total_chunks, stage="downloading"): + """ + Callback to update progress based on downloaded chunks. + """ + nonlocal current_file_completed + if stage == "downloading": + if total_chunks > 0: + progress = int((downloaded_chunks / total_chunks) * 100) + download_progress[progress_key] = min(progress, 99) + if downloaded_chunks == total_chunks: + current_file_completed = True # Mark file as completed + elif stage == "verifying": + download_progress[progress_key] = 100 + + # Call pull_model to start the download + url = f"{NEXA_OFFICIAL_BUCKET}/{model_name}/{filename}" + response = requests.head(url) + total_size = int(response.headers.get("Content-Length", 0)) + + # Start monitoring progress in a background thread + from threading import Thread + progress_thread = Thread(target=monitor_progress, args=(file_path, total_size)) + progress_thread.start() + + # Call download_file_with_progress with progress callback + result = pull_model( + model_path= model_path, + hf=kwargs.get("hf", False), + ms=kwargs.get("ms", False), + progress_callback=lambda downloaded, total, stage: progress_callback(downloaded, total, stage), + **kwargs, + ) + + if not result or len(result) != 2: + raise ValueError("Invalid response from pull_model.") + + final_file_path, run_type = result + + if not final_file_path or not run_type: + raise ValueError("Failed to download model or invalid response from pull_model.") + + # Extract model type from the returned file path or extension + model_type = Path(final_file_path).suffix.strip(".") or "undefined" + + download_progress[progress_key] = 100 + + return { + "local_path": str(final_file_path), + "model_type": model_type, + "run_type": run_type, + } + except Exception as e: + download_progress[progress_key] = -1 # Mark download as failed + raise ValueError(f"Error in pull_model_with_progress: {e}") + +@app.get("/v1/check_model_type", tags=["Model"]) +async def check_model(model_path: str): + """ + Check if the model exists and return its type. + """ + model_name, model_version = model_path.split(":") + if model_name in NEXA_OFFICIAL_MODELS_TYPE: + model_type = NEXA_OFFICIAL_MODELS_TYPE[model_name].value + return { + "model_name": model_name, + "model_type": model_type + } + else: + raise HTTPException( + status_code=404, + detail=f"Model '{model_name}' not found in the official model list." + ) + @app.post("/v1/download_model", tags=["Model"]) async def download_model(request: DownloadModelRequest): - """Download a model from the model hub""" + """ + Download a model from the model hub with progress tracking. + """ try: - if request.model_path in NEXA_RUN_MODEL_MAP_VLM or request.model_path in NEXA_RUN_OMNI_VLM_MAP or request.model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM: # models and projectors - if request.model_path in NEXA_RUN_MODEL_MAP_VLM: - downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_VLM[request.model_path]) - projector_downloaded_path, _ = pull_model(NEXA_RUN_PROJECTOR_MAP[request.model_path]) - elif request.model_path in NEXA_RUN_OMNI_VLM_MAP: - downloaded_path, model_type = pull_model(NEXA_RUN_OMNI_VLM_MAP[request.model_path]) - projector_downloaded_path, _ = pull_model(NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[request.model_path]) - elif request.model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM: - downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_AUDIO_LM[request.model_path]) - projector_downloaded_path, _ = pull_model(NEXA_RUN_AUDIO_LM_PROJECTOR_MAP[request.model_path]) - - if not downloaded_path or not model_type: - return JSONResponse(content="Failed to download model. Please check whether model_path is correct.", status_code=400) - - return { - "status": "success", - "message": "Successfully downloaded model and projector", - "model_path": request.model_path, - "model_local_path": downloaded_path, - "projector_local_path": projector_downloaded_path, - "model_type": model_type - } - else: - downloaded_path, model_type = pull_model(request.model_path) - if not downloaded_path or not model_type: - return JSONResponse(content="Failed to download model. Please check whether model_path is correct.", status_code=400) + # Initialize progress tracking + progress_key = request.model_path + download_progress[progress_key] = 0 + + def perform_download(): + """ + Perform the download process with progress tracking. + """ + try: + if request.model_path in NEXA_RUN_MODEL_MAP_VLM: + downloaded_path = pull_model_with_progress( + NEXA_RUN_MODEL_MAP_VLM[request.model_path], progress_key=progress_key + ) + projector_downloaded_path = pull_model_with_progress( + NEXA_RUN_PROJECTOR_MAP[request.model_path], progress_key=progress_key + ) + return { + "status": "success", + "message": "Successfully downloaded model and projector", + "model_path": request.model_path, + "model_local_path": downloaded_path["local_path"], + "projector_local_path": projector_downloaded_path["local_path"], + "model_type": downloaded_path["run_type"] + } + elif request.model_path in NEXA_RUN_OMNI_VLM_MAP: + downloaded_path = pull_model_with_progress( + NEXA_RUN_OMNI_VLM_MAP[request.model_path], progress_key=progress_key + ) + projector_downloaded_path = pull_model_with_progress( + NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[request.model_path], progress_key=progress_key + ) + return { + "status": "success", + "message": "Successfully downloaded model and projector", + "model_path": request.model_path, + "model_local_path": downloaded_path["local_path"], + "projector_local_path": projector_downloaded_path["local_path"], + "model_type": downloaded_path["run_type"] + } + elif request.model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM: + downloaded_path = pull_model_with_progress( + NEXA_RUN_MODEL_MAP_AUDIO_LM[request.model_path], progress_key=progress_key + ) + projector_downloaded_path = pull_model_with_progress( + NEXA_RUN_AUDIO_LM_PROJECTOR_MAP[request.model_path], progress_key=progress_key + ) + return { + "status": "success", + "message": "Successfully downloaded model and projector", + "model_path": request.model_path, + "model_local_path": downloaded_path["local_path"], + "projector_local_path": projector_downloaded_path["local_path"], + "model_type": downloaded_path["run_type"] + } + else: + downloaded_path = pull_model_with_progress( + request.model_path, progress_key=progress_key + ) + return { + "status": "success", + "message": "Successfully downloaded model", + "model_path": request.model_path, + "model_local_path": downloaded_path["local_path"], + "model_type": downloaded_path["run_type"] + } + except Exception as e: + logging.error(f"Error during download: {e}") + download_progress[progress_key] = -1 # Mark download as failed + raise - return { - "status": "success", - "message": "Successfully downloaded model", - "model_path": request.model_path, - "model_local_path": downloaded_path, - "model_type": model_type - } - + # Execute the download in a background thread + loop = asyncio.get_event_loop() + result = await loop.run_in_executor(None, perform_download) + + # Return the result of the download + return result + except Exception as e: + # Log error and raise HTTP exception logging.error(f"Error downloading model: {e}") raise HTTPException( status_code=500, detail=f"Failed to download model: {str(e)}" ) +async def progress_generator(model_path: str): + """ + A generator to stream download progress updates. + """ + try: + while True: + progress = download_progress.get(model_path, -1) + # Check if the model_path exists in download_progress + if progress == -1: + yield f"data: {{\"error\": \"Download failed or invalid model path.\"}}\n\n" + break + + yield f"data: {{\"model_name\": \"{model_path}\", \"progress\": {progress}}}\n\n" + + if progress == 100: + # if model_path in download_progress and download_progress[model_path] == 100: + break + + await asyncio.sleep(1) + except Exception as e: + yield f"data: {{\"error\": \"Error streaming progress: {str(e)}\"}}\n\n" + + +@app.get("/v1/download_progress", tags=["Model"]) +async def get_download_progress(model_path: str): + """ + Stream the download progress for a specific model. + """ + if model_path not in download_progress: + raise HTTPException(status_code=404, detail="No download progress found for the specified model.") + + # Return a StreamingResponse + return StreamingResponse(progress_generator(model_path), media_type="text/event-stream") + @app.post("/v1/load_model", tags=["Model"]) async def load_different_model(request: LoadModelRequest): """Load a different model while maintaining the global model state"""