Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Piper stream audio #45

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 136 additions & 75 deletions programs/tts/piper/bin/piper_server.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,52 @@
#!/usr/bin/env python3
import argparse
import json
import io
import logging
import os
import selectors
import socket
import subprocess
import tempfile
import sys
import wave
from pathlib import Path
from rhasspy3.audio import AudioChunk, AudioStart, AudioStop
from rhasspy3.event import read_event, write_event
from rhasspy3.tts import Synthesize

_FILE = Path(__file__)
_DIR = _FILE.parent
_LOGGER = logging.getLogger(_FILE.stem)


def get_voice_config(model) -> dict:
"""Generate sample wav to get samplerate, samplewidth, and channels of the voice."""
command = [
str(_DIR / "piper"),
"--model",
str(model),
"--output_file",
"-",
]
proc = subprocess.Popen(
command,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
)
try:
wav_str, _ = proc.communicate(b"\n", timeout=10)
except subprocess.TimeoutExpired:
proc.kill()
wav_str, _ = proc.communicate()

with io.BytesIO(wav_str) as wav_io:
wav_file: wave.Wave_write = wave.open(wav_io, "rb")
with wav_file:
rate = wav_file.getframerate()
width = wav_file.getsampwidth()
channels = wav_file.getnchannels()
return {"rate": rate, "width": width, "channels": channels}


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("model", help="Path to model file (.onnx)")
Expand All @@ -23,11 +56,21 @@ def main() -> None:
parser.add_argument(
"--auto-punctuation", default=".?!", help="Automatically add punctuation"
)
parser.add_argument("--config", help="Path to model config file (default: model path + .json)")
parser.add_argument("--speaker", type=int, help="ID of speaker (default: 0)")
parser.add_argument("--noise_scale", type=float, help="Generator noise (default: 0.667)")
parser.add_argument("--length_scale", type=float, help="Phoneme length (default: 1.0)")
parser.add_argument("--noise_w", type=float, help="Phoneme width noise (default: 0.8)")
parser.add_argument("--sentence_silence", type=float, help="Seconds of silence after each sentence (default: 0.2)")
parser.add_argument("--tashkeel_model", help="Path to libtashkeel onnx model (arabic)")
parser.add_argument("--debug", action="store_true", help="Log DEBUG messages")
args = parser.parse_args()

logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)

voice_config = get_voice_config(args.model)
_LOGGER.debug("voice_config: %s", voice_config)

# Need to unlink socket if it exists
try:
os.unlink(args.socketfile)
Expand All @@ -40,51 +83,68 @@ def main() -> None:
sock.bind(args.socketfile)
sock.listen()

with tempfile.TemporaryDirectory() as temp_dir:
command = [
str(_DIR / "piper"),
"--model",
str(args.model),
"--output_dir",
temp_dir,
]
with subprocess.Popen(
command,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
universal_newlines=True,
) as proc:
_LOGGER.info("Ready")

# Listen for connections
while True:
try:
connection, client_address = sock.accept()
_LOGGER.debug("Connection from %s", client_address)
handle_connection(connection, proc, args)
except KeyboardInterrupt:
break
except Exception:
_LOGGER.exception("Error communicating with socket client")
command = [
str(_DIR / "piper"),
"--model",
str(args.model),
"--output_raw",
]
if args.config is not None:
command.append(["--config", args.config])
if args.speaker is not None:
command.append(["--speaker", args.speaker])
if args.noise_scale is not None:
command.append(["--noise_scale", args.noise_scale])
if args.length_scale is not None:
command.append(["--length_scale", args.length_scale])
if args.noise_w is not None:
command.append(["--noise_w", args.noise_w])
if args.sentence_silence is not None:
command.append(["--sentence_silence", args.sentence_silence])
if args.tashkeel_model is not None:
command.append(["--tashkeel_model", args.tashkeel_model])

with subprocess.Popen(
command,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as proc:
_LOGGER.info("Ready")

# Listen for connections
while True:
try:
connection, client_address = sock.accept()
_LOGGER.debug("Connection from %s", client_address)
handle_connection(connection, proc, args, voice_config)
except KeyboardInterrupt:
break
except Exception:
_LOGGER.exception("Error communicating with socket client")
finally:
os.unlink(args.socketfile)


def handle_connection(
connection: socket.socket, proc: subprocess.Popen, args: argparse.Namespace
connection: socket.socket,
proc: subprocess.Popen,
args: argparse.Namespace,
voice_config: dict,
) -> None:
assert proc.stdin is not None
assert proc.stdout is not None

with connection, connection.makefile(mode="rwb") as conn_file:
while True:
event_info = json.loads(conn_file.readline())
event_type = event_info["type"]
event = read_event(conn_file) # type: ignore
if event is None:
continue

if event_type != "synthesize":
if not Synthesize.is_type(event.type):
continue

raw_text = event_info["data"]["text"]
raw_text = Synthesize.from_event(event).text
text = raw_text.strip()
if args.auto_punctuation and text:
has_punctuation = False
Expand All @@ -99,48 +159,49 @@ def handle_connection(
_LOGGER.debug("synthesize: raw_text=%s, text='%s'", raw_text, text)

# Text in, file path out
print(text.strip(), file=proc.stdin, flush=True)
output_path = proc.stdout.readline().strip()
_LOGGER.debug(output_path)

wav_file: wave.Wave_read = wave.open(output_path, "rb")
with wav_file:
data = {
"rate": wav_file.getframerate(),
"width": wav_file.getsampwidth(),
"channels": wav_file.getnchannels(),
}

conn_file.write(
(
json.dumps(
{"type": "audio-start", "data": data}, ensure_ascii=False
)
+ "\n"
).encode()
)

# Audio
audio_bytes = wav_file.readframes(wav_file.getnframes())
conn_file.write(
(
json.dumps(
{
"type": "audio-chunk",
"data": data,
"payload_length": len(audio_bytes),
},
ensure_ascii=False,
)
+ "\n"
).encode()
)
conn_file.write(audio_bytes)

conn_file.write(
(json.dumps({"type": "audio-stop"}, ensure_ascii=False) + "\n").encode()
)
os.unlink(output_path)
proc.stdin.write(bytes(text.strip() + "\n", "utf8"))
proc.stdin.flush()

sel = selectors.DefaultSelector()
sel.register(proc.stdout, selectors.EVENT_READ)
sel.register(proc.stderr, selectors.EVENT_READ)

audio_started = False
audio_stopped = False
while True:
# Wait for stdout or stderr output from the process (blocking).
# If we already got a message on stderr that the synthesizing has finished,
# then just poll (non-blocking) until stdout is empty. We will know that
# when the non-blocking select(timeout=0) returns an empty set
rlist = sel.select(timeout=0 if audio_stopped else None)
if not rlist:
break
for key, _ in rlist:
output = key.fileobj.read1()
if not output:
break
if key.fileobj is proc.stderr:
sys.stderr.buffer.write(output)
if "Real-time factor" in output.decode():
audio_stopped = True
continue

if not audio_started:
write_event(AudioStart(
rate=voice_config["rate"],
width=voice_config["width"],
channels=voice_config["channels"],
).event(), conn_file) # type: ignore
audio_started = True
# Audio
write_event(AudioChunk(
rate=voice_config["rate"],
width=voice_config["width"],
channels=voice_config["channels"],
audio=output,
).event(), conn_file) # type: ignore

write_event(AudioStop().event(), conn_file) # type: ignore
break


Expand Down
6 changes: 3 additions & 3 deletions programs/tts/piper/script/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@
_DIR = Path(__file__).parent
_LOGGER = logging.getLogger("setup")

PLATFORMS = {"x86_64": "amd64", "aarch64": "arm64"}
PLATFORMS = {"x86_64": "amd64", "aarch64": "arm64", "armhf": "armv7"}


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--platform",
help="CPU architecture to download (amd64, arm64)",
help="CPU architecture to download (amd64, arm64, armv7)",
)
parser.add_argument(
"--destination", help="Path to destination directory (default: bin)"
)
parser.add_argument(
"--link-format",
default="https://github.com/rhasspy/piper/releases/download/v0.0.2/piper_{platform}.tar.gz",
default="https://github.com/rhasspy/piper/releases/download/v1.2.0/piper_{platform}.tar.gz",
help="Format string for download URLs",
)
args = parser.parse_args()
Expand Down