Skip to content

Commit

Permalink
refactor: add api params validate
Browse files Browse the repository at this point in the history
  • Loading branch information
aiwantaozi committed Nov 26, 2024
1 parent 1010f29 commit 28ca493
Showing 1 changed file with 74 additions and 5 deletions.
79 changes: 74 additions & 5 deletions vox_box/server/routers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, HTTPException, Request, UploadFile
from pydantic import BaseModel
from fastapi.responses import FileResponse

Expand All @@ -12,18 +12,36 @@

executor = ThreadPoolExecutor()

ALLOWED_SPEECH_OUTPUT_AUDIO_TYPES = {
"mp3",
"opus",
"aac",
"flac",
"wav",
"pcm",
}


class SpeechRequest(BaseModel):
model: str
input: str
voice: str = "alloy"
response_format: str = "wav"
voice: str
response_format: str = "mp3"
speed: float = 1.0


@router.post("/v1/audio/speech")
async def speech(request: SpeechRequest):
try:
if (
request.response_format
and request.response_format not in ALLOWED_SPEECH_OUTPUT_AUDIO_TYPES
):
return HTTPException(
status_code=400,
detail=f"Unsupported audio format: {request.response_format}",
)

model_instance: TTSBackend = get_model_instance()
if not isinstance(model_instance, TTSBackend):
return HTTPException(
Expand All @@ -46,6 +64,38 @@ async def speech(request: SpeechRequest):
return HTTPException(status_code=500, detail=f"Failed to generate speech, {e}")


# ref: https://github.com/LMS-Community/slimserver/blob/public/10.0/types.conf
ALLOWED_TRANSCRIPTIONS_INPUT_AUDIO_FORMATS = {
# flac
"audio/flac",
"audio/x-flac",
# mp3
"audio/mpeg",
"audio/x-mpeg",
"audio/mp3",
"audio/mp3s",
"audio/mpeg3",
"audio/mpg",
# mp4
"audio/m4a",
"audio/x-m4a",
"audio/mp4",
# mpeg
"audio/mpga",
# ogg
"audio/ogg",
"audio/x-ogg",
# wav
"audio/wav",
"audio/x-wav",
"audio/wave",
# webm
"audio/webm",
}

ALLOWED_TRANSCRIPTIONS_OUTPUT_FORMATS = {"json", "text", "srt", "vtt", "verbose_json"}


@router.post("/v1/audio/transcriptions")
async def transcribe(request: Request):
try:
Expand All @@ -54,12 +104,31 @@ async def transcribe(request: Request):
if "file" not in keys:
return HTTPException(status_code=400, detail="Field file is required")

audio_bytes = await form["file"].read()
file: UploadFile = form[
"file"
] # flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm
file_content_type = file.content_type
if file_content_type not in ALLOWED_TRANSCRIPTIONS_INPUT_AUDIO_FORMATS:
return HTTPException(
status_code=400,
detail=f"Unsupported file format: {file_content_type}",
)

audio_bytes = await file.read()
language = form.get("language")
prompt = form.get("prompt")
temperature = float(form.get("temperature", 0.2))
temperature = float(form.get("temperature", 0))
if not (0 <= temperature <= 1):
return HTTPException(
status_code=400, detail="Temperature must be between 0 and 1"
)

timestamp_granularities = form.getlist("timestamp_granularities")
response_format = form.get("response_format", "json")
if response_format not in ALLOWED_TRANSCRIPTIONS_OUTPUT_FORMATS:
return HTTPException(
status_code=400, detail="Unsupported response_format: {response_format}"
)

model_instance: STTBackend = get_model_instance()
if not isinstance(model_instance, STTBackend):
Expand Down

0 comments on commit 28ca493

Please sign in to comment.