Skip to content

Commit

Permalink
refactor: estimate model insequence since file lock takes more time
Browse files Browse the repository at this point in the history
  • Loading branch information
aiwantaozi committed Dec 18, 2024
1 parent 1a5673b commit 1c8166d
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 41 deletions.
30 changes: 23 additions & 7 deletions vox_box/downloader/downloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,48 @@
logger = logging.getLogger(__name__)


def download_model(
def download_file(
huggingface_repo_id: Optional[str] = None,
huggingface_filename: Optional[str] = None,
model_scope_model_id: Optional[str] = None,
model_scope_file_path: Optional[str] = None,
cache_dir: Optional[str] = None,
huggingface_token: Optional[str] = None,
) -> str:
result_path = None
key = None

if huggingface_repo_id is not None:
return HfDownloader.download(
key = (
f"huggingface:{huggingface_repo_id}"
if huggingface_filename is None
else f"huggingface:{huggingface_repo_id}:{huggingface_filename}"
)
logger.debug(f"Downloading {key}")

result_path = HfDownloader.download(
repo_id=huggingface_repo_id,
filename=huggingface_filename,
token=huggingface_token,
cache_dir=os.path.join(cache_dir, "huggingface"),
)
elif model_scope_model_id is not None:
return ModelScopeDownloader.download(
key = (
f"modelscope:{model_scope_model_id}"
if model_scope_file_path is None
else f"modelscope:{model_scope_model_id}:{model_scope_file_path}"
)
logger.debug(f"Downloading {key}")

result_path = ModelScopeDownloader.download(
model_id=model_scope_model_id,
file_path=model_scope_file_path,
cache_dir=os.path.join(cache_dir, "model_scope"),
)

logger.debug(f"Downloaded {key}")
return result_path


def get_file_size(
huggingface_repo_id: Optional[str] = None,
Expand Down Expand Up @@ -150,8 +170,6 @@ def download_file(
if len(matching_files) == 0:
raise ValueError(f"No file found in {repo_id} that match {filename}")

logger.info(f"Downloading model {repo_id}/{filename}")

subfolder, first_filename = (
str(Path(matching_files[0]).parent),
Path(matching_files[0]).name,
Expand Down Expand Up @@ -193,7 +211,6 @@ def _inner_hf_hub_download(repo_file: str):
else:
model_path = os.path.join(local_dir, first_filename)

logger.info(f"Downloaded model {repo_id}/{filename}")
return model_path

def __call__(self):
Expand Down Expand Up @@ -242,7 +259,6 @@ def download(
name = name.replace(".", "___")
lock_filename = os.path.join(cache_dir, group_or_owner, f"{name}.lock")

logger.info("Retriving file lock")
with FileLock(lock_filename):
if file_path is not None:
matching_files = match_model_scope_file_paths(model_id, file_path)
Expand Down
12 changes: 3 additions & 9 deletions vox_box/estimator/bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from typing import Dict
from vox_box.config.config import BackendEnum, Config, TaskTypeEnum
from vox_box.downloader.downloaders import download_model
from vox_box.downloader.downloaders import download_file
from vox_box.estimator.base import Estimator
from vox_box.utils.model import create_model_dict

Expand All @@ -18,10 +18,8 @@ def __init__(
self._cfg = cfg
self._required_files = [
"config.json",
"speaker_embeddings_path.json",
]
self._config_json = None
self._speaker_json = None

def model_info(self) -> Dict:
model = (
Expand Down Expand Up @@ -61,17 +59,13 @@ def _check_local_model(self, base_dir: str) -> bool:
if architectures is not None and "BarkModel" in architectures:
supported = True

speaker_path = os.path.join(base_dir, "speaker_embeddings_path.json")
with open(speaker_path, "r", encoding="utf-8") as f:
self._speaker_json = json.load(f)

return supported

def _check_remote_model(self) -> bool:
downloaded_files = []
for f in self._required_files:
try:
downloaded_file_path = download_model(
downloaded_file_path = download_file(
huggingface_repo_id=self._cfg.huggingface_repo_id,
huggingface_filename=f,
model_scope_model_id=self._cfg.model_scope_model_id,
Expand All @@ -80,7 +74,7 @@ def _check_remote_model(self) -> bool:
)
downloaded_files.append(downloaded_file_path)
except Exception as e:
logger.error(f"Failed to download {f}, {e}")
logger.debug(f"File {f} does not exist, {e}")
continue

if len(downloaded_files) != 0:
Expand Down
6 changes: 3 additions & 3 deletions vox_box/estimator/cosyvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from typing import Dict

from vox_box.downloader.downloaders import download_model
from vox_box.downloader.downloaders import download_file
from vox_box.estimator.base import Estimator

from vox_box.config.config import BackendEnum, Config, TaskTypeEnum
Expand Down Expand Up @@ -57,15 +57,15 @@ def _check_remote_model(self) -> bool:
downloaded_files = []
for f in self._required_files:
try:
download_file_path = download_model(
download_file_path = download_file(
huggingface_repo_id=self._cfg.huggingface_repo_id,
huggingface_filename=f,
model_scope_model_id=self._cfg.model_scope_model_id,
model_scope_file_path=f,
cache_dir=self._cfg.cache_dir,
)
except Exception as e:
logger.error(f"Failed to download {f}, {e}")
logger.debug(f"File {f} does not exist, {e}")
continue
downloaded_files.append(download_file_path)

Expand Down
22 changes: 8 additions & 14 deletions vox_box/estimator/estimate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Dict, List
from vox_box.config.config import Config
from vox_box.estimator.bark import Bark
Expand All @@ -6,14 +7,15 @@
from vox_box.estimator.faster_whisper import FasterWhisper
from vox_box.estimator.funasr import FunASR
from vox_box.utils.model import create_model_dict
from concurrent.futures import ThreadPoolExecutor, as_completed

logger = logging.getLogger(__name__)


def estimate_model(cfg: Config) -> Dict:
estimators: List[Estimator] = [
CosyVoice(cfg),
FasterWhisper(cfg),
FunASR(cfg),
CosyVoice(cfg),
Bark(cfg),
]

Expand All @@ -23,17 +25,9 @@ def estimate_model(cfg: Config) -> Dict:
supported=False,
)

def get_model_info(estimator: Estimator) -> Dict:
return estimator.model_info()

with ThreadPoolExecutor() as executor:
futures = {executor.submit(get_model_info, e): e for e in estimators}
for future in as_completed(futures):
result = future.result()
if result["supported"]:
for f in futures:
if not f.done():
f.cancel()
return result
for estimator in estimators:
model_info = estimator.model_info()
if model_info["supported"]:
return model_info

return model_info
8 changes: 4 additions & 4 deletions vox_box/estimator/faster_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from typing import Dict, List
from vox_box.config.config import BackendEnum, Config, TaskTypeEnum
from vox_box.downloader.downloaders import download_model
from vox_box.downloader.downloaders import download_file
from vox_box.downloader.hub import match_files
from vox_box.estimator.base import Estimator
from vox_box.utils.model import create_model_dict
Expand Down Expand Up @@ -107,22 +107,22 @@ def _check_remote_model(self) -> bool: # noqa: C901
if "model.bin" not in matching_files:
return False
except Exception as e:
logger.error(f"Failed to download model file for estimating, {e}")
logger.debug(f"File model.bin does not exist, {e}")
return False

downloaded_files = []
download_files = ["tokenizer.json", "preprocessor_config.json"]
for f in download_files:
try:
downloaded_file_path = download_model(
downloaded_file_path = download_file(
huggingface_repo_id=self._cfg.huggingface_repo_id,
huggingface_filename=f,
model_scope_model_id=self._cfg.model_scope_model_id,
model_scope_file_path=f,
cache_dir=self._cfg.cache_dir,
)
except Exception as e:
logger.error(f"Failed to download {f} for model estimate, {e}")
logger.debug(f"File {f} does not exist, {e}")
continue

downloaded_files.append(downloaded_file_path)
Expand Down
6 changes: 3 additions & 3 deletions vox_box/estimator/funasr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import yaml
from vox_box.config.config import BackendEnum, Config, TaskTypeEnum
from vox_box.downloader.downloaders import download_model
from vox_box.downloader.downloaders import download_file
from vox_box.estimator.base import Estimator
from vox_box.utils.model import create_model_dict

Expand Down Expand Up @@ -91,15 +91,15 @@ def _check_remote_model(self) -> Tuple[bool, str]:
downloaded_files = []
for f in self._optional_files:
try:
download_file_path = download_model(
download_file_path = download_file(
huggingface_repo_id=self._cfg.huggingface_repo_id,
huggingface_filename=f,
model_scope_model_id=self._cfg.model_scope_model_id,
model_scope_file_path=f,
cache_dir=self._cfg.cache_dir,
)
except Exception as e:
logger.error(f"Failed to download {f} for model estimate, {e}")
logger.debug(f"File {f} does not exist, {e}")
continue
downloaded_files.append(download_file_path)

Expand Down
10 changes: 9 additions & 1 deletion vox_box/server/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Union
from vox_box.backends.stt.base import STTBackend
from vox_box.backends.stt.faster_whisper import FasterWhisper
Expand All @@ -11,12 +12,17 @@

_instance = None

logger = logging.getLogger(__name__)


class ModelInstance:
def __init__(self, cfg: Config):
self._cfg = cfg
self._backend_framework = None

logger.info("Estimating model")
self._estimate = estimate_model(cfg)
logger.info("Finished estimating model")
if (
self._estimate is None
or not self._estimate.get("supported", False)
Expand All @@ -30,7 +36,8 @@ def __init__(self, cfg: Config):
or self._cfg.model_scope_model_id is not None
):
try:
mode_path = downloaders.download_model(
logger.info("Downloading model")
mode_path = downloaders.download_file(
huggingface_repo_id=self._cfg.huggingface_repo_id,
model_scope_model_id=self._cfg.model_scope_model_id,
cache_dir=self._cfg.cache_dir,
Expand All @@ -54,6 +61,7 @@ def run(self):

if _instance is None:
try:
logger.info("Loading model")
_instance = self._backend_framework.load()
except Exception as e:
raise Exception(f"Faild to load model, {e}")
Expand Down

0 comments on commit 1c8166d

Please sign in to comment.