Skip to content

Commit

Permalink
(feat): Support for text iterators in AsyncElevenLabs
Browse files Browse the repository at this point in the history
Why:
Allowing the async client to utilize incoming text streams when generating voice.
Very useful when feeding the realtime output of an LLM into the TTS.

Closes elevenlabs#344

What:
1. Copied `RealtimeTextToSpeechClient` and `text_chunker` into `AsyncRealtimeTextToSpeechClient` and `async_text_chunker`
   Most of the logic is intact, aside from async stuff
2. Added `AsyncRealtimeTextToSpeechClient` into `AsyncElevenLabs` just like `RealtimeTextToSpeechClient` is in `ElevenLabs`
3. Added rudimentary testing

The code is basically a copy-paste of what I found in the repo. We can rewrite it to be more elegant, but I figured parity with the sync code is more important.
  • Loading branch information
nitz-uglabs committed Aug 15, 2024
1 parent 6db2fdd commit b541a57
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 12 deletions.
53 changes: 42 additions & 11 deletions src/elevenlabs/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .types import Voice, VoiceSettings, \
PronunciationDictionaryVersionLocator, Model
from .environment import ElevenLabsEnvironment
from .realtime_tts import RealtimeTextToSpeechClient
from .realtime_tts import RealtimeTextToSpeechClient, AsyncRealtimeTextToSpeechClient
from .types import OutputFormat


Expand Down Expand Up @@ -257,6 +257,25 @@ class AsyncElevenLabs(AsyncBaseElevenLabs):
api_key="YOUR_API_KEY",
)
"""
def __init__(
self,
*,
base_url: typing.Optional[str] = None,
environment: ElevenLabsEnvironment = ElevenLabsEnvironment.PRODUCTION,
api_key: typing.Optional[str] = os.getenv("ELEVEN_API_KEY"),
timeout: typing.Optional[float] = None,
follow_redirects: typing.Optional[bool] = True,
httpx_client: typing.Optional[httpx.AsyncClient] = None
):
super().__init__(
base_url=base_url,
environment=environment,
api_key=api_key,
timeout=timeout,
follow_redirects=follow_redirects,
httpx_client=httpx_client,
)
self.text_to_speech = AsyncRealtimeTextToSpeechClient(client_wrapper=self._client_wrapper)

async def clone(
self,
Expand Down Expand Up @@ -383,16 +402,28 @@ async def generate(
model_id = model.model_id

if stream:
return self.text_to_speech.convert_as_stream(
voice_id=voice_id,
model_id=model_id,
voice_settings=voice_settings,
optimize_streaming_latency=optimize_streaming_latency,
output_format=output_format,
text=text,
request_options=request_options,
pronunciation_dictionary_locators=pronunciation_dictionary_locators
)
if isinstance(text, str):
return self.text_to_speech.convert_as_stream(
voice_id=voice_id,
model_id=model_id,
voice_settings=voice_settings,
optimize_streaming_latency=optimize_streaming_latency,
output_format=output_format,
text=text,
request_options=request_options,
pronunciation_dictionary_locators=pronunciation_dictionary_locators
)
elif isinstance(text, AsyncIterator):
return self.text_to_speech.convert_realtime( # type: ignore
voice_id=voice_id,
voice_settings=voice_settings,
output_format=output_format,
text=text,
request_options=request_options,
model_id=model_id
)
else:
raise ApiError(body="Text is neither a string nor an iterator.")
else:
if not isinstance(text, str):
raise ApiError(body="Text must be a string when stream is False.")
Expand Down
122 changes: 121 additions & 1 deletion src/elevenlabs/realtime_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
import json
import base64
import websockets
import asyncio

from websockets.sync.client import connect
from websockets.client import connect as async_connect

from .core.api_error import ApiError
from .core.jsonable_encoder import jsonable_encoder
from .core.remove_none_from_dict import remove_none_from_dict
from .core.request_options import RequestOptions
from .types.voice_settings import VoiceSettings
from .text_to_speech.client import TextToSpeechClient
from .text_to_speech.client import TextToSpeechClient, AsyncTextToSpeechClient
from .types import OutputFormat

# this is used as the default value for optional parameters
Expand All @@ -37,6 +39,22 @@ def text_chunker(chunks: typing.Iterator[str]) -> typing.Iterator[str]:
if buffer != "":
yield buffer + " "

async def async_text_chunker(chunks: typing.AsyncIterator[str]) -> typing.AsyncIterator[str]:
"""Used during input streaming to chunk text blocks and set last char to space"""
splitters = (".", ",", "?", "!", ";", ":", "—", "-", "(", ")", "[", "]", "}", " ")
buffer = ""
async for text in chunks:
if buffer.endswith(splitters):
yield buffer if buffer.endswith(" ") else buffer + " "
buffer = text
elif text.startswith(splitters):
output = buffer + text[0]
yield output if output.endswith(" ") else output + " "
buffer = text[1:]
else:
buffer += text
if buffer != "":
yield buffer + " "

class RealtimeTextToSpeechClient(TextToSpeechClient):

Expand Down Expand Up @@ -137,3 +155,105 @@ def get_text() -> typing.Iterator[str]:
raise ApiError(body=data, status_code=ce.code)
elif ce.code != 1000:
raise ApiError(body=ce.reason, status_code=ce.code)


class AsyncRealtimeTextToSpeechClient(AsyncTextToSpeechClient):

async def convert_realtime(
self,
voice_id: str,
*,
text: typing.AsyncIterator[str],
model_id: typing.Optional[str] = OMIT,
output_format: typing.Optional[OutputFormat] = "mp3_44100_128",
voice_settings: typing.Optional[VoiceSettings] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.AsyncIterator[bytes]:
"""
Converts text into speech using a voice of your choice and returns audio.
Parameters:
- voice_id: str. Voice ID to be used, you can use https://api.elevenlabs.io/v1/voices to list all the available voices.
- text: typing.Iterator[str]. The text that will get converted into speech.
- model_id: typing.Optional[str]. Identifier of the model that will be used, you can query them using GET /v1/models. The model needs to have support for text to speech, you can check this using the can_do_text_to_speech property.
- voice_settings: typing.Optional[VoiceSettings]. Voice settings overriding stored setttings for the given voice. They are applied only on the given request.
- request_options: typing.Optional[RequestOptions]. Request-specific configuration.
---
from elevenlabs import PronunciationDictionaryVersionLocator, VoiceSettings
from elevenlabs.client import ElevenLabs
def get_text() -> typing.Iterator[str]:
yield "Hello, how are you?"
yield "I am fine, thank you."
client = ElevenLabs(
api_key="YOUR_API_KEY",
)
client.text_to_speech.convert_realtime(
voice_id="string",
text=get_text(),
model_id="string",
voice_settings=VoiceSettings(
stability=1.1,
similarity_boost=1.1,
style=1.1,
use_speaker_boost=True,
),
)
"""
async with async_connect(
urllib.parse.urljoin(
"wss://api.elevenlabs.io/",
f"v1/text-to-speech/{jsonable_encoder(voice_id)}/stream-input?model_id={model_id}&output_format={output_format}"
),
extra_headers=jsonable_encoder(
remove_none_from_dict(
{
**self._client_wrapper.get_headers(),
**(request_options.get("additional_headers", {}) if request_options is not None else {}),
}
)
)
) as socket:
try:
await socket.send(json.dumps(
dict(
text=" ",
try_trigger_generation=True,
voice_settings=voice_settings.dict() if voice_settings else None,
generation_config=dict(
chunk_length_schedule=[50],
),
)
))
except websockets.exceptions.ConnectionClosedError as ce:
raise ApiError(body=ce.reason, status_code=ce.code)

try:
async for text_chunk in async_text_chunker(text):
data = dict(text=text_chunk, try_trigger_generation=True)
await socket.send(json.dumps(data))
try:
async with asyncio.timeout(1e-4):
data = json.loads(await socket.recv())
if "audio" in data and data["audio"]:
yield base64.b64decode(data["audio"]) # type: ignore
except TimeoutError:
pass

await socket.send(json.dumps(dict(text="")))

while True:

data = json.loads(await socket.recv())
if "audio" in data and data["audio"]:
yield base64.b64decode(data["audio"]) # type: ignore
except websockets.exceptions.ConnectionClosed as ce:
if "message" in data:
raise ApiError(body=data, status_code=ce.code)
elif ce.code != 1000:
raise ApiError(body=ce.reason, status_code=ce.code)
17 changes: 17 additions & 0 deletions tests/test_async_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,20 @@ async def main():
if not IN_GITHUB:
play(out)
asyncio.run(main())

def test_generate_stream() -> None:
async def main():
async def text_stream():
yield "Hi there, I'm Eleven "
yield "I'm a text to speech API "

audio_stream = await async_client.generate(
text=text_stream(),
voice="Nicole",
model="eleven_monolingual_v1",
stream=True
)

if not IN_GITHUB:
stream(audio_stream) # type: ignore
asyncio.run(main())

0 comments on commit b541a57

Please sign in to comment.