Skip to content

Commit

Permalink
Show last message in thread when loading thread
Browse files Browse the repository at this point in the history
  • Loading branch information
mikejar committed Jan 8, 2025
1 parent 6cf53b7 commit 7368ec1
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 32 deletions.
4 changes: 2 additions & 2 deletions assistants/ai/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from anthropic import AsyncAnthropic

from assistants.ai.memory import MemoryMixin
from assistants.ai.types import AssistantProtocol, MessageData
from assistants.ai.types import MessageData
from assistants.config.environment import ANTHROPIC_API_KEY
from assistants.lib.exceptions import ConfigError


class Claude(AssistantProtocol, MemoryMixin):
class Claude(MemoryMixin):
"""
Claude class encapsulates interactions with the Anthropic API.
Expand Down
18 changes: 16 additions & 2 deletions assistants/ai/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from datetime import datetime
from typing import Optional

from assistants.ai.types import MessageDict
from assistants.ai.types import MessageDict, MessageData
from assistants.user_data.sqlite_backend import conversations_table
from assistants.user_data.sqlite_backend.conversations import Conversation

Expand Down Expand Up @@ -75,4 +75,18 @@ async def save_conversation(self):
conversation=json.dumps(self.memory),
last_updated=datetime.now(),
)
)
)

def get_last_message(self, thread_id: str) -> Optional[MessageData]:
"""
Get the last message from the conversation or None if no message exists.
Conversation must have already been loaded.
:param thread_id: Not used; required by protocol
:return: MessageData with the last message and current conversation_id.
"""
if not self.memory:
return None
return MessageData(
text_content=self.memory[-1]["content"], thread_id=self.conversation_id
)
53 changes: 30 additions & 23 deletions assistants/ai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from openai.types.chat import ChatCompletionMessage

from assistants.ai.memory import MemoryMixin
from assistants.ai.types import AssistantProtocol, MessageData, MessageDict
from assistants.ai.types import MessageData, MessageDict
from assistants.config import environment
from assistants.lib.exceptions import ConfigError, NoResponseError
from assistants.log import logger
Expand All @@ -26,12 +26,11 @@
)


class Assistant(AssistantProtocol): # pylint: disable=too-many-instance-attributes
class Assistant: # pylint: disable=too-many-instance-attributes
"""
Encapsulates interactions with the OpenAI Assistants API.
Inherits from:
- AssistantProtocol: Protocol defining the interface for assistant classes.
Fits AssistantProtocol: Protocol defining the interface for assistant classes.
Attributes:
name (str): The name of the assistant.
Expand All @@ -42,7 +41,7 @@ class Assistant(AssistantProtocol): # pylint: disable=too-many-instance-attribu
client (openai.OpenAI): Client for interacting with the OpenAI API.
_config_hash (Optional[str]): Hash of the current configuration.
assistant (Optional[object]): The assistant object.
last_message_id (Optional[str]): ID of the last message in the thread.
last_message (Optional[str]): ID of the last message in the thread.
"""

def __init__( # pylint: disable=too-many-arguments
Expand Down Expand Up @@ -75,15 +74,15 @@ def __init__( # pylint: disable=too-many-arguments
self.name = name
self._config_hash = None
self.assistant = None
self.last_message_id = None
self.last_message = None

async def start(self):
"""
Load the assistant_id from DB if exists or create a new assistant.
"""
if not self.__dict__.get("assistant"):
self.assistant = await self.load_or_create_assistant()
self.last_message_id = None
self.last_message = None

def __getattribute__(self, item):
"""
Expand Down Expand Up @@ -270,6 +269,28 @@ async def image_prompt(self, prompt: str) -> Optional[str]:
)
return response.data[0].url

def get_last_message(self, thread_id: str) -> Optional[MessageData]:
messages = self.client.beta.threads.messages.list(
thread_id=thread_id,
order="asc",
after=self.last_message.id if self.last_message else NOT_GIVEN,
).data

last_message_in_thread = messages[-1]

if not last_message_in_thread:
return None

if self.last_message and last_message_in_thread.id == self.last_message.id:
raise NoResponseError

self.last_message = last_message_in_thread

return MessageData(
text_content=last_message_in_thread.content[0].text.value,
thread_id=thread_id,
)

async def converse(
self, user_input: str, thread_id: Optional[str] = None
) -> Optional[MessageData]:
Expand All @@ -289,24 +310,10 @@ async def converse(
else:
await self.prompt(user_input, thread_id)

messages = self.client.beta.threads.messages.list(
thread_id=thread_id, order="asc", after=self.last_message_id or NOT_GIVEN
).data

last_message_in_thread = messages[-1]

if last_message_in_thread.id == self.last_message_id:
raise NoResponseError

self.last_message_id = last_message_in_thread.id

return MessageData(
text_content=last_message_in_thread.content[0].text.value,
thread_id=thread_id,
)
return self.get_last_message(thread_id)


class Completion(AssistantProtocol, MemoryMixin):
class Completion(MemoryMixin):
"""
Encapsulates interactions with the OpenAI Chat Completion API.
Expand Down
9 changes: 9 additions & 0 deletions assistants/ai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ async def converse(
"""
...

def get_last_message(self, thread_id: str) -> Optional[MessageData]:
"""
Get the last message in the thread.
:param thread_id: the ID of the thread to continue.
:return: last message in the thread if one exists.
"""
...


class MessageDict(TypedDict):
"""
Expand Down
17 changes: 13 additions & 4 deletions assistants/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import pyperclip

from assistants.ai.memory import MemoryMixin
from assistants.ai.openai import Assistant
from assistants.ai.types import AssistantProtocol, MessageData
from assistants.cli import output
from assistants.cli.constants import IO_INSTRUCTIONS
from assistants.cli.selector import TerminalSelector
from assistants.cli.terminal import clear_screen
from assistants.cli.utils import get_text_from_default_editor
from assistants.cli.utils import get_text_from_default_editor, highlight_code_blocks
from assistants.user_data import threads_table
from assistants.user_data.sqlite_backend import conversations_table

Expand All @@ -23,7 +24,7 @@ class IoEnviron:
Environment variables for the input/output loop.
"""

assistant: AssistantProtocol | MemoryMixin
assistant: AssistantProtocol | MemoryMixin | Assistant
last_message: Optional[MessageData] = None
thread_id: Optional[str] = None

Expand Down Expand Up @@ -235,8 +236,16 @@ def __call__(self, environ: IoEnviron) -> None:
else:
asyncio.run(environ.assistant.start())

environ.last_message = None
output.inform(f"Selected thread {thread_id}")
output.inform(f"Selected thread '{thread_id}'")

last_message = environ.assistant.get_last_message(thread_id)
environ.last_message = last_message

if last_message:
output.default(highlight_code_blocks(last_message.text_content))
output.new_line(2)
else:
output.warn("No last message found in thread")


select_thread: Command = SelectThread()
Expand Down
2 changes: 1 addition & 1 deletion assistants/telegram_ui/tg_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def deauthorise_user(update: Update, context: ContextTypes.DEFAULT_TYPE):
@restricted_access
async def new_thread(update: Update, context: ContextTypes.DEFAULT_TYPE):
await user_data.clear_last_thread_id(update.effective_chat.id)
assistant.last_message_id = None
assistant.last_message = None
await context.bot.send_message(
update.effective_chat.id, "Conversation history cleared."
)
Expand Down

0 comments on commit 7368ec1

Please sign in to comment.