From bc775eb9682cee2bfe7bd43135505293fb714d52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Tue, 20 Feb 2024 20:28:42 +0100 Subject: [PATCH] fix: Missing sentences in streamed LLM --- helpers/call.py | 25 +++++++++++++++++-------- main.py | 6 ++++-- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/helpers/call.py b/helpers/call.py index a7884c8a..0552ea4e 100644 --- a/helpers/call.py +++ b/helpers/call.py @@ -22,7 +22,7 @@ _logger = build_logger(__name__) -SENTENCE_R = r"[^\w\s+\-–—’/'\",:;()@=]" +_SENTENCE_PUNCTUATION_R = r"(\. |\.$|[!?;:])" class ContextEnum(str, Enum): @@ -31,15 +31,24 @@ class ContextEnum(str, Enum): TRANSFER_FAILED = "transfer_failed" -def sentence_split(text: str) -> Generator[str, None, None]: +def tts_sentence_split(text: str, include_last: bool) -> Generator[str, None, None]: """ Split a text into sentences. """ - separators = re.findall(SENTENCE_R, text) - splits = re.split(SENTENCE_R, text) - for i, separator in enumerate(separators): - local_content = splits[i] + separator - yield local_content + # Clean and remove extra spaces + text = " ".join(text.split()) + # Split by sentence by punctuation + splits = re.split(_SENTENCE_PUNCTUATION_R, text) + for i, split in enumerate(splits): + if i % 2 == 1: # Skip punctuation + continue + if not split: # Skip empty lines + continue + if i == len(splits) - 1: # Skip last line in case of missing punctuation + if include_last: + yield split + else: # Add punctuation back + yield split + splits[i + 1] # TODO: Disable or lower profanity filter. The filter seems enabled by default, it replaces words like "holes in my roof" by "*** in my roof". This is not acceptable for a call center. @@ -144,7 +153,7 @@ async def handle_play( # Split text in chunks of max 400 characters, separated by sentence chunks = [] chunk = "" - for to_add in sentence_split(text): + for to_add in tts_sentence_split(text, True): if len(chunk) + len(to_add) >= 400: chunks.append(chunk.strip()) # Remove trailing space chunk = "" diff --git a/main.py b/main.py index add17357..1f71a130 100644 --- a/main.py +++ b/main.py @@ -80,7 +80,7 @@ handle_play, handle_recognize_ivr, handle_recognize_text, - sentence_split, + tts_sentence_split, ) from helpers.llm_plugins import LlmPlugins @@ -814,7 +814,9 @@ async def _tool_cancellation_callback() -> None: else: # Store whole content content_full += delta.content - for sentence in sentence_split(content_full[content_buffer_pointer:]): + for sentence in tts_sentence_split( + content_full[content_buffer_pointer:], False + ): content_buffer_pointer += len(sentence) plugins.style = await _buffer_user_callback(sentence, plugins.style)