Skip to content

Commit

Permalink
Merge pull request #146 from NexaAI/perry/concurrency-download
Browse files Browse the repository at this point in the history
fixed multiprocessing imcompatible with streamlit on Macos and Windows problem
  • Loading branch information
zhiyuan8 authored Oct 4, 2024
2 parents 6f30ff0 + 012e881 commit 5b02a84
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 23 deletions.
37 changes: 27 additions & 10 deletions nexa/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
import os
from tqdm import tqdm
import platform

from nexa.constants import (
NEXA_API_URL,
Expand Down Expand Up @@ -105,7 +106,7 @@ def get_user_info(token):
return None


def pull_model(model_path, hf = False):
def pull_model(model_path, hf = False, **kwargs):
model_path = NEXA_RUN_MODEL_MAP.get(model_path, model_path)

try:
Expand All @@ -118,9 +119,9 @@ def pull_model(model_path, hf = False):
return location, run_type

if "/" in model_path:
result = pull_model_from_hub(model_path)
result = pull_model_from_hub(model_path, **kwargs)
else:
result = pull_model_from_official(model_path)
result = pull_model_from_official(model_path, **kwargs)

if result["success"]:
add_model_to_list(model_path, result["local_path"], result["model_type"], result["run_type"])
Expand All @@ -134,7 +135,7 @@ def pull_model(model_path, hf = False):
return None, "NLP"


def pull_model_from_hub(model_path):
def pull_model_from_hub(model_path, **kwargs):
NEXA_MODELS_HUB_DIR.mkdir(parents=True, exist_ok=True)

token = ""
Expand Down Expand Up @@ -174,7 +175,7 @@ def pull_model_from_hub(model_path):
for file_path, presigned_link in presigned_links.items():
try:
download_path = NEXA_MODELS_HUB_DIR / file_path
download_file_with_progress(presigned_link, download_path, use_processes=True)
download_file_with_progress(presigned_link, download_path, **kwargs)

if local_path is None:
if model_type == "onnx" or model_type == "bin":
Expand All @@ -196,7 +197,7 @@ def pull_model_from_hub(model_path):
}


def pull_model_from_official(model_path):
def pull_model_from_official(model_path, **kwargs):
NEXA_MODELS_HUB_OFFICIAL_DIR.mkdir(parents=True, exist_ok=True)

if "onnx" in model_path:
Expand All @@ -208,7 +209,7 @@ def pull_model_from_official(model_path):

run_type = get_run_type_from_model_path(model_path)
run_type_str = run_type.value if isinstance(run_type, ModelType) else str(run_type)
success, location = download_model_from_official(model_path, model_type)
success, location = download_model_from_official(model_path, model_type, **kwargs)

return {
"success": success,
Expand Down Expand Up @@ -288,14 +289,30 @@ def download_chunk(url, start, end, output_file, chunk_number):
if attempt == max_retries - 1:
raise
time.sleep(2 ** attempt) # Exponential backoff


def default_use_processes():
"""
Distinct operating systems may have different default behaviors for threading vs multiprocessing.
"""
platform_name = platform.system()
if platform_name == "Linux":
return True
elif platform_name == "Windows":
return False
elif platform_name == "Darwin":
return False
else:
return False


def download_file_with_progress(
url: str,
file_path: Path,
chunk_size: int = 40 * 1024 * 1024,
max_workers: int = 20,
use_processes: bool = False
use_processes: bool = default_use_processes(),
**kwargs
):
file_path.parent.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -376,7 +393,7 @@ def download_file_with_progress(
os.remove(file_path)


def download_model_from_official(model_path, model_type):
def download_model_from_official(model_path, model_type, **kwargs):
try:
model_name, model_version = model_path.split(":")
file_extension = ".zip" if model_type == "onnx" or model_type == "bin" else ".gguf"
Expand All @@ -388,7 +405,7 @@ def download_model_from_official(model_path, model_type):
download_url = f"{NEXA_OFFICIAL_BUCKET}{filepath}"

full_path.parent.mkdir(parents=True, exist_ok=True)
download_file_with_progress(download_url, full_path, use_processes=True)
download_file_with_progress(download_url, full_path, **kwargs)

if model_type == "onnx" or model_type == "bin":
unzipped_folder = full_path.parent / model_version
Expand Down
8 changes: 4 additions & 4 deletions nexa/gguf/nexa_inference_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, model_path, local_path=None, **kwargs):

# Download base model if not provided
if self.downloaded_path is None:
self.downloaded_path, _ = pull_model(self.model_path)
self.downloaded_path, _ = pull_model(self.model_path, **kwargs)
if self.downloaded_path is None:
logging.error(
f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.",
Expand All @@ -92,11 +92,11 @@ def __init__(self, model_path, local_path=None, **kwargs):
self.clip_l_path = FLUX_CLIP_L_PATH

if self.t5xxl_path:
self.t5xxl_downloaded_path, _ = pull_model(self.t5xxl_path)
self.t5xxl_downloaded_path, _ = pull_model(self.t5xxl_path, **kwargs)
if self.ae_path:
self.ae_downloaded_path, _ = pull_model(self.ae_path)
self.ae_downloaded_path, _ = pull_model(self.ae_path, **kwargs)
if self.clip_l_path:
self.clip_l_downloaded_path, _ = pull_model(self.clip_l_path)
self.clip_l_downloaded_path, _ = pull_model(self.clip_l_path, **kwargs)
if "lcm-dreamshaper" in self.model_path or "flux" in self.model_path:
self.params = DEFAULT_IMG_GEN_PARAMS_LCM.copy() # both lcm-dreamshaper and flux use the same params
elif "sdxl-turbo" in self.model_path:
Expand Down
2 changes: 1 addition & 1 deletion nexa/gguf/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, model_path, local_path=None, stop_words=None, **kwargs):
self.top_logprobs = kwargs.get('top_logprobs', None)

if self.downloaded_path is None:
self.downloaded_path, _ = pull_model(self.model_path)
self.downloaded_path, _ = pull_model(self.model_path, **kwargs)

if self.downloaded_path is None:
logging.error(
Expand Down
6 changes: 3 additions & 3 deletions nexa/gguf/nexa_inference_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ def __init__(self, model_path, local_path=None, stop_words=None, **kwargs):
if self.downloaded_path is not None:
if model_path in NEXA_RUN_MODEL_MAP_VLM:
self.projector_path = NEXA_RUN_PROJECTOR_MAP[model_path]
self.projector_downloaded_path, _ = pull_model(self.projector_path)
self.projector_downloaded_path, _ = pull_model(self.projector_path, **kwargs)
elif model_path in NEXA_RUN_MODEL_MAP_VLM:
self.model_path = NEXA_RUN_MODEL_MAP_VLM[model_path]
self.projector_path = NEXA_RUN_PROJECTOR_MAP[model_path]
self.downloaded_path, _ = pull_model(self.model_path)
self.projector_downloaded_path, _ = pull_model(self.projector_path)
self.downloaded_path, _ = pull_model(self.model_path, **kwargs)
self.projector_downloaded_path, _ = pull_model(self.projector_path, **kwargs)
elif Path(model_path).parent.exists():
local_dir = Path(model_path).parent
model_name = Path(model_path).name
Expand Down
2 changes: 1 addition & 1 deletion nexa/gguf/nexa_inference_voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, model_path, local_path=None, **kwargs):
self.params = DEFAULT_VOICE_GEN_PARAMS

if self.downloaded_path is None:
self.downloaded_path, _ = pull_model(self.model_path)
self.downloaded_path, _ = pull_model(self.model_path, **kwargs)

if self.downloaded_path is None:
logging.error(
Expand Down
2 changes: 1 addition & 1 deletion nexa/onnx/nexa_inference_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, model_path, local_path=None, **kwargs):
def run(self):

if self.download_onnx_folder is None:
self.download_onnx_folder, run_type = pull_model(self.model_path)
self.download_onnx_folder, run_type = pull_model(self.model_path, **kwargs)

if self.download_onnx_folder is None:
logging.error(
Expand Down
2 changes: 1 addition & 1 deletion nexa/onnx/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def run(self):
self.run_streamlit()
else:
if self.downloaded_onnx_folder is None:
self.downloaded_onnx_folder, run_type = pull_model(self.model_path)
self.downloaded_onnx_folder, run_type = pull_model(self.model_path, **kwargs)

if self.downloaded_onnx_folder is None:
logging.error(
Expand Down
2 changes: 1 addition & 1 deletion nexa/onnx/nexa_inference_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, model_path, local_path=None, **kwargs):
self.downloaded_onnx_folder = local_path

if self.downloaded_onnx_folder is None:
self.downloaded_onnx_folder, run_type = pull_model(self.model_path)
self.downloaded_onnx_folder, run_type = pull_model(self.model_path, **kwargs)

if self.downloaded_onnx_folder is None:
logging.error(
Expand Down
2 changes: 1 addition & 1 deletion nexa/onnx/nexa_inference_voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, model_path, local_path=None, **kwargs):

def run(self):
if self.downloaded_onnx_folder is None:
self.downloaded_onnx_folder, run_type = pull_model(self.model_path)
self.downloaded_onnx_folder, run_type = pull_model(self.model_path, **kwargs)

if self.downloaded_onnx_folder is None:
logging.error(
Expand Down

0 comments on commit 5b02a84

Please sign in to comment.