Skip to content

Commit

Permalink
Refactor io_loop.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mikejar committed Jan 11, 2025
1 parent 6a1c3ee commit 7bcf8df
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 28 deletions.
2 changes: 1 addition & 1 deletion assistants/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

def cli():
"""
Main function for the Assistant CLI.
Main function (entrypoint) for the Assistant CLI.
"""

# Parse command line arguments, if --help is passed, it will exit here
Expand Down
71 changes: 45 additions & 26 deletions assistants/cli/io_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
This module contains the main input/output loop for interacting with the assistant.
"""
import asyncio
from dataclasses import dataclass
from enum import Enum
from typing import Optional

from openai.types.beta.threads import Message
from prompt_toolkit import prompt
from prompt_toolkit.history import FileHistory
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.styles import Style

from assistants.ai.memory import MemoryMixin
from assistants.ai.openai import Assistant
from assistants.ai.types import AssistantProtocol
from assistants.ai.types import AssistantProtocol, MessageData
from assistants.cli import output
from assistants.cli.commands import COMMAND_MAP, EXIT_COMMANDS, IoEnviron
from assistants.cli.terminal import clear_screen
Expand All @@ -21,19 +22,33 @@
from assistants.log import logger
from assistants.user_data.sqlite_backend.threads import save_thread_data

bindings = KeyBindings()

# Prompt history
history = FileHistory(f"{CONFIG_DIR}/history")
# Constants and Configuration
class PromptStyle(Enum):
USER_INPUT = "ansigreen"
PROMPT_SYMBOL = "ansibrightgreen"


INPUT_CLASSNAME = "input"


@dataclass
class PromptConfig:
style: Style = Style.from_dict(
{
"": PromptStyle.USER_INPUT.value,
INPUT_CLASSNAME: PromptStyle.PROMPT_SYMBOL.value,
}
)
prompt_symbol: str = ">>>"
history_file: str = f"{CONFIG_DIR}/history"


# Styling for the prompt_toolkit prompt
style = Style.from_dict(
{
"": "ansigreen", # green user input
"input": "ansibrightgreen", # bright green prompt symbol
},
)
PROMPT = [("class:input", ">>> ")] # prompt symbol
# Setup
bindings = KeyBindings()
config = PromptConfig()
history = FileHistory(config.history_file)
PROMPT = [(f"class:{INPUT_CLASSNAME}", f"{config.prompt_symbol} ")]


# Bind CTRL+L to clear the screen
Expand All @@ -42,10 +57,15 @@ def _(_event):
clear_screen()


def get_user_input() -> str:
"""Get user input from interactive/styled prompt (prompt_toolkit)."""
return prompt(PROMPT, style=config.style, history=history)


async def io_loop_async(
assistant: AssistantProtocol | MemoryMixin,
initial_input: str = "",
last_message: Optional[Message] = None,
last_message: Optional[MessageData] = None,
thread_id: Optional[str] = None,
):
"""
Expand All @@ -56,16 +76,6 @@ async def io_loop_async(
:param last_message: The last message in the conversation thread.
:param thread_id: The ID of the conversation thread.
"""
user_input = ""

def get_user_input() -> str:
"""
Get user input from the prompt.
:return: The user input as a string.
"""
return prompt(PROMPT, style=style, history=history)

environ = IoEnviron(
assistant=assistant,
last_message=last_message,
Expand Down Expand Up @@ -99,7 +109,7 @@ def get_user_input() -> str:
continue

environ.user_input = user_input
asyncio.run(converse(environ))
await converse(environ)


async def converse(
Expand Down Expand Up @@ -143,6 +153,15 @@ async def converse(
await save_thread_data(
environ.thread_id, assistant.assistant_id, environ.user_input
)
elif not isinstance(assistant, Assistant):
elif isinstance(assistant, MemoryMixin):
await assistant.save_conversation()
environ.thread_id = assistant.conversation_id


def io_loop(
assistant: AssistantProtocol | MemoryMixin,
initial_input: str = "",
last_message: Optional[MessageData] = None,
thread_id: Optional[str] = None,
):
asyncio.run(io_loop_async(assistant, initial_input, last_message, thread_id))
3 changes: 2 additions & 1 deletion assistants/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from assistants.ai.anthropic import Claude
from assistants.ai.openai import Assistant, Completion
from assistants.ai.types import AssistantProtocol
from assistants.config import environment
from assistants.lib.exceptions import ConfigError
from assistants.user_data.sqlite_backend.threads import (
Expand Down Expand Up @@ -71,7 +72,7 @@ def get_text_from_default_editor(initial_text=None):

async def create_assistant_and_thread(
args: Namespace,
) -> tuple[Assistant, Optional[ThreadData]]:
) -> tuple[AssistantProtocol, Optional[ThreadData]]:
if args.code:
if environment.CODE_MODEL == "o1-mini":
# Create a completion model for code reasoning (slower and more expensive)
Expand Down

0 comments on commit 7bcf8df

Please sign in to comment.