Skip to content

Commit

Permalink
refactor: add thread pool
Browse files Browse the repository at this point in the history
  • Loading branch information
aiwantaozi committed Nov 21, 2024
1 parent caa4d83 commit 42b10c4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
2 changes: 1 addition & 1 deletion speech_box/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import httpx

from speech_box import __version__
from speech_box.server.openai_routers import router
from speech_box.server.routers import router


@asynccontextmanager
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import asyncio
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from fastapi.responses import FileResponse

from speech_box.backends.stt.base import STTBackend
from speech_box.backends.tts.base import TTSBackend
from speech_box.server.model import get_model_instance
from concurrent.futures import ThreadPoolExecutor

router = APIRouter()

executor = ThreadPoolExecutor()


class SpeechRequest(BaseModel):
model: str
Expand All @@ -25,9 +29,17 @@ async def speech(request: SpeechRequest):
return HTTPException(
status_code=400, detail="Model instance does not support speech API"
)
audio_file = model_instance.speech(
request.input, request.voice, request.speed, request.response_format

loop = asyncio.get_event_loop()
audio_file = await loop.run_in_executor(
executor,
model_instance.speech,
request.input,
request.voice,
request.speed,
request.response_format,
)

media_type = get_media_type(request.response_format)
return FileResponse(audio_file, media_type=media_type)
except Exception as e:
Expand Down Expand Up @@ -55,7 +67,11 @@ async def transcribe(request: Request):
status_code=400,
detail="Model instance does not support transcriptions API",
)
data = model_instance.transcribe(

loop = asyncio.get_event_loop()
data = await loop.run_in_executor(
executor,
model_instance.transcribe,
audio_bytes,
language,
prompt,
Expand Down Expand Up @@ -98,11 +114,19 @@ async def get_model_info(model_id: str):
return model_instance.model_info()


@router.get("/voices")
async def get_voice():
model_instance = get_model_instance()
if model_instance is None:
return {}
return {"voices": model_instance.model_info().get("voices", [])}


def get_media_type(response_format) -> str:
if response_format == "mp3":
media_type = "audio/mpeg"
elif response_format == "opus":
media_type = "audio/ogg;codec=opus" # codecs?
media_type = "audio/ogg;codec=opus"
elif response_format == "aac":
media_type = "audio/aac"
elif response_format == "flac":
Expand Down

0 comments on commit 42b10c4

Please sign in to comment.