From 7bcf8df67ec7a1bbb8ed20c89d5e551ef329c632 Mon Sep 17 00:00:00 2001 From: Michael Jarvis Date: Sat, 11 Jan 2025 10:55:12 +0000 Subject: [PATCH] Refactor io_loop.py --- assistants/cli/cli.py | 2 +- assistants/cli/io_loop.py | 71 +++++++++++++++++++++++++-------------- assistants/cli/utils.py | 3 +- 3 files changed, 48 insertions(+), 28 deletions(-) diff --git a/assistants/cli/cli.py b/assistants/cli/cli.py index cc41065..16a85a5 100644 --- a/assistants/cli/cli.py +++ b/assistants/cli/cli.py @@ -21,7 +21,7 @@ def cli(): """ - Main function for the Assistant CLI. + Main function (entrypoint) for the Assistant CLI. """ # Parse command line arguments, if --help is passed, it will exit here diff --git a/assistants/cli/io_loop.py b/assistants/cli/io_loop.py index 02d7393..65028da 100644 --- a/assistants/cli/io_loop.py +++ b/assistants/cli/io_loop.py @@ -2,9 +2,10 @@ This module contains the main input/output loop for interacting with the assistant. """ import asyncio +from dataclasses import dataclass +from enum import Enum from typing import Optional -from openai.types.beta.threads import Message from prompt_toolkit import prompt from prompt_toolkit.history import FileHistory from prompt_toolkit.key_binding import KeyBindings @@ -12,7 +13,7 @@ from assistants.ai.memory import MemoryMixin from assistants.ai.openai import Assistant -from assistants.ai.types import AssistantProtocol +from assistants.ai.types import AssistantProtocol, MessageData from assistants.cli import output from assistants.cli.commands import COMMAND_MAP, EXIT_COMMANDS, IoEnviron from assistants.cli.terminal import clear_screen @@ -21,19 +22,33 @@ from assistants.log import logger from assistants.user_data.sqlite_backend.threads import save_thread_data -bindings = KeyBindings() -# Prompt history -history = FileHistory(f"{CONFIG_DIR}/history") +# Constants and Configuration +class PromptStyle(Enum): + USER_INPUT = "ansigreen" + PROMPT_SYMBOL = "ansibrightgreen" + + +INPUT_CLASSNAME = "input" + + +@dataclass +class PromptConfig: + style: Style = Style.from_dict( + { + "": PromptStyle.USER_INPUT.value, + INPUT_CLASSNAME: PromptStyle.PROMPT_SYMBOL.value, + } + ) + prompt_symbol: str = ">>>" + history_file: str = f"{CONFIG_DIR}/history" + -# Styling for the prompt_toolkit prompt -style = Style.from_dict( - { - "": "ansigreen", # green user input - "input": "ansibrightgreen", # bright green prompt symbol - }, -) -PROMPT = [("class:input", ">>> ")] # prompt symbol +# Setup +bindings = KeyBindings() +config = PromptConfig() +history = FileHistory(config.history_file) +PROMPT = [(f"class:{INPUT_CLASSNAME}", f"{config.prompt_symbol} ")] # Bind CTRL+L to clear the screen @@ -42,10 +57,15 @@ def _(_event): clear_screen() +def get_user_input() -> str: + """Get user input from interactive/styled prompt (prompt_toolkit).""" + return prompt(PROMPT, style=config.style, history=history) + + async def io_loop_async( assistant: AssistantProtocol | MemoryMixin, initial_input: str = "", - last_message: Optional[Message] = None, + last_message: Optional[MessageData] = None, thread_id: Optional[str] = None, ): """ @@ -56,16 +76,6 @@ async def io_loop_async( :param last_message: The last message in the conversation thread. :param thread_id: The ID of the conversation thread. """ - user_input = "" - - def get_user_input() -> str: - """ - Get user input from the prompt. - - :return: The user input as a string. - """ - return prompt(PROMPT, style=style, history=history) - environ = IoEnviron( assistant=assistant, last_message=last_message, @@ -99,7 +109,7 @@ def get_user_input() -> str: continue environ.user_input = user_input - asyncio.run(converse(environ)) + await converse(environ) async def converse( @@ -143,6 +153,15 @@ async def converse( await save_thread_data( environ.thread_id, assistant.assistant_id, environ.user_input ) - elif not isinstance(assistant, Assistant): + elif isinstance(assistant, MemoryMixin): await assistant.save_conversation() environ.thread_id = assistant.conversation_id + + +def io_loop( + assistant: AssistantProtocol | MemoryMixin, + initial_input: str = "", + last_message: Optional[MessageData] = None, + thread_id: Optional[str] = None, +): + asyncio.run(io_loop_async(assistant, initial_input, last_message, thread_id)) diff --git a/assistants/cli/utils.py b/assistants/cli/utils.py index 916c28b..99a6c5d 100644 --- a/assistants/cli/utils.py +++ b/assistants/cli/utils.py @@ -12,6 +12,7 @@ from assistants.ai.anthropic import Claude from assistants.ai.openai import Assistant, Completion +from assistants.ai.types import AssistantProtocol from assistants.config import environment from assistants.lib.exceptions import ConfigError from assistants.user_data.sqlite_backend.threads import ( @@ -71,7 +72,7 @@ def get_text_from_default_editor(initial_text=None): async def create_assistant_and_thread( args: Namespace, -) -> tuple[Assistant, Optional[ThreadData]]: +) -> tuple[AssistantProtocol, Optional[ThreadData]]: if args.code: if environment.CODE_MODEL == "o1-mini": # Create a completion model for code reasoning (slower and more expensive)