Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: asynciterable support #477

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ dist/
__pycache__/
poetry.toml
.ruff_cache/
env/
57 changes: 44 additions & 13 deletions src/elevenlabs/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import httpx

from typing import Iterator, Optional, Union, \
Optional, AsyncIterator
Optional, AsyncIterable, AsyncIterator

from .base_client import \
BaseElevenLabs, AsyncBaseElevenLabs
from .core import RequestOptions, ApiError
from .types import Voice, VoiceSettings, \
PronunciationDictionaryVersionLocator, Model
from .environment import ElevenLabsEnvironment
from .realtime_tts import RealtimeTextToSpeechClient
from .realtime_tts import RealtimeTextToSpeechClient, AsyncRealtimeTextToSpeechClient
from .types import OutputFormat


Expand Down Expand Up @@ -257,6 +257,25 @@ class AsyncElevenLabs(AsyncBaseElevenLabs):
api_key="YOUR_API_KEY",
)
"""
def __init__(
self,
*,
base_url: typing.Optional[str] = None,
environment: ElevenLabsEnvironment = ElevenLabsEnvironment.PRODUCTION,
api_key: typing.Optional[str] = os.getenv("ELEVENLABS_API_KEY"),
timeout: typing.Optional[float] = None,
follow_redirects: typing.Optional[bool] = True,
httpx_client: typing.Optional[httpx.AsyncClient] = None
):
super().__init__(
base_url=base_url,
environment=environment,
api_key=api_key,
timeout=timeout,
follow_redirects=follow_redirects,
httpx_client=httpx_client,
)
self.text_to_speech = AsyncRealtimeTextToSpeechClient(client_wrapper=self._client_wrapper)

async def clone(
self,
Expand Down Expand Up @@ -299,7 +318,7 @@ async def clone(
async def generate(
self,
*,
text: str,
text: Union[str, AsyncIterable[str]],
voice: Union[VoiceId, VoiceName, Voice] = DEFAULT_VOICE,
voice_settings: typing.Optional[VoiceSettings] = DEFAULT_VOICE.settings,
model: Union[ModelId, Model] = "eleven_multilingual_v2",
Expand Down Expand Up @@ -383,16 +402,28 @@ async def generate(
model_id = model.model_id

if stream:
return self.text_to_speech.convert_as_stream(
voice_id=voice_id,
model_id=model_id,
voice_settings=voice_settings,
optimize_streaming_latency=optimize_streaming_latency,
output_format=output_format,
text=text,
request_options=request_options,
pronunciation_dictionary_locators=pronunciation_dictionary_locators
)
if isinstance(text, AsyncIterable):
return self.text_to_speech.convert_realtime( # type: ignore
voice_id=voice_id,
voice_settings=voice_settings,
output_format=output_format,
text=text,
request_options=request_options,
model_id=model_id
)
elif isinstance(text, str):
return self.text_to_speech.convert_as_stream(
voice_id=voice_id,
model_id=model_id,
voice_settings=voice_settings,
optimize_streaming_latency=optimize_streaming_latency,
output_format=output_format,
text=text,
request_options=request_options,
pronunciation_dictionary_locators=pronunciation_dictionary_locators
)
else:
raise ApiError(body="Text is neither a string nor an iterator.")
else:
if not isinstance(text, str):
raise ApiError(body="Text must be a string when stream is False.")
Expand Down
115 changes: 114 additions & 1 deletion src/elevenlabs/realtime_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
import json
import base64
import websockets
import asyncio

from websockets.sync.client import connect
from websockets.client import connect as async_connect

from .core.api_error import ApiError
from .core.client_wrapper import SyncClientWrapper
from .core.jsonable_encoder import jsonable_encoder
from .core.remove_none_from_dict import remove_none_from_dict
from .core.request_options import RequestOptions
from .types.voice_settings import VoiceSettings
from .text_to_speech.client import TextToSpeechClient
from .text_to_speech.client import TextToSpeechClient, AsyncTextToSpeechClient
from .types import OutputFormat

# this is used as the default value for optional parameters
Expand All @@ -39,6 +41,24 @@ def text_chunker(chunks: typing.Iterator[str]) -> typing.Iterator[str]:
yield buffer + " "


async def async_text_chunker(chunks: typing.AsyncIterable[str]) -> typing.AsyncIterable[str]:
"""Used during input streaming to chunk text blocks and set last char to space"""
splitters = (".", ",", "?", "!", ";", ":", "—", "-", "(", ")", "[", "]", "}", " ")
buffer = ""
async for text in chunks:
if buffer.endswith(splitters):
yield buffer if buffer.endswith(" ") else buffer + " "
buffer = text
elif text.startswith(splitters):
output = buffer + text[0]
yield output if output.endswith(" ") else output + " "
buffer = text[1:]
else:
buffer += text
if buffer != "":
yield buffer + " "


class RealtimeTextToSpeechClient(TextToSpeechClient):
def __init__(self, *, client_wrapper: SyncClientWrapper):
super().__init__(client_wrapper=client_wrapper)
Expand Down Expand Up @@ -141,3 +161,96 @@ def get_text() -> typing.Iterator[str]:
raise ApiError(body=data, status_code=ce.code)
elif ce.code != 1000:
raise ApiError(body=ce.reason, status_code=ce.code)

class AsyncRealtimeTextToSpeechClient(AsyncTextToSpeechClient):
async def convert_realtime(
self,
voice_id: str,
*,
text: typing.AsyncIterable[str],
model_id: typing.Optional[str] = OMIT,
output_format: typing.Optional[OutputFormat] = "mp3_44100_128",
voice_settings: typing.Optional[VoiceSettings] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.AsyncIterator[bytes]:
"""
Converts text into speech using a voice of your choice and returns audio.
Parameters:
- voice_id: str. Voice ID to be used, you can use https://api.elevenlabs.io/v1/voices to list all the available voices.

- text: typing.Iterator[str]. The text that will get converted into speech.
- model_id: typing.Optional[str]. Identifier of the model that will be used, you can query them using GET /v1/models. The model needs to have support for text to speech, you can check this using the can_do_text_to_speech property.
- voice_settings: typing.Optional[VoiceSettings]. Voice settings overriding stored setttings for the given voice. They are applied only on the given request.
- request_options: typing.Optional[RequestOptions]. Request-specific configuration.
---
from elevenlabs import PronunciationDictionaryVersionLocator, VoiceSettings
from elevenlabs.client import ElevenLabs
def get_text() -> typing.Iterator[str]:
yield "Hello, how are you?"
yield "I am fine, thank you."
client = ElevenLabs(
api_key="YOUR_API_KEY",
)
client.text_to_speech.convert_realtime(
voice_id="string",
text=get_text(),
model_id="string",
voice_settings=VoiceSettings(
stability=1.1,
similarity_boost=1.1,
style=1.1,
use_speaker_boost=True,
),
)
"""
async with async_connect(
urllib.parse.urljoin(
"wss://api.elevenlabs.io/",
f"v1/text-to-speech/{jsonable_encoder(voice_id)}/stream-input?model_id={model_id}&output_format={output_format}"
),
extra_headers=jsonable_encoder(
remove_none_from_dict(
{
**self._client_wrapper.get_headers(),
**(request_options.get("additional_headers", {}) if request_options is not None else {}),
}
)
)
) as socket:
try:
await socket.send(json.dumps(
dict(
text=" ",
try_trigger_generation=True,
voice_settings=voice_settings.dict() if voice_settings else None,
generation_config=dict(
chunk_length_schedule=[50],
),
)
))
except websockets.exceptions.ConnectionClosedError as ce:
raise ApiError(body=ce.reason, status_code=ce.code)

try:
async for text_chunk in async_text_chunker(text):
data = dict(text=text_chunk, try_trigger_generation=True)
await socket.send(json.dumps(data))
try:
async with json.loads(await asyncio.wait_for(socket.recv(), timeout=1e-2)):
if "audio" in data and data["audio"]:
yield base64.b64decode(data["audio"]) # type: ignore
except TimeoutError:
pass

await socket.send(json.dumps(dict(text="")))

while True:

data = json.loads(await socket.recv())
if "audio" in data and data["audio"]:
yield base64.b64decode(data["audio"]) # type: ignore
except websockets.exceptions.ConnectionClosed as ce:
if "message" in data:
raise ApiError(body=data, status_code=ce.code)
elif ce.code != 1000:
raise ApiError(body=ce.reason, status_code=ce.code)
25 changes: 25 additions & 0 deletions tests/test_async_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import asyncio
import pytest

from .utils import IN_GITHUB
from elevenlabs import AsyncElevenLabs
from elevenlabs import play

async_client = AsyncElevenLabs()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is entirely new and isn't autogenerated (and so no autogenerated tests) I think this needs way more tests before we can ship it.


def test_generate_stream() -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No assertions? This test doesn't appear to do anything if there isn't a human listening in.

async def main():
async def text_stream():
yield "Hi there, I'm Eleven Labs "
yield "I'm an AI Audio Research Company "

audio_stream = await async_client.generate(
text=text_stream(),
voice="Adam",
model="eleven_monolingual_v1",
stream=True
)

if not IN_GITHUB:
stream(audio_stream) # type: ignore
asyncio.run(main())
Loading