Skip to content

Commit

Permalink
refacto: Place message related functions in the model file
Browse files Browse the repository at this point in the history
  • Loading branch information
clemlesne committed Feb 29, 2024
1 parent f75414a commit af0aa50
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 37 deletions.
34 changes: 5 additions & 29 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
PersonaEnum as MessagePersonaEnum,
StyleEnum as MessageStyleEnum,
ToolModel as MessageToolModel,
extract_message_style,
remove_message_action,
)
from helpers.llm_worker import (
completion_model_sync,
Expand Down Expand Up @@ -144,9 +146,6 @@
)
_logger.info(f"Using call event URL {_CALL_EVENT_URL}")

_MESSAGE_ACTION_R = r"action=([a-z_]*)( .*)?"
_MESSAGE_STYLE_R = r"style=([a-z_]*)( .*)?"


@api.get(
"/health/liveness",
Expand Down Expand Up @@ -795,35 +794,12 @@ async def execute_llm_chat(
_logger.debug("Running LLM chat")
should_user_answer = True

def _remove_message_actions(text: str) -> str:
"""
Remove action from content. AI often adds it by mistake event if explicitly asked not to.
"""
res = re.match(_MESSAGE_ACTION_R, text)
if not res:
return text.strip()
content = res.group(2)
return content.strip() if content else ""

def _extract_message_style(text: str) -> Tuple[Optional[MessageStyleEnum], str]:
"""
Detect the style of a message.
"""
res = re.match(_MESSAGE_STYLE_R, text)
if not res:
return None, text
try:
content = res.group(2)
return MessageStyleEnum(res.group(1)), (content.strip() if content else "")
except ValueError:
return None, text

async def _buffer_user_callback(
buffer: str, style: MessageStyleEnum
) -> MessageStyleEnum:
# Remove tool calls from buffer content and detect style
local_style, local_content = _extract_message_style(
_remove_message_actions(buffer)
local_style, local_content = extract_message_style(
remove_message_action(buffer)
)
new_style = local_style or style
if local_content:
Expand Down Expand Up @@ -927,7 +903,7 @@ async def _tool_cancellation_callback() -> None:
tool_calls = [tool_call for _, tool_call in tool_calls_buffer.items()]

# Get data from full content to be able to store it in the DB
_, content_full = _extract_message_style(_remove_message_actions(content_full))
_, content_full = extract_message_style(remove_message_action(content_full))

_logger.debug(f"Chat response: {content_full}")
_logger.debug(f"Tool calls: {tool_calls}")
Expand Down
38 changes: 30 additions & 8 deletions models/message.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field, validator
from typing import Any, List, Union
from typing import Any, List, Optional, Tuple, Union
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessageToolCallParam,
Expand All @@ -11,15 +11,12 @@
from inspect import getmembers, isfunction
from json_repair import repair_json
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
from html import unescape
import re
from urllib.parse import unquote


DOUBLE_ESCAPED_UNICODE_R = r"\\\\u([0-9a-fA-F]{4})"
FUNC_NAME_SANITIZER_R = r"[^a-zA-Z0-9_-]"
NON_ESCAPED_UNICODE_R = r"(?<!\\)u[0-9a-fA-F]{4}"
REMOVE_DOUBLE_ESCAPE_R = r"\\u\1"
_FUNC_NAME_SANITIZER_R = r"[^a-zA-Z0-9_-]"
_MESSAGE_ACTION_R = r"action=([a-z_]*)( .*)?"
_MESSAGE_STYLE_R = r"style=([a-z_]*)( .*)?"


class StyleEnum(str, Enum):
Expand Down Expand Up @@ -63,7 +60,7 @@ def to_openai(self) -> ChatCompletionMessageToolCallParam:
"arguments": self.function_arguments,
"name": "-".join(
re.sub(
FUNC_NAME_SANITIZER_R,
_FUNC_NAME_SANITIZER_R,
"-",
self.function_name,
).split("-")
Expand Down Expand Up @@ -183,3 +180,28 @@ def to_openai(
)
)
return res


def remove_message_action(text: str) -> str:
"""
Remove action from content. AI often adds it by mistake event if explicitly asked not to.
"""
res = re.match(_MESSAGE_ACTION_R, text)
if not res:
return text.strip()
content = res.group(2)
return content.strip() if content else ""


def extract_message_style(text: str) -> Tuple[Optional[StyleEnum], str]:
"""
Detect the style of a message.
"""
res = re.match(_MESSAGE_STYLE_R, text)
if not res:
return None, text
try:
content = res.group(2)
return StyleEnum(res.group(1)), (content.strip() if content else "")
except ValueError:
return None, text

0 comments on commit af0aa50

Please sign in to comment.