Skip to content

Commit

Permalink
Single save conversation API between assistant types
Browse files Browse the repository at this point in the history
  • Loading branch information
mikejar committed Jan 12, 2025
1 parent 9cfedcd commit 127b506
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 17 deletions.
4 changes: 3 additions & 1 deletion assistants/ai/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ async def load_conversation(self, conversation_id: Optional[str] = None):
self.memory = json.loads(conversation.conversation) if conversation else []
self.conversation_id = conversation.id if conversation else uuid.uuid4().hex

async def save_conversation(self):
async def save_conversation_state(self) -> str:
"""
Save the current conversation to the database.
:return: The conversation ID.
"""
if not self.memory:
return
Expand All @@ -76,6 +77,7 @@ async def save_conversation(self):
last_updated=datetime.now(),
)
)
return self.conversation_id

def get_last_message(self, thread_id: str) -> Optional[MessageData]:
"""
Expand Down
14 changes: 14 additions & 0 deletions assistants/ai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from assistants.config import environment
from assistants.lib.exceptions import ConfigError, NoResponseError
from assistants.log import logger
from assistants.user_data import threads_table
from assistants.user_data.sqlite_backend.assistants import (
get_assistant_data,
save_assistant_id,
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__( # pylint: disable=too-many-arguments
self._config_hash = None
self.assistant = None
self.last_message = None
self.last_prompt = None

async def start(self):
"""
Expand Down Expand Up @@ -238,6 +240,7 @@ async def prompt(self, prompt: str, thread_id: Optional[str] = None) -> Run:
:param thread_id: Optional ID of the thread to continue.
:return: The run object.
"""
self.last_prompt = prompt
if thread_id is None:
thread = self.start_thread(prompt)
run = self.run_thread(thread)
Expand All @@ -261,6 +264,7 @@ async def image_prompt(self, prompt: str) -> Optional[str]:
:param prompt: The image prompt.
:return: The URL of the generated image.
"""
self.last_prompt = prompt
response = self.client.images.generate(
model=environment.IMAGE_MODEL,
prompt=prompt,
Expand Down Expand Up @@ -313,6 +317,16 @@ async def converse(

return self.get_last_message(thread_id)

async def save_conversation_state(self) -> str:
"""
Save the state of the conversation.
:return: The thread ID of the conversation.
"""
await threads_table.save_thread(
self.last_message.thread_id, self.assistant_id, self.last_prompt
)
return self.last_message.thread_id


class Completion(MemoryMixin):
"""
Expand Down
7 changes: 7 additions & 0 deletions assistants/ai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ def get_last_message(self, thread_id: str) -> Optional[MessageData]:
"""
...

def save_conversation_state(self) -> str:
"""
Save the current state of the conversation.
:return: The conversation ID/ Thread ID.
"""
...


class MessageDict(TypedDict):
"""
Expand Down
20 changes: 4 additions & 16 deletions assistants/cli/io_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@
from prompt_toolkit.styles import Style

from assistants.ai.memory import MemoryMixin
from assistants.ai.openai import Assistant
from assistants.ai.types import AssistantProtocol
from assistants.cli import output
from assistants.cli.commands import COMMAND_MAP, EXIT_COMMANDS, IoEnviron
from assistants.cli.terminal import clear_screen
from assistants.cli.utils import highlight_code_blocks
from assistants.config.file_management import CONFIG_DIR
from assistants.log import logger
from assistants.user_data.sqlite_backend.threads import threads_table


# Constants and Configuration
Expand Down Expand Up @@ -121,7 +119,7 @@ async def converse(
"""
assistant = environ.assistant
last_message = environ.last_message
thread_id = environ.thread_id
thread_id = environ.thread_id # Could be None; a new thread will be created if so.

message = await assistant.converse(
environ.user_input, last_message.thread_id if last_message else thread_id
Expand All @@ -140,20 +138,10 @@ async def converse(

output.default(text)
output.new_line(2)
environ.last_message = message

if (
environ.last_message
and not environ.thread_id
and isinstance(assistant, Assistant)
):
environ.thread_id = environ.last_message.thread_id
await threads_table.save_thread(
environ.thread_id, assistant.assistant_id, environ.user_input
)
elif isinstance(assistant, MemoryMixin):
await assistant.save_conversation()
environ.thread_id = assistant.conversation_id
# Set and save the new conversation state for future iterations:
environ.last_message = message
environ.thread_id = await assistant.save_conversation_state()


def io_loop(
Expand Down

0 comments on commit 127b506

Please sign in to comment.