diff --git a/nexa/gguf/server/nexa_service.py b/nexa/gguf/server/nexa_service.py index 0a04ce57..a3e995f9 100644 --- a/nexa/gguf/server/nexa_service.py +++ b/nexa/gguf/server/nexa_service.py @@ -29,6 +29,7 @@ NEXA_MODELS_HUB_OFFICIAL_DIR, NEXA_OFFICIAL_MODELS_TYPE, NEXA_RUN_CHAT_TEMPLATE_MAP, + NEXA_RUN_MODEL_MAP, NEXA_RUN_MODEL_MAP_TEXT, NEXA_RUN_MODEL_MAP_VLM, NEXA_RUN_MODEL_MAP_VOICE, @@ -752,6 +753,8 @@ def pull_model_with_progress(model_path, progress_key, **kwargs): """ Wrapper for pull_model to track download progress using download_file_with_progress. """ + model_path = NEXA_RUN_MODEL_MAP.get(model_path, model_path) + try: # Initialize progress tracking download_progress[progress_key] = 0 @@ -797,7 +800,6 @@ def progress_callback(downloaded_chunks, total_chunks, stage="downloading"): elif stage == "verifying": download_progress[progress_key] = 100 - # Call pull_model to start the download url = f"{NEXA_OFFICIAL_BUCKET}/{model_name}/{filename}" response = requests.head(url) total_size = int(response.headers.get("Content-Length", 0)) @@ -807,7 +809,7 @@ def progress_callback(downloaded_chunks, total_chunks, stage="downloading"): progress_thread = Thread(target=monitor_progress, args=(file_path, total_size)) progress_thread.start() - # Call download_file_with_progress with progress callback + # Call pull_model to start the download result = pull_model( model_path= model_path, hf=kwargs.get("hf", False), @@ -839,12 +841,19 @@ def progress_callback(downloaded_chunks, total_chunks, stage="downloading"): raise ValueError(f"Error in pull_model_with_progress: {e}") @app.get("/v1/check_model_type", tags=["Model"]) -async def check_model(model_path: str): +async def check_model_type(model_path: str): """ Check if the model exists and return its type. """ - model_name, model_version = model_path.split(":") - if model_name in NEXA_OFFICIAL_MODELS_TYPE: + model_name = NEXA_RUN_MODEL_MAP.get(model_path, model_path) + + if ":" in model_name: + model_name = model_name.split(":")[0] + else: + model_name = model_name + + if model_name in NEXA_RUN_MODEL_MAP or NEXA_RUN_CHAT_TEMPLATE_MAP: + model_type = NEXA_OFFICIAL_MODELS_TYPE[model_name].value return { "model_name": model_name,