diff --git a/assistants/ai/memory.py b/assistants/ai/memory.py index 7fa2a94..8887f02 100644 --- a/assistants/ai/memory.py +++ b/assistants/ai/memory.py @@ -59,9 +59,10 @@ async def load_conversation(self, conversation_id: Optional[str] = None): self.memory = json.loads(conversation.conversation) if conversation else [] self.conversation_id = conversation.id if conversation else uuid.uuid4().hex - async def save_conversation(self): + async def save_conversation_state(self) -> str: """ Save the current conversation to the database. + :return: The conversation ID. """ if not self.memory: return @@ -76,6 +77,7 @@ async def save_conversation(self): last_updated=datetime.now(), ) ) + return self.conversation_id def get_last_message(self, thread_id: str) -> Optional[MessageData]: """ diff --git a/assistants/ai/openai.py b/assistants/ai/openai.py index 6f461c5..13f2929 100644 --- a/assistants/ai/openai.py +++ b/assistants/ai/openai.py @@ -21,6 +21,7 @@ from assistants.config import environment from assistants.lib.exceptions import ConfigError, NoResponseError from assistants.log import logger +from assistants.user_data import threads_table from assistants.user_data.sqlite_backend.assistants import ( get_assistant_data, save_assistant_id, @@ -76,6 +77,7 @@ def __init__( # pylint: disable=too-many-arguments self._config_hash = None self.assistant = None self.last_message = None + self.last_prompt = None async def start(self): """ @@ -238,6 +240,7 @@ async def prompt(self, prompt: str, thread_id: Optional[str] = None) -> Run: :param thread_id: Optional ID of the thread to continue. :return: The run object. """ + self.last_prompt = prompt if thread_id is None: thread = self.start_thread(prompt) run = self.run_thread(thread) @@ -261,6 +264,7 @@ async def image_prompt(self, prompt: str) -> Optional[str]: :param prompt: The image prompt. :return: The URL of the generated image. """ + self.last_prompt = prompt response = self.client.images.generate( model=environment.IMAGE_MODEL, prompt=prompt, @@ -313,6 +317,16 @@ async def converse( return self.get_last_message(thread_id) + async def save_conversation_state(self) -> str: + """ + Save the state of the conversation. + :return: The thread ID of the conversation. + """ + await threads_table.save_thread( + self.last_message.thread_id, self.assistant_id, self.last_prompt + ) + return self.last_message.thread_id + class Completion(MemoryMixin): """ diff --git a/assistants/ai/types.py b/assistants/ai/types.py index c20d17b..ea9a2d1 100644 --- a/assistants/ai/types.py +++ b/assistants/ai/types.py @@ -63,6 +63,13 @@ def get_last_message(self, thread_id: str) -> Optional[MessageData]: """ ... + def save_conversation_state(self) -> str: + """ + Save the current state of the conversation. + :return: The conversation ID/ Thread ID. + """ + ... + class MessageDict(TypedDict): """ diff --git a/assistants/cli/io_loop.py b/assistants/cli/io_loop.py index e78910e..b78b9f6 100644 --- a/assistants/cli/io_loop.py +++ b/assistants/cli/io_loop.py @@ -13,7 +13,6 @@ from prompt_toolkit.styles import Style from assistants.ai.memory import MemoryMixin -from assistants.ai.openai import Assistant from assistants.ai.types import AssistantProtocol from assistants.cli import output from assistants.cli.commands import COMMAND_MAP, EXIT_COMMANDS, IoEnviron @@ -21,7 +20,6 @@ from assistants.cli.utils import highlight_code_blocks from assistants.config.file_management import CONFIG_DIR from assistants.log import logger -from assistants.user_data.sqlite_backend.threads import threads_table # Constants and Configuration @@ -121,7 +119,7 @@ async def converse( """ assistant = environ.assistant last_message = environ.last_message - thread_id = environ.thread_id + thread_id = environ.thread_id # Could be None; a new thread will be created if so. message = await assistant.converse( environ.user_input, last_message.thread_id if last_message else thread_id @@ -140,20 +138,10 @@ async def converse( output.default(text) output.new_line(2) - environ.last_message = message - if ( - environ.last_message - and not environ.thread_id - and isinstance(assistant, Assistant) - ): - environ.thread_id = environ.last_message.thread_id - await threads_table.save_thread( - environ.thread_id, assistant.assistant_id, environ.user_input - ) - elif isinstance(assistant, MemoryMixin): - await assistant.save_conversation() - environ.thread_id = assistant.conversation_id + # Set and save the new conversation state for future iterations: + environ.last_message = message + environ.thread_id = await assistant.save_conversation_state() def io_loop(