Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

refactor commands #320

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mentat/command/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Make sure commands are initialized

from . import commands # noqa: F401 # type: ignore
80 changes: 80 additions & 0 deletions mentat/command/command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import List

from mentat.errors import MentatError
from mentat.session_context import SESSION_CONTEXT


class Command(ABC):
"""
Base Command class. To create a new command, extend this class, provide a command_name,
and import the class in commands.__init__.py so that it is initialized on startup.
"""

# Unfortunately, Command isn't defined here yet, so even with annotations we need quotation marks
_registered_commands = dict[str, type["Command"]]()
hidden = False

def __init_subclass__(cls, command_name: str | None) -> None:
if command_name is not None:
Command._registered_commands[command_name] = cls

@classmethod
def create_command(cls, command_name: str) -> Command:
if command_name not in cls._registered_commands:
return InvalidCommand(command_name)

command_cls = cls._registered_commands[command_name]
return command_cls()

@classmethod
def get_command_names(cls) -> list[str]:
return [
name
for name, command in cls._registered_commands.items()
if not command.hidden
]

@classmethod
def get_command_completions(cls) -> List[str]:
return list(map(lambda name: "/" + name, cls.get_command_names()))

@abstractmethod
async def apply(self, *args: str) -> None:
pass

# TODO: make more robust way to specify arguments for commands
@classmethod
@abstractmethod
def argument_names(cls) -> list[str]:
pass

@classmethod
@abstractmethod
def help_message(cls) -> str:
pass


class InvalidCommand(Command, command_name=None):
def __init__(self, invalid_name: str):
self.invalid_name = invalid_name

async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream

stream.send(
f"{self.invalid_name} is not a valid command. Use /help to see a list of"
" all valid commands",
color="light_yellow",
)

@classmethod
def argument_names(cls) -> list[str]:
raise MentatError("Argument names called on invalid command")

@classmethod
def help_message(cls) -> str:
raise MentatError("Help message called on invalid command")
17 changes: 17 additions & 0 deletions mentat/command/commands/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# ruff: noqa: F401
# type: ignore

# Import all of the commands so that they are initialized
from .clear import ClearCommand
from .commit import CommitCommand
from .config import ConfigCommand
from .context import ContextCommand
from .conversation import ConversationCommand
from .exclude import ExcludeCommand
from .help import HelpCommand
from .include import IncludeCommand
from .run import RunCommand
from .screenshot import ScreenshotCommand
from .search import SearchCommand
from .undo import UndoCommand
from .undoall import UndoAllCommand
21 changes: 21 additions & 0 deletions mentat/command/commands/clear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from mentat.command.command import Command
from mentat.session_context import SESSION_CONTEXT


class ClearCommand(Command, command_name="clear"):
async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
conversation = session_context.conversation

conversation.clear_messages()
message = "Message history cleared"
stream.send(message, color="green")

@classmethod
def argument_names(cls) -> list[str]:
return []

@classmethod
def help_message(cls) -> str:
return "Clear the current conversation's message history"
20 changes: 20 additions & 0 deletions mentat/command/commands/commit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from mentat.command.command import Command
from mentat.git_handler import commit


class CommitCommand(Command, command_name="commit"):
default_message = "Automatic commit"

async def apply(self, *args: str) -> None:
if args:
commit(args[0])
else:
commit(self.__class__.default_message)

@classmethod
def argument_names(cls) -> list[str]:
return [f"commit_message={cls.default_message}"]

@classmethod
def help_message(cls) -> str:
return "Commits all of your unstaged and staged changes to git"
54 changes: 54 additions & 0 deletions mentat/command/commands/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import attr

from mentat.command.command import Command
from mentat.session_context import SESSION_CONTEXT


class ConfigCommand(Command, command_name="config"):
async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
config = session_context.config
if len(args) == 0:
stream.send("No config option specified", color="yellow")
else:
setting = args[0]
if hasattr(config, setting):
if len(args) == 1:
value = getattr(config, setting)
description = attr.fields_dict(type(config))[setting].metadata.get(
"description"
)
stream.send(f"{setting}: {value}")
if description:
stream.send(f"Description: {description}")
elif len(args) == 2:
value = args[1]
if attr.fields_dict(type(config))[setting].metadata.get(
"no_midsession_change"
):
stream.send(
f"Cannot change {setting} mid-session. Please restart"
" Mentat to change this setting.",
color="yellow",
)
return
try:
setattr(config, setting, value)
stream.send(f"{setting} set to {value}", color="green")
except (TypeError, ValueError):
stream.send(
f"Illegal value for {setting}: {value}", color="red"
)
else:
stream.send("Too many arguments", color="yellow")
else:
stream.send(f"Unrecognized config option: {setting}", color="red")

@classmethod
def argument_names(cls) -> list[str]:
return ["setting", "value"]

@classmethod
def help_message(cls) -> str:
return "Set a configuration option or omit value to see current value."
18 changes: 18 additions & 0 deletions mentat/command/commands/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from mentat.command.command import Command
from mentat.session_context import SESSION_CONTEXT


class ContextCommand(Command, command_name="context"):
async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
code_context = session_context.code_context

code_context.display_context()

@classmethod
def argument_names(cls) -> list[str]:
return []

@classmethod
def help_message(cls) -> str:
return "Shows all files currently in Mentat's context"
30 changes: 30 additions & 0 deletions mentat/command/commands/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import webbrowser

from mentat.command.command import Command
from mentat.session_context import SESSION_CONTEXT
from mentat.transcripts import Transcript, get_transcript_logs
from mentat.utils import create_viewer


class ConversationCommand(Command, command_name="conversation"):
hidden = True

async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
conversation = session_context.conversation

logs = get_transcript_logs()

viewer_path = create_viewer(
[Transcript(timestamp="Current", messages=conversation.literal_messages)]
+ logs
)
webbrowser.open(f"file://{viewer_path.resolve()}")

@classmethod
def argument_names(cls) -> list[str]:
return []

@classmethod
def help_message(cls) -> str:
return "Opens an html page showing the conversation as seen by Mentat so far"
34 changes: 34 additions & 0 deletions mentat/command/commands/exclude.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from pathlib import Path

from mentat.command.command import Command
from mentat.include_files import print_invalid_path
from mentat.session_context import SESSION_CONTEXT


class ExcludeCommand(Command, command_name="exclude"):
async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
code_context = session_context.code_context
git_root = session_context.git_root

if len(args) == 0:
stream.send("No files specified", color="yellow")
return
for file_path in args:
excluded_paths, invalid_paths = code_context.exclude_file(
Path(file_path).absolute()
)
for invalid_path in invalid_paths:
print_invalid_path(invalid_path)
for excluded_path in excluded_paths:
rel_path = excluded_path.relative_to(git_root)
stream.send(f"{rel_path} removed from context", color="red")

@classmethod
def argument_names(cls) -> list[str]:
return ["file1", "file2", "..."]

@classmethod
def help_message(cls) -> str:
return "Remove files from the code context"
39 changes: 39 additions & 0 deletions mentat/command/commands/help.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from mentat.command.command import Command
from mentat.session_context import SESSION_CONTEXT

help_message_width = 60


class HelpCommand(Command, command_name="help"):
async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream

if not args:
commands = Command.get_command_names()
else:
commands = args
for command_name in commands:
if command_name not in Command._registered_commands:
stream.send(
f"Error: Command {command_name} does not exist.", color="red"
)
else:
command_class = Command._registered_commands[command_name]
argument_names = command_class.argument_names()
help_message = command_class.help_message()
message = (
" ".join(
[f"/{command_name}"] + [f"<{arg}>" for arg in argument_names]
).ljust(help_message_width)
+ help_message
)
stream.send(message)

@classmethod
def argument_names(cls) -> list[str]:
return ["command"]

@classmethod
def help_message(cls) -> str:
return "Displays this message"
34 changes: 34 additions & 0 deletions mentat/command/commands/include.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from pathlib import Path

from mentat.command.command import Command
from mentat.include_files import print_invalid_path
from mentat.session_context import SESSION_CONTEXT


class IncludeCommand(Command, command_name="include"):
async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
code_context = session_context.code_context
git_root = session_context.git_root

if len(args) == 0:
stream.send("No files specified", color="yellow")
return
for file_path in args:
included_paths, invalid_paths = code_context.include_file(
Path(file_path).absolute()
)
for invalid_path in invalid_paths:
print_invalid_path(invalid_path)
for included_path in included_paths:
rel_path = included_path.relative_to(git_root)
stream.send(f"{rel_path} added to context", color="green")

@classmethod
def argument_names(cls) -> list[str]:
return ["file1", "file2", "..."]

@classmethod
def help_message(cls) -> str:
return "Add files to the code context"
17 changes: 17 additions & 0 deletions mentat/command/commands/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from mentat.command.command import Command
from mentat.session_context import SESSION_CONTEXT


class RunCommand(Command, command_name="run"):
async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
conversation = session_context.conversation
await conversation.run_command(list(args))

@classmethod
def argument_names(cls) -> list[str]:
return ["command", "args..."]

@classmethod
def help_message(cls) -> str:
return "Run a shell command and put its output in context."
Loading