Skip to content

Commit

Permalink
feat: tts
Browse files Browse the repository at this point in the history
  • Loading branch information
Fedir Zadniprovskyi committed Oct 31, 2024
1 parent 39d5220 commit d837657
Show file tree
Hide file tree
Showing 7 changed files with 571 additions and 66 deletions.
17 changes: 13 additions & 4 deletions src/faster_whisper_server/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import Depends

from faster_whisper_server.config import Config
from faster_whisper_server.model_manager import ModelManager
from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager


@lru_cache
Expand All @@ -16,9 +16,18 @@ def get_config() -> Config:


@lru_cache
def get_model_manager() -> ModelManager:
def get_model_manager() -> WhisperModelManager:
config = get_config() # HACK
return ModelManager(config.whisper)
return WhisperModelManager(config.whisper)


ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)]
ModelManagerDependency = Annotated[WhisperModelManager, Depends(get_model_manager)]


@lru_cache
def get_piper_model_manager() -> PiperModelManager:
config = get_config() # HACK
return PiperModelManager(config.whisper.ttl) # HACK


PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)]
163 changes: 161 additions & 2 deletions src/faster_whisper_server/hf_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from collections.abc import Generator
from functools import lru_cache
import json
import logging
from pathlib import Path
import typing
from typing import Any, Literal

import huggingface_hub
from huggingface_hub.constants import HF_HUB_CACHE
from pydantic import BaseModel

from faster_whisper_server.api_models import Model

logger = logging.getLogger(__name__)

Expand All @@ -12,10 +19,36 @@


def does_local_model_exist(model_id: str) -> bool:
return any(model_id == model.repo_id for model, _ in list_local_models())
return any(model_id == model.repo_id for model, _ in list_local_whisper_models())


def list_whisper_models() -> Generator[Model, None, None]:
models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
models = list(models)
models.sort(key=lambda model: model.downloads or -1, reverse=True)
for model in models:
assert model.created_at is not None
assert model.card_data is not None
assert model.card_data.language is None or isinstance(model.card_data.language, str | list)
if model.card_data.language is None:
language = []
elif isinstance(model.card_data.language, str):
language = [model.card_data.language]
else:
language = model.card_data.language
transformed_model = Model(
id=model.id,
created=int(model.created_at.timestamp()),
object_="model",
owned_by=model.id.split("/")[0],
language=language,
)
yield transformed_model


def list_local_models() -> Generator[tuple[huggingface_hub.CachedRepoInfo, huggingface_hub.ModelCardData], None, None]:
def list_local_whisper_models() -> (
Generator[tuple[huggingface_hub.CachedRepoInfo, huggingface_hub.ModelCardData], None, None]
):
hf_cache = huggingface_hub.scan_cache_dir()
hf_models = [repo for repo in list(hf_cache.repos) if repo.repo_type == "model"]
for model in hf_models:
Expand All @@ -36,3 +69,129 @@ def list_local_models() -> Generator[tuple[huggingface_hub.CachedRepoInfo, huggi
and TASK_NAME in model_card_data.tags
):
yield model, model_card_data


def get_whisper_models() -> Generator[Model, None, None]:
models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
models = list(models)
models.sort(key=lambda model: model.downloads or -1, reverse=True)
for model in models:
assert model.created_at is not None
assert model.card_data is not None
assert model.card_data.language is None or isinstance(model.card_data.language, str | list)
if model.card_data.language is None:
language = []
elif isinstance(model.card_data.language, str):
language = [model.card_data.language]
else:
language = model.card_data.language
transformed_model = Model(
id=model.id,
created=int(model.created_at.timestamp()),
object_="model",
owned_by=model.id.split("/")[0],
language=language,
)
yield transformed_model


class PiperModel(BaseModel):
id: str
object: Literal["model"] = "model"
created: int
owned_by: Literal["rhasspy"] = "rhasspy"
path: Path
config_path: Path


def get_model_path(model_id: str, *, cache_dir: str | Path | None = None) -> Path | None:
if cache_dir is None:
cache_dir = HF_HUB_CACHE

cache_dir = Path(cache_dir).expanduser().resolve()
if not cache_dir.exists():
raise huggingface_hub.CacheNotFound(
f"Cache directory not found: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable.", # noqa: E501
cache_dir=cache_dir,
)

if cache_dir.is_file():
raise ValueError(
f"Scan cache expects a directory but found a file: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable." # noqa: E501
)

for repo_path in cache_dir.iterdir():
if not repo_path.is_dir():
continue
if repo_path.name == ".locks": # skip './.locks/' folder
continue
repo_type, repo_id = repo_path.name.split("--", maxsplit=1)
repo_type = repo_type[:-1] # "models" -> "model"
repo_id = repo_id.replace("--", "/") # google--fleurs -> "google/fleurs"
if repo_type != "model":
continue
if model_id == repo_id:
return repo_path

return None


def list_model_files(
model_id: str, glob_pattern: str = "**/*", *, cache_dir: str | Path | None = None
) -> Generator[Path, None, None]:
repo_path = get_model_path(model_id, cache_dir=cache_dir)
if repo_path is None:
return None
snapshots_path = repo_path / "snapshots"
if not snapshots_path.exists():
return None
yield from list(snapshots_path.glob(glob_pattern))


def list_piper_models() -> Generator[PiperModel, None, None]:
model_weights_files = list_model_files("rhasspy/piper-voices", glob_pattern="**/*.onnx")
for model_weights_file in model_weights_files:
model_config_file = model_weights_file.with_suffix(".json")
yield PiperModel(
id=model_weights_file.name,
created=int(model_weights_file.stat().st_mtime),
path=model_weights_file,
config_path=model_config_file,
)


# NOTE: It's debatable whether caching should be done here or by the caller. Should be revisited.


@lru_cache
def read_piper_voices_config() -> dict[str, Any]:
voices_file = next(list_model_files("rhasspy/piper-voices", glob_pattern="**/voices.json"), None)
if voices_file is None:
raise FileNotFoundError("Could not find voices.json file")
return json.loads(voices_file.read_text())


@lru_cache
def get_piper_voice_model_file(voice: str) -> Path:
model_file = next(list_model_files("rhasspy/piper-voices", glob_pattern=f"**/{voice}.onnx"), None)
if model_file is None:
raise FileNotFoundError(f"Could not find model file for '{voice}' voice")
return model_file


class PiperVoiceConfigAudio(BaseModel):
sample_rate: int
quality: int


class PiperVoiceConfig(BaseModel):
audio: PiperVoiceConfigAudio
# NOTE: there are more fields in the config, but we don't care about them


@lru_cache
def read_piper_voice_config(voice: str) -> PiperVoiceConfig:
model_config_file = next(list_model_files("rhasspy/piper-voices", glob_pattern=f"**/{voice}.onnx.json"), None)
if model_config_file is None:
raise FileNotFoundError(f"Could not find config file for '{voice}' voice")
return PiperVoiceConfig.model_validate_json(model_config_file.read_text())
4 changes: 4 additions & 0 deletions src/faster_whisper_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from faster_whisper_server.routers.misc import (
router as misc_router,
)
from faster_whisper_server.routers.speech import (
router as speech_router,
)
from faster_whisper_server.routers.stt import (
router as stt_router,
)
Expand Down Expand Up @@ -46,6 +49,7 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
app.include_router(stt_router)
app.include_router(list_models_router)
app.include_router(misc_router)
app.include_router(speech_router)

if config.allow_origins is not None:
app.add_middleware(
Expand Down
Loading

0 comments on commit d837657

Please sign in to comment.