Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update check_model_type to allowed more input of model path. #339

Merged
merged 1 commit into from
Jan 5, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions nexa/gguf/server/nexa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
NEXA_MODELS_HUB_OFFICIAL_DIR,
NEXA_OFFICIAL_MODELS_TYPE,
NEXA_RUN_CHAT_TEMPLATE_MAP,
NEXA_RUN_MODEL_MAP,
NEXA_RUN_MODEL_MAP_TEXT,
NEXA_RUN_MODEL_MAP_VLM,
NEXA_RUN_MODEL_MAP_VOICE,
Expand Down Expand Up @@ -752,6 +753,8 @@ def pull_model_with_progress(model_path, progress_key, **kwargs):
"""
Wrapper for pull_model to track download progress using download_file_with_progress.
"""
model_path = NEXA_RUN_MODEL_MAP.get(model_path, model_path)

try:
# Initialize progress tracking
download_progress[progress_key] = 0
Expand Down Expand Up @@ -797,7 +800,6 @@ def progress_callback(downloaded_chunks, total_chunks, stage="downloading"):
elif stage == "verifying":
download_progress[progress_key] = 100

# Call pull_model to start the download
url = f"{NEXA_OFFICIAL_BUCKET}/{model_name}/{filename}"
response = requests.head(url)
total_size = int(response.headers.get("Content-Length", 0))
Expand All @@ -807,7 +809,7 @@ def progress_callback(downloaded_chunks, total_chunks, stage="downloading"):
progress_thread = Thread(target=monitor_progress, args=(file_path, total_size))
progress_thread.start()

# Call download_file_with_progress with progress callback
# Call pull_model to start the download
result = pull_model(
model_path= model_path,
hf=kwargs.get("hf", False),
Expand Down Expand Up @@ -839,12 +841,19 @@ def progress_callback(downloaded_chunks, total_chunks, stage="downloading"):
raise ValueError(f"Error in pull_model_with_progress: {e}")

@app.get("/v1/check_model_type", tags=["Model"])
async def check_model(model_path: str):
async def check_model_type(model_path: str):
"""
Check if the model exists and return its type.
"""
model_name, model_version = model_path.split(":")
if model_name in NEXA_OFFICIAL_MODELS_TYPE:
model_name = NEXA_RUN_MODEL_MAP.get(model_path, model_path)

if ":" in model_name:
model_name = model_name.split(":")[0]
else:
model_name = model_name

if model_name in NEXA_RUN_MODEL_MAP or NEXA_RUN_CHAT_TEMPLATE_MAP:

model_type = NEXA_OFFICIAL_MODELS_TYPE[model_name].value
return {
"model_name": model_name,
Expand Down