-
Notifications
You must be signed in to change notification settings - Fork 190
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
116 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import asyncio | ||
import base64 | ||
import logging | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
from openai import AsyncOpenAI | ||
from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection | ||
from openai.types.beta.realtime.conversation_item_content_param import ConversationItemContentParam | ||
from openai.types.beta.realtime.conversation_item_param import ConversationItemParam | ||
from openai.types.beta.realtime.session_update_event_param import Session, SessionTurnDetection | ||
import pytest | ||
import soundfile as sf | ||
import websockets | ||
|
||
from speaches.audio import resample_audio | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
SAMPLE_RATE = 24000 | ||
SAMPLE_WIDTH = 2 | ||
BYTERATE = SAMPLE_RATE * SAMPLE_WIDTH # like "bitrate" but in bytes | ||
|
||
WS_BASE_URL = "ws://localhost:8000/v1" | ||
MODEL = "gpt-4o-mini" | ||
|
||
RESPONSE_SESSION = Session(turn_detection=SessionTurnDetection(create_response=True)) | ||
NO_RESPONSE_SESSION = Session(turn_detection=SessionTurnDetection(create_response=False)) | ||
|
||
|
||
async def audio_sender( | ||
conn: AsyncRealtimeConnection, audio_bytes: bytes, chunks_per_second: int = 10, speed: int = 1 | ||
) -> None: | ||
chunk_size = BYTERATE // chunks_per_second | ||
try: | ||
async with asyncio.TaskGroup() as tg: | ||
for i in range(0, len(audio_bytes), chunk_size): | ||
logger.info(f"Sending audio chunk from {i} to {i + chunk_size} of {len(audio_bytes)}") | ||
audio_chunk = audio_bytes[i : i + chunk_size] | ||
tg.create_task(conn.input_audio_buffer.append(audio=base64.b64encode(audio_chunk).decode("utf-8"))) | ||
await asyncio.sleep(1 / chunks_per_second / speed) | ||
except* websockets.exceptions.ConnectionClosedError: | ||
logger.info("Connection closed") | ||
|
||
|
||
async def print_events(conn: AsyncRealtimeConnection, final_event: str | None = None) -> None: | ||
try: | ||
async for event in conn: | ||
if event.type == "response.audio.delta": | ||
size = len(base64.b64decode(event.delta)) | ||
event.delta = f"base64 encoded audio of size {size} bytes" | ||
print(event.model_dump_json()) | ||
if final_event is not None and event.type == final_event: | ||
break | ||
except websockets.exceptions.ConnectionClosedError: | ||
logger.info("Connection closed") | ||
|
||
|
||
data, samplerate = sf.read(Path("1_2_3_4_5_6_7_8.wav"), dtype="int16") | ||
pcm_audio_bytes = data.tobytes() | ||
audio_bytes = resample_audio(pcm_audio_bytes, samplerate, 24000) | ||
quite_audio = np.zeros(SAMPLE_RATE * 3, dtype=np.int16).tobytes() | ||
audio_bytes = audio_bytes + quite_audio | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.requires_openai | ||
async def test_realtime_vad_openai() -> None: | ||
realtime_client = AsyncOpenAI(websocket_base_url=WS_BASE_URL).beta.realtime | ||
async with asyncio.TaskGroup() as tg, realtime_client.connect(model=MODEL) as conn: | ||
print_events_task = tg.create_task( | ||
print_events(conn, final_event="conversation.item.input_audio_transcription.completed") | ||
) | ||
await conn.session.update(session=NO_RESPONSE_SESSION) | ||
audio_sender_task = tg.create_task(audio_sender(conn, audio_bytes)) | ||
await audio_sender_task | ||
await print_events_task | ||
await conn.close() | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.requires_openai | ||
async def test_realtime_response() -> None: | ||
realtime_client = AsyncOpenAI(websocket_base_url=WS_BASE_URL).beta.realtime | ||
async with asyncio.TaskGroup() as tg, realtime_client.connect(model=MODEL) as conn: | ||
print_events_task = tg.create_task(print_events(conn, final_event=None)) | ||
await conn.session.update(session=RESPONSE_SESSION) | ||
audio_sender_task = tg.create_task(audio_sender(conn, audio_bytes)) | ||
await audio_sender_task | ||
await print_events_task | ||
await conn.close() | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.requires_openai | ||
async def test_realtime_create_conversation_item() -> None: | ||
realtime_client = AsyncOpenAI(websocket_base_url=WS_BASE_URL).beta.realtime | ||
async with asyncio.TaskGroup() as tg, realtime_client.connect(model=MODEL) as conn: | ||
print_events_task = tg.create_task(print_events(conn, final_event="response.done")) | ||
await conn.session.update(session=NO_RESPONSE_SESSION) | ||
await conn.conversation.item.create( | ||
item=ConversationItemParam( | ||
role="user", type="message", content=[ConversationItemContentParam(type="input_text", text="Hello")] | ||
) | ||
) | ||
await conn.response.create() | ||
await print_events_task | ||
await conn.close() |