Skip to content

Commit

Permalink
Update download_progress API
Browse files Browse the repository at this point in the history
  • Loading branch information
MaokunZhang committed Jan 3, 2025
1 parent db0c615 commit c72f1d4
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 38 deletions.
8 changes: 7 additions & 1 deletion nexa/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,14 +397,20 @@ def download_file_with_progress(

# Call the progress callback with the current progress
if progress_callback:
progress_callback(downloaded_size, file_size)
progress_callback(downloaded_size, file_size, stage="downloading")

except Exception as e:
print(f"Error downloading chunk {chunk_number}: {e}")

progress_bar.close()

if all(completed_chunks):
# Transition to Verifying phase
try:
if progress_callback:
progress_callback(downloaded_size, file_size, stage="verifying")
except Exception as e:
print(f"Error Verifying chunk {chunk_number}: {e}")
# Create a new progress bar for combining chunks
combine_progress = tqdm(
total=file_size,
Expand Down
83 changes: 46 additions & 37 deletions nexa/gguf/server/nexa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,72 +756,80 @@ def pull_model_with_progress(model_path, progress_key, **kwargs):
try:
# Initialize progress tracking
download_progress[progress_key] = 0

current_file_completed = False

# Extract local download path
local_download_path = kwargs.get('local_download_path')
base_download_dir = Path(local_download_path) if local_download_path else NEXA_MODELS_HUB_OFFICIAL_DIR
model_name, model_version = model_path.split(":")
file_extension = ".zip" if kwargs.get("model_type") in ["onnx", "bin"] else ".gguf"
filename = f"{model_version}{file_extension}"
file_path = base_download_dir / model_name / filename


# Record expected file details
expected_files = [
{
"path": file_path,
"size": int(requests.head(f"{NEXA_OFFICIAL_BUCKET}/{model_name}/{filename}").headers.get("Content-Length", 0))
}
]

# Progress tracker
def monitor_progress(file_path, total_size):
"""
Monitor file size growth to estimate progress.
"""
while not os.path.exists(file_path):
nonlocal current_file_completed
while not current_file_completed:
# Update downloading progress
time.sleep(0.5)

while True:
try:
current_size = os.path.getsize(file_path)
progress = min(int((current_size / total_size) * 100), 99)
download_progress[progress_key] = progress

# Break the loop if the file size stops growing
if current_size >= total_size:
break
time.sleep(0.5) # Adjust frequency of checking
except FileNotFoundError:
# Handle cases where the file gets deleted during download
break

def progress_callback(downloaded_chunks, total_chunks):
def progress_callback(downloaded_chunks, total_chunks, stage="downloading"):
"""
Callback to update progress based on downloaded chunks.
"""
if total_chunks > 0:
progress = int((downloaded_chunks / total_chunks) * 100)
download_progress[progress_key] = progress

nonlocal current_file_completed
if stage == "downloading":
if total_chunks > 0:
progress = int((downloaded_chunks / total_chunks) * 100)
download_progress[progress_key] = min(progress, 99)
if downloaded_chunks == total_chunks:
current_file_completed = True # Mark file as completed
elif stage == "verifying":
download_progress[progress_key] = 100

# Call pull_model to start the download
url = f"{NEXA_OFFICIAL_BUCKET}/{model_name}/{filename}"
response = requests.head(url)
total_size = int(response.headers.get("Content-Length", 0))

# Start monitoring progress in a background thread
from threading import Thread
progress_thread = Thread(target=monitor_progress, args=(file_path, total_size))
progress_thread.start()

# Call download_file_with_progress with progress callback
result = pull_model(
model_path= model_path,
hf=kwargs.get("hf", False),
ms=kwargs.get("ms", False),
progress_callback=progress_callback,
progress_callback=lambda downloaded, total, stage: progress_callback(downloaded, total, stage),
**kwargs,
)

if not result or len(result) != 2:
raise ValueError("Invalid response from pull_model.")

final_file_path, run_type = result

if not final_file_path or not run_type:
raise ValueError("Failed to download model or invalid response from pull_model.")

# Extract model type from the returned file path or extension
model_type = Path(final_file_path).suffix.strip(".") or "undefined"


download_progress[progress_key] = 100

return {
"local_path": str(final_file_path),
"model_type": model_type,
Expand Down Expand Up @@ -858,7 +866,7 @@ async def download_model(request: DownloadModelRequest):
# Initialize progress tracking
progress_key = request.model_path
download_progress[progress_key] = 0

def perform_download():
"""
Perform the download process with progress tracking.
Expand Down Expand Up @@ -924,14 +932,14 @@ def perform_download():
logging.error(f"Error during download: {e}")
download_progress[progress_key] = -1 # Mark download as failed
raise

# Execute the download in a background thread
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, perform_download)

# Return the result of the download
return result

except Exception as e:
# Log error and raise HTTP exception
logging.error(f"Error downloading model: {e}")
Expand All @@ -951,12 +959,13 @@ async def progress_generator(model_path: str):
if progress == -1:
yield f"data: {{\"error\": \"Download failed or invalid model path.\"}}\n\n"
break

yield f"data: {{\"model_name\": \"{model_path}\", \"progress\": {progress}}}\n\n"

if progress >= 100:

if progress == 100:
# if model_path in download_progress and download_progress[model_path] == 100:
break

await asyncio.sleep(1)
except Exception as e:
yield f"data: {{\"error\": \"Error streaming progress: {str(e)}\"}}\n\n"
Expand All @@ -969,7 +978,7 @@ async def get_download_progress(model_path: str):
"""
if model_path not in download_progress:
raise HTTPException(status_code=404, detail="No download progress found for the specified model.")

# Return a StreamingResponse
return StreamingResponse(progress_generator(model_path), media_type="text/event-stream")

Expand Down

0 comments on commit c72f1d4

Please sign in to comment.