Skip to content

Commit

Permalink
optimizations for linux
Browse files Browse the repository at this point in the history
  • Loading branch information
KoljaB committed Nov 16, 2024
1 parent b2fab8b commit 41a6856
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 8 deletions.
7 changes: 6 additions & 1 deletion RealtimeTTS/engines/coqui_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,10 @@ def __init__(
# Start the worker process
try:
# Only set the start method if it hasn't been set already
if mp.get_start_method(allow_none=True) is None:
# Check the current platform and set the start method
if sys.platform.startswith('linux') or sys.platform == 'darwin': # For Linux or macOS
mp.set_start_method("spawn")
elif mp.get_start_method(allow_none=True) is None:
mp.set_start_method("spawn")
except RuntimeError as e:
print("Start method has already been set. Details:", e)
Expand Down Expand Up @@ -715,6 +718,7 @@ def get_user_data_dir(appname):
for i, chunk in enumerate(chunks):
chunk = postprocess_wave(chunk)
chunk_bytes = chunk.tobytes()

conn.send(("success", chunk_bytes))
chunk_duration = len(chunk_bytes) / (
4 * 24000
Expand Down Expand Up @@ -942,6 +946,7 @@ def synthesize(self, text: str) -> bool:
logging.error(f"Error synthesizing text: {text}")
logging.error(f"Error: {result}")
return False

self.queue.put(result)
status, result = self.parent_synthesize_pipe.recv()

Expand Down
9 changes: 9 additions & 0 deletions RealtimeTTS/engines/parler_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Union
import pyaudio
import torch
import time


class ParlerVoice:
Expand All @@ -24,6 +25,7 @@ def __init__(
torch_dtype=torch.bfloat16,
buffer_duration_s=1.0,
play_steps_in_s=0.5,
print_time_to_first_token=False,
):
"""
Initializes the Parler TTS engine.
Expand All @@ -43,6 +45,7 @@ def __init__(
self.play_steps_in_s = play_steps_in_s
self.voice_parameters = {}
self.buffer_duration_s = buffer_duration_s
self.print_time_to_first_token = print_time_to_first_token

self.initialize_model()

Expand Down Expand Up @@ -91,6 +94,7 @@ def _generate_and_queue_audio(self, text: str):
Args:
text (str): Text to synthesize.
"""
start_time = time.time()
frame_rate = self.model.audio_encoder.config.frame_rate
sampling_rate = self.model.audio_encoder.config.sampling_rate

Expand Down Expand Up @@ -154,6 +158,7 @@ def generate_audio():
self.queue.put(buffered_chunk.tobytes())

# Continue streaming the rest of the audio
first_token = False
while not generation_completed:
try:
new_audio = next(streamer)
Expand All @@ -162,7 +167,11 @@ def generate_audio():
generation_completed = True
break
audio_chunk = new_audio
if not first_token and self.print_time_to_first_token:
end_time = time.time()
print(f"Time to first token: {end_time - start_time:.2f} s")
self.queue.put(audio_chunk.tobytes())
first_token = True
except StopIteration:
generation_completed = True
break
Expand Down
146 changes: 139 additions & 7 deletions RealtimeTTS/stream_player.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""

from pydub import AudioSegment
import numpy as np
import threading
import resampy
import pyaudio
import logging
import queue
Expand Down Expand Up @@ -52,15 +54,91 @@ def __init__(self, config: AudioConfiguration):
self.config = config
self.stream = None
self.pyaudio_instance = pyaudio.PyAudio()
self.actual_sample_rate = 0

def get_supported_sample_rates(self, device_index):
"""
Test which standard sample rates are supported by the specified device.
Args:
device_index (int): The index of the audio device to test
Returns:
list: List of supported sample rates
"""
standard_rates = [8000, 9600, 11025, 12000, 16000, 22050, 24000, 32000, 44100, 48000]
supported_rates = []

device_info = self.pyaudio_instance.get_device_info_by_index(device_index)
max_channels = device_info.get('maxOutputChannels')

# Test each standard sample rate
for rate in standard_rates:
try:
if self.pyaudio_instance.is_format_supported(
rate,
output_device=device_index,
output_channels=max_channels,
output_format=self.config.format,
):
supported_rates.append(rate)
except:
continue
return supported_rates

def _get_best_sample_rate(self, device_index, desired_rate):
"""
Determines the best available sample rate for the device.
Args:
device_index: Index of the audio device
desired_rate: Preferred sample rate
Returns:
int: Best available sample rate
"""
try:
# First determine the actual device index to use
actual_device_index = (device_index if device_index is not None
else self.pyaudio_instance.get_default_output_device_info()['index'])

# Now use the actual_device_index for getting device info and supported rates
device_info = self.pyaudio_instance.get_device_info_by_index(actual_device_index)
supported_rates = self.get_supported_sample_rates(actual_device_index)

# Check if desired rate is supported
if desired_rate in supported_rates:
return desired_rate

# Find the highest supported rate that's lower than desired_rate
lower_rates = [r for r in supported_rates if r <= desired_rate]
if lower_rates:
return max(lower_rates)

# If no lower rates, get the lowest higher rate
higher_rates = [r for r in supported_rates if r > desired_rate]
if higher_rates:
return min(higher_rates)

# If nothing else works, return device's default rate
return int(device_info.get('defaultSampleRate', 44100))

except Exception as e:
logging.warning(f"Error determining sample rate: {e}")
return 44100 # Safe fallback

def open_stream(self):
"""Opens an audio stream."""

# check for mpeg format
pyChannels = self.config.channels
pySampleRate = self.config.rate
desired_rate = self.config.rate
pyOutput_device_index = self.config.output_device_index

# Determine the best sample rate
best_rate = self._get_best_sample_rate(pyOutput_device_index, desired_rate)
self.actual_sample_rate = best_rate

if self.config.muted:
logging.debug("Muted mode, no opening stream")

Expand All @@ -70,26 +148,42 @@ def open_stream(self):
logging.debug(
"Opening stream for mpeg audio chunks, "
f"pyFormat: {pyFormat}, pyChannels: {pyChannels}, "
f"pySampleRate: {pySampleRate}"
f"pySampleRate: {best_rate}"
)
else:
pyFormat = self.config.format
logging.debug(
"Opening stream for wave audio chunks, "
f"pyFormat: {pyFormat}, pyChannels: {pyChannels}, "
f"pySampleRate: {pySampleRate}"
f"pySampleRate: {best_rate}"
)

try:
self.stream = self.pyaudio_instance.open(
format=pyFormat,
channels=pyChannels,
rate=pySampleRate,
rate=best_rate,
output_device_index=pyOutput_device_index,
output=True,
)
except Exception as e:
print(f"Error opening stream: {e}")

# Get the number of available audio devices
device_count = self.pyaudio_instance.get_device_count()

print("Available Audio Devices:\n")

# Iterate through all devices and print their details
for i in range(device_count):
device_info = self.pyaudio_instance.get_device_info_by_index(i)
print(f"Device Index: {i}")
print(f" Name: {device_info['name']}")
print(f" Sample Rate (Default): {device_info['defaultSampleRate']} Hz")
print(f" Max Input Channels: {device_info['maxInputChannels']}")
print(f" Max Output Channels: {device_info['maxOutputChannels']}")
print(f" Host API: {self.pyaudio_instance.get_host_api_info_by_index(device_info['hostApi'])['name']}")
print("\n")

exit(0)

def start_stream(self):
Expand Down Expand Up @@ -225,13 +319,27 @@ def _play_chunk(self, chunk):
chunk: Chunk of audio data to be played.
"""

sample_width = self.audio_stream.pyaudio_instance.get_sample_size(self.audio_stream.config.format)
channels = self.audio_stream.config.channels

# handle mpeg
if self.audio_stream.config.format == pyaudio.paCustomFormat:
# convert to pcm using pydub
segment = AudioSegment.from_file(io.BytesIO(chunk), format="mp3")
chunk = segment.raw_data

sub_chunk_size = 1024
if self.audio_stream.config.rate != self.audio_stream.actual_sample_rate:
if self.audio_stream.config.format == pyaudio.paFloat32:
audio_data = np.frombuffer(chunk, dtype=np.float32)
resampled_data = resampy.resample(audio_data, self.audio_stream.config.rate, self.audio_stream.actual_sample_rate)
chunk = resampled_data.astype(np.float32).tobytes()
else:
audio_data = np.frombuffer(chunk, dtype=np.int16)
audio_data = audio_data.astype(np.float32) / 32768.0
resampled_data = resampy.resample(audio_data, self.audio_stream.config.rate, self.audio_stream.actual_sample_rate)
chunk = (resampled_data * 32768.0).astype(np.int16).tobytes()

sub_chunk_size = 512

for i in range(0, len(chunk), sub_chunk_size):
sub_chunk = chunk[i : i + sub_chunk_size]
Expand All @@ -241,7 +349,30 @@ def _play_chunk(self, chunk):
self.first_chunk_played = True

if not self.muted:
self.audio_stream.stream.write(sub_chunk)
try:
import time

# Define the timeout duration in seconds
timeout = 0.1

# Record the start time
start_time = time.time()

frames_in_sub_chunk = len(sub_chunk) // (sample_width * channels)

# Wait until there's space in the buffer or the timeout is reached
while self.audio_stream.stream.get_write_available() < frames_in_sub_chunk:
if time.time() - start_time > timeout:
print(f"Wait aborted: Timeout of {timeout}s exceeded. "
f"Buffer availability: {self.audio_stream.stream.get_write_available()}, "
f"Frames in sub-chunk: {frames_in_sub_chunk}")
break
time.sleep(0.001) # Small sleep to let the stream process audio


self.audio_stream.stream.write(sub_chunk)
except Exception as e:
print(f"RealtimeTTS error sending audio data: {e}")

if self.on_audio_chunk:
self.on_audio_chunk(sub_chunk)
Expand All @@ -266,6 +397,7 @@ def _process_buffer(self):
if self.immediate_stop.is_set():
logging.info("Immediate stop requested, aborting playback")
break

if self.on_playback_stop:
self.on_playback_stop()

Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,8 @@ pyaudio==0.2.14
# pydub is used to convert chunks from mp3 to pcm (for openai tts)
pydub==0.25.1

# resampy is used to resample from the tts to the target device sample rate
resampy==0.4.3

# stream2sentence is to quickly convert streamed text into sentences for real-time synthesis
stream2sentence==0.2.7

0 comments on commit 41a6856

Please sign in to comment.