diff --git a/app/helpers/call_events.py b/app/helpers/call_events.py index c154039c..6076dbfc 100644 --- a/app/helpers/call_events.py +++ b/app/helpers/call_events.py @@ -469,7 +469,7 @@ async def on_ivr_recognized( call=call, scheduler=scheduler, ): - call.lang = lang.short_code + call.lang_short_code = lang.short_code call.recognition_retry = 0 await start_audio_streaming( @@ -528,6 +528,7 @@ async def on_sms_received( MessageModel( action=MessageActionEnum.SMS, content=message, + lang_short_code=call.lang.short_code, persona=MessagePersonaEnum.HUMAN, ) ) @@ -655,6 +656,7 @@ def _validate(req: str | None) -> tuple[bool, str | None, str | None]: call=call, scheduler=scheduler, ): + # Dont't store the lang as we aren't sure about the language of the SMS call.messages.append( MessageModel( action=MessageActionEnum.SMS, diff --git a/app/helpers/call_llm.py b/app/helpers/call_llm.py index 0f4bbecb..07e36eea 100644 --- a/app/helpers/call_llm.py +++ b/app/helpers/call_llm.py @@ -196,6 +196,7 @@ async def _response_callback(_retry: bool = False) -> None: call.messages.append( MessageModel( content=stt_text, + lang_short_code=call.lang.short_code, persona=MessagePersonaEnum.HUMAN, ) ) @@ -491,6 +492,13 @@ async def _content_callback(buffer: str) -> None: tools = await plugins.to_openai(frozenset(tool_blacklist)) # logger.debug("Tools: %s", tools) + # Translate messages to avoid LLM hallucinations + # See: https://github.com/microsoft/call-center-ai/issues/260 + translated_messages = await asyncio.gather( + *[message.translate(call.lang.short_code) for message in call.messages] + ) + logger.debug("Translated messages: %s", translated_messages) + # Execute LLM inference content_buffer_pointer = 0 last_buffered_tool_id = None @@ -500,7 +508,7 @@ async def _content_callback(buffer: str) -> None: # Consume the completion stream async for delta in completion_stream( max_tokens=160, # Lowest possible value for 90% of the cases, if not sufficient, retry will be triggered, 100 tokens ~= 75 words, 20 words ~= 1 sentence, 6 sentences ~= 160 tokens - messages=call.messages, + messages=translated_messages, system=system, tools=tools, ): diff --git a/app/helpers/call_utils.py b/app/helpers/call_utils.py index 1eed49f3..87b06673 100644 --- a/app/helpers/call_utils.py +++ b/app/helpers/call_utils.py @@ -309,6 +309,7 @@ async def _store_assistant_message( call.messages.append( MessageModel( content=text, + lang_short_code=call.lang.short_code, persona=MessagePersonaEnum.ASSISTANT, style=style, ) diff --git a/app/helpers/llm_tools.py b/app/helpers/llm_tools.py index 8a1dd0a8..08bbf61c 100644 --- a/app/helpers/llm_tools.py +++ b/app/helpers/llm_tools.py @@ -427,6 +427,7 @@ async def send_sms( MessageModel( action=MessageActionEnum.SMS, content=message, + lang_short_code=self.call.lang.short_code, persona=MessagePersonaEnum.ASSISTANT, ) ) @@ -519,7 +520,7 @@ async def speech_lang( # Update lang initial_lang = self.call.lang.short_code - self.call.lang = lang + self.call.lang_short_code = lang # LLM confirmation return f"Voice language set to {lang} (was {initial_lang})" diff --git a/app/helpers/translation.py b/app/helpers/translation.py index 260af052..9a3f33b8 100644 --- a/app/helpers/translation.py +++ b/app/helpers/translation.py @@ -29,7 +29,7 @@ async def translate_text(text: str, source_lang: str, target_lang: str) -> str | """ Translate text from source language to target language. - Catch errors for a maximum of 3 times. + If the source and target languages are the same, the original text is returned. Catch errors for a maximum of 3 times. """ # No need to translate if source_lang == target_lang: diff --git a/app/models/call.py b/app/models/call.py index 182c6bfe..c34c9945 100644 --- a/app/models/call.py +++ b/app/models/call.py @@ -120,10 +120,6 @@ def lang(self) -> LanguageEntryModel: # pyright: ignore ) return default - @lang.setter - def lang(self, short_code: str) -> None: - self.lang_short_code = short_code - async def trainings(self, cache_only: bool = True) -> list[TrainingModel]: """ Get the trainings from the last messages. diff --git a/app/models/message.py b/app/models/message.py index c2fc1084..c1f682c7 100644 --- a/app/models/message.py +++ b/app/models/message.py @@ -115,10 +115,38 @@ class MessageModel(BaseModel): # Editable fields action: ActionEnum = ActionEnum.TALK content: str + lang_short_code: str | None = None persona: PersonaEnum style: StyleEnum = StyleEnum.NONE tool_calls: list[ToolModel] = [] + async def translate(self, target_short_code: str) -> "MessageModel": + """ + Translate the message to a target language. + + A copy of the model is returned with the translated content. + """ + from app.helpers.translation import translate_text + + # Work on a copy to avoid modifying the original model in the database + copy = self.model_copy() + + # Skip if no language is set + if not self.lang_short_code: + return copy + + # Apply translation + translation = await translate_text( + source_lang=self.lang_short_code, + target_lang=target_short_code, + text=self.content, + ) + if translation: + copy.content = translation + copy.lang_short_code = target_short_code + + return copy + @field_validator("created_at") @classmethod def _validate_created_at(cls, created_at: datetime) -> datetime: diff --git a/tests/llm.py b/tests/llm.py index 0dc17c26..89644a21 100644 --- a/tests/llm.py +++ b/tests/llm.py @@ -274,7 +274,7 @@ def _play_media_callback(text: str) -> None: call_client = automation_client.get_call_connection() # Mock call - call.lang = lang + call.lang_short_code = lang async with Scheduler() as scheduler: @@ -312,6 +312,7 @@ async def _training_callback(_call: CallStateModel) -> None: call.messages.append( MessageModel( content=speech, + lang_short_code=call.lang.short_code, persona=MessagePersonaEnum.HUMAN, ) ) diff --git a/tests/local.py b/tests/local.py index 90fe63a1..fac8a180 100644 --- a/tests/local.py +++ b/tests/local.py @@ -95,6 +95,7 @@ async def _training_callback(_call: CallStateModel) -> None: call.messages.append( MessageModel( content=speech, + lang_short_code=call.lang.short_code, persona=MessagePersonaEnum.HUMAN, ) ) diff --git a/tests/search.py b/tests/search.py index 57867d2c..907b8d2e 100644 --- a/tests/search.py +++ b/tests/search.py @@ -139,13 +139,14 @@ async def test_relevancy( # noqa: PLR0913 Test is repeated 10 times to catch multi-threading and concurrency issues. """ # Set call language - call.lang = lang + call.lang_short_code = lang # Fill call with messages for speech in speeches: call.messages.append( MessageModel( content=speech, + lang_short_code=call.lang.short_code, persona=MessagePersonaEnum.HUMAN, ) )