From af0aa5008f4178a5e46464b7c8cbee98dc8d7bc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 1 Mar 2024 00:37:20 +0100 Subject: [PATCH] refacto: Place message related functions in the model file --- main.py | 34 +++++----------------------------- models/message.py | 38 ++++++++++++++++++++++++++++++-------- 2 files changed, 35 insertions(+), 37 deletions(-) diff --git a/main.py b/main.py index 96806fb2..e598c272 100644 --- a/main.py +++ b/main.py @@ -61,6 +61,8 @@ PersonaEnum as MessagePersonaEnum, StyleEnum as MessageStyleEnum, ToolModel as MessageToolModel, + extract_message_style, + remove_message_action, ) from helpers.llm_worker import ( completion_model_sync, @@ -144,9 +146,6 @@ ) _logger.info(f"Using call event URL {_CALL_EVENT_URL}") -_MESSAGE_ACTION_R = r"action=([a-z_]*)( .*)?" -_MESSAGE_STYLE_R = r"style=([a-z_]*)( .*)?" - @api.get( "/health/liveness", @@ -795,35 +794,12 @@ async def execute_llm_chat( _logger.debug("Running LLM chat") should_user_answer = True - def _remove_message_actions(text: str) -> str: - """ - Remove action from content. AI often adds it by mistake event if explicitly asked not to. - """ - res = re.match(_MESSAGE_ACTION_R, text) - if not res: - return text.strip() - content = res.group(2) - return content.strip() if content else "" - - def _extract_message_style(text: str) -> Tuple[Optional[MessageStyleEnum], str]: - """ - Detect the style of a message. - """ - res = re.match(_MESSAGE_STYLE_R, text) - if not res: - return None, text - try: - content = res.group(2) - return MessageStyleEnum(res.group(1)), (content.strip() if content else "") - except ValueError: - return None, text - async def _buffer_user_callback( buffer: str, style: MessageStyleEnum ) -> MessageStyleEnum: # Remove tool calls from buffer content and detect style - local_style, local_content = _extract_message_style( - _remove_message_actions(buffer) + local_style, local_content = extract_message_style( + remove_message_action(buffer) ) new_style = local_style or style if local_content: @@ -927,7 +903,7 @@ async def _tool_cancellation_callback() -> None: tool_calls = [tool_call for _, tool_call in tool_calls_buffer.items()] # Get data from full content to be able to store it in the DB - _, content_full = _extract_message_style(_remove_message_actions(content_full)) + _, content_full = extract_message_style(remove_message_action(content_full)) _logger.debug(f"Chat response: {content_full}") _logger.debug(f"Tool calls: {tool_calls}") diff --git a/models/message.py b/models/message.py index 4c6c785b..9544ed44 100644 --- a/models/message.py +++ b/models/message.py @@ -1,7 +1,7 @@ from datetime import datetime from enum import Enum from pydantic import BaseModel, Field, validator -from typing import Any, List, Union +from typing import Any, List, Optional, Tuple, Union from openai.types.chat import ( ChatCompletionAssistantMessageParam, ChatCompletionMessageToolCallParam, @@ -11,15 +11,12 @@ from inspect import getmembers, isfunction from json_repair import repair_json from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall -from html import unescape import re -from urllib.parse import unquote -DOUBLE_ESCAPED_UNICODE_R = r"\\\\u([0-9a-fA-F]{4})" -FUNC_NAME_SANITIZER_R = r"[^a-zA-Z0-9_-]" -NON_ESCAPED_UNICODE_R = r"(? ChatCompletionMessageToolCallParam: "arguments": self.function_arguments, "name": "-".join( re.sub( - FUNC_NAME_SANITIZER_R, + _FUNC_NAME_SANITIZER_R, "-", self.function_name, ).split("-") @@ -183,3 +180,28 @@ def to_openai( ) ) return res + + +def remove_message_action(text: str) -> str: + """ + Remove action from content. AI often adds it by mistake event if explicitly asked not to. + """ + res = re.match(_MESSAGE_ACTION_R, text) + if not res: + return text.strip() + content = res.group(2) + return content.strip() if content else "" + + +def extract_message_style(text: str) -> Tuple[Optional[StyleEnum], str]: + """ + Detect the style of a message. + """ + res = re.match(_MESSAGE_STYLE_R, text) + if not res: + return None, text + try: + content = res.group(2) + return StyleEnum(res.group(1)), (content.strip() if content else "") + except ValueError: + return None, text