Skip to content

Commit

Permalink
fix: LLM handle finely language changes
Browse files Browse the repository at this point in the history
  • Loading branch information
clemlesne committed Jan 15, 2025
1 parent 7fe3cdf commit c024382
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 10 deletions.
4 changes: 3 additions & 1 deletion app/helpers/call_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ async def on_ivr_recognized(
call=call,
scheduler=scheduler,
):
call.lang = lang.short_code
call.lang_short_code = lang.short_code
call.recognition_retry = 0

await start_audio_streaming(
Expand Down Expand Up @@ -528,6 +528,7 @@ async def on_sms_received(
MessageModel(
action=MessageActionEnum.SMS,
content=message,
lang_short_code=call.lang.short_code,
persona=MessagePersonaEnum.HUMAN,
)
)
Expand Down Expand Up @@ -655,6 +656,7 @@ def _validate(req: str | None) -> tuple[bool, str | None, str | None]:
call=call,
scheduler=scheduler,
):
# Dont't store the lang as we aren't sure about the language of the SMS
call.messages.append(
MessageModel(
action=MessageActionEnum.SMS,
Expand Down
10 changes: 9 additions & 1 deletion app/helpers/call_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ async def _response_callback(_retry: bool = False) -> None:
call.messages.append(
MessageModel(
content=stt_text,
lang_short_code=call.lang.short_code,
persona=MessagePersonaEnum.HUMAN,
)
)
Expand Down Expand Up @@ -491,6 +492,13 @@ async def _content_callback(buffer: str) -> None:
tools = await plugins.to_openai(frozenset(tool_blacklist))
# logger.debug("Tools: %s", tools)

# Translate messages to avoid LLM hallucinations
# See: https://github.com/microsoft/call-center-ai/issues/260
translated_messages = await asyncio.gather(
*[message.translate(call.lang.short_code) for message in call.messages]
)
logger.debug("Translated messages: %s", translated_messages)

# Execute LLM inference
content_buffer_pointer = 0
last_buffered_tool_id = None
Expand All @@ -500,7 +508,7 @@ async def _content_callback(buffer: str) -> None:
# Consume the completion stream
async for delta in completion_stream(
max_tokens=160, # Lowest possible value for 90% of the cases, if not sufficient, retry will be triggered, 100 tokens ~= 75 words, 20 words ~= 1 sentence, 6 sentences ~= 160 tokens
messages=call.messages,
messages=translated_messages,
system=system,
tools=tools,
):
Expand Down
1 change: 1 addition & 0 deletions app/helpers/call_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ async def _store_assistant_message(
call.messages.append(
MessageModel(
content=text,
lang_short_code=call.lang.short_code,
persona=MessagePersonaEnum.ASSISTANT,
style=style,
)
Expand Down
3 changes: 2 additions & 1 deletion app/helpers/llm_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ async def send_sms(
MessageModel(
action=MessageActionEnum.SMS,
content=message,
lang_short_code=self.call.lang.short_code,
persona=MessagePersonaEnum.ASSISTANT,
)
)
Expand Down Expand Up @@ -519,7 +520,7 @@ async def speech_lang(

# Update lang
initial_lang = self.call.lang.short_code
self.call.lang = lang
self.call.lang_short_code = lang

# LLM confirmation
return f"Voice language set to {lang} (was {initial_lang})"
2 changes: 1 addition & 1 deletion app/helpers/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def translate_text(text: str, source_lang: str, target_lang: str) -> str |
"""
Translate text from source language to target language.
Catch errors for a maximum of 3 times.
If the source and target languages are the same, the original text is returned. Catch errors for a maximum of 3 times.
"""
# No need to translate
if source_lang == target_lang:
Expand Down
4 changes: 0 additions & 4 deletions app/models/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,6 @@ def lang(self) -> LanguageEntryModel: # pyright: ignore
)
return default

@lang.setter
def lang(self, short_code: str) -> None:
self.lang_short_code = short_code

async def trainings(self, cache_only: bool = True) -> list[TrainingModel]:
"""
Get the trainings from the last messages.
Expand Down
28 changes: 28 additions & 0 deletions app/models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,38 @@ class MessageModel(BaseModel):
# Editable fields
action: ActionEnum = ActionEnum.TALK
content: str
lang_short_code: str | None = None
persona: PersonaEnum
style: StyleEnum = StyleEnum.NONE
tool_calls: list[ToolModel] = []

async def translate(self, target_short_code: str) -> "MessageModel":
"""
Translate the message to a target language.
A copy of the model is returned with the translated content.
"""
from app.helpers.translation import translate_text

# Work on a copy to avoid modifying the original model in the database
copy = self.model_copy()

# Skip if no language is set
if not self.lang_short_code:
return copy

# Apply translation
translation = await translate_text(
source_lang=self.lang_short_code,
target_lang=target_short_code,
text=self.content,
)
if translation:
copy.content = translation
copy.lang_short_code = target_short_code

return copy

@field_validator("created_at")
@classmethod
def _validate_created_at(cls, created_at: datetime) -> datetime:
Expand Down
3 changes: 2 additions & 1 deletion tests/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def _play_media_callback(text: str) -> None:
call_client = automation_client.get_call_connection()

# Mock call
call.lang = lang
call.lang_short_code = lang

async with Scheduler() as scheduler:

Expand Down Expand Up @@ -312,6 +312,7 @@ async def _training_callback(_call: CallStateModel) -> None:
call.messages.append(
MessageModel(
content=speech,
lang_short_code=call.lang.short_code,
persona=MessagePersonaEnum.HUMAN,
)
)
Expand Down
1 change: 1 addition & 0 deletions tests/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ async def _training_callback(_call: CallStateModel) -> None:
call.messages.append(
MessageModel(
content=speech,
lang_short_code=call.lang.short_code,
persona=MessagePersonaEnum.HUMAN,
)
)
Expand Down
3 changes: 2 additions & 1 deletion tests/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,14 @@ async def test_relevancy( # noqa: PLR0913
Test is repeated 10 times to catch multi-threading and concurrency issues.
"""
# Set call language
call.lang = lang
call.lang_short_code = lang

# Fill call with messages
for speech in speeches:
call.messages.append(
MessageModel(
content=speech,
lang_short_code=call.lang.short_code,
persona=MessagePersonaEnum.HUMAN,
)
)
Expand Down

0 comments on commit c024382

Please sign in to comment.