This repository has been archived by the owner on Jan 7, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 243
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
20 changed files
with
515 additions
and
459 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Make sure commands are initialized | ||
|
||
from . import commands # noqa: F401 # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
from mentat.errors import MentatError | ||
from mentat.session_context import SESSION_CONTEXT | ||
|
||
|
||
class Command(ABC): | ||
""" | ||
Base Command class. To create a new command, extend this class, provide a command_name, | ||
and import the class in commands.__init__.py so that it is initialized on startup. | ||
""" | ||
|
||
# Unfortunately, Command isn't defined here yet, so even with annotations we need quotation marks | ||
_registered_commands = dict[str, type["Command"]]() | ||
hidden = False | ||
|
||
def __init_subclass__(cls, command_name: str | None) -> None: | ||
if command_name is not None: | ||
Command._registered_commands[command_name] = cls | ||
|
||
@classmethod | ||
def create_command(cls, command_name: str) -> Command: | ||
if command_name not in cls._registered_commands: | ||
return InvalidCommand(command_name) | ||
|
||
command_cls = cls._registered_commands[command_name] | ||
return command_cls() | ||
|
||
@classmethod | ||
def get_command_names(cls) -> list[str]: | ||
return [ | ||
name | ||
for name, command in cls._registered_commands.items() | ||
if not command.hidden | ||
] | ||
|
||
@classmethod | ||
def get_command_completions(cls) -> List[str]: | ||
return list(map(lambda name: "/" + name, cls.get_command_names())) | ||
|
||
@abstractmethod | ||
async def apply(self, *args: str) -> None: | ||
pass | ||
|
||
# TODO: make more robust way to specify arguments for commands | ||
@classmethod | ||
@abstractmethod | ||
def argument_names(cls) -> list[str]: | ||
pass | ||
|
||
@classmethod | ||
@abstractmethod | ||
def help_message(cls) -> str: | ||
pass | ||
|
||
|
||
class InvalidCommand(Command, command_name=None): | ||
def __init__(self, invalid_name: str): | ||
self.invalid_name = invalid_name | ||
|
||
async def apply(self, *args: str) -> None: | ||
session_context = SESSION_CONTEXT.get() | ||
stream = session_context.stream | ||
|
||
stream.send( | ||
f"{self.invalid_name} is not a valid command. Use /help to see a list of" | ||
" all valid commands", | ||
color="light_yellow", | ||
) | ||
|
||
@classmethod | ||
def argument_names(cls) -> list[str]: | ||
raise MentatError("Argument names called on invalid command") | ||
|
||
@classmethod | ||
def help_message(cls) -> str: | ||
raise MentatError("Help message called on invalid command") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# ruff: noqa: F401 | ||
# type: ignore | ||
|
||
# Import all of the commands so that they are initialized | ||
from .clear import ClearCommand | ||
from .commit import CommitCommand | ||
from .config import ConfigCommand | ||
from .context import ContextCommand | ||
from .conversation import ConversationCommand | ||
from .exclude import ExcludeCommand | ||
from .help import HelpCommand | ||
from .include import IncludeCommand | ||
from .run import RunCommand | ||
from .screenshot import ScreenshotCommand | ||
from .search import SearchCommand | ||
from .undo import UndoCommand | ||
from .undoall import UndoAllCommand |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from mentat.command.command import Command | ||
from mentat.session_context import SESSION_CONTEXT | ||
|
||
|
||
class ClearCommand(Command, command_name="clear"): | ||
async def apply(self, *args: str) -> None: | ||
session_context = SESSION_CONTEXT.get() | ||
stream = session_context.stream | ||
conversation = session_context.conversation | ||
|
||
conversation.clear_messages() | ||
message = "Message history cleared" | ||
stream.send(message, color="green") | ||
|
||
@classmethod | ||
def argument_names(cls) -> list[str]: | ||
return [] | ||
|
||
@classmethod | ||
def help_message(cls) -> str: | ||
return "Clear the current conversation's message history" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from mentat.command.command import Command | ||
from mentat.git_handler import commit | ||
|
||
|
||
class CommitCommand(Command, command_name="commit"): | ||
default_message = "Automatic commit" | ||
|
||
async def apply(self, *args: str) -> None: | ||
if args: | ||
commit(args[0]) | ||
else: | ||
commit(self.__class__.default_message) | ||
|
||
@classmethod | ||
def argument_names(cls) -> list[str]: | ||
return [f"commit_message={cls.default_message}"] | ||
|
||
@classmethod | ||
def help_message(cls) -> str: | ||
return "Commits all of your unstaged and staged changes to git" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import attr | ||
|
||
from mentat.command.command import Command | ||
from mentat.session_context import SESSION_CONTEXT | ||
|
||
|
||
class ConfigCommand(Command, command_name="config"): | ||
async def apply(self, *args: str) -> None: | ||
session_context = SESSION_CONTEXT.get() | ||
stream = session_context.stream | ||
config = session_context.config | ||
if len(args) == 0: | ||
stream.send("No config option specified", color="yellow") | ||
else: | ||
setting = args[0] | ||
if hasattr(config, setting): | ||
if len(args) == 1: | ||
value = getattr(config, setting) | ||
description = attr.fields_dict(type(config))[setting].metadata.get( | ||
"description" | ||
) | ||
stream.send(f"{setting}: {value}") | ||
if description: | ||
stream.send(f"Description: {description}") | ||
elif len(args) == 2: | ||
value = args[1] | ||
if attr.fields_dict(type(config))[setting].metadata.get( | ||
"no_midsession_change" | ||
): | ||
stream.send( | ||
f"Cannot change {setting} mid-session. Please restart" | ||
" Mentat to change this setting.", | ||
color="yellow", | ||
) | ||
return | ||
try: | ||
setattr(config, setting, value) | ||
stream.send(f"{setting} set to {value}", color="green") | ||
except (TypeError, ValueError): | ||
stream.send( | ||
f"Illegal value for {setting}: {value}", color="red" | ||
) | ||
else: | ||
stream.send("Too many arguments", color="yellow") | ||
else: | ||
stream.send(f"Unrecognized config option: {setting}", color="red") | ||
|
||
@classmethod | ||
def argument_names(cls) -> list[str]: | ||
return ["setting", "value"] | ||
|
||
@classmethod | ||
def help_message(cls) -> str: | ||
return "Set a configuration option or omit value to see current value." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from mentat.command.command import Command | ||
from mentat.session_context import SESSION_CONTEXT | ||
|
||
|
||
class ContextCommand(Command, command_name="context"): | ||
async def apply(self, *args: str) -> None: | ||
session_context = SESSION_CONTEXT.get() | ||
code_context = session_context.code_context | ||
|
||
code_context.display_context() | ||
|
||
@classmethod | ||
def argument_names(cls) -> list[str]: | ||
return [] | ||
|
||
@classmethod | ||
def help_message(cls) -> str: | ||
return "Shows all files currently in Mentat's context" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import webbrowser | ||
|
||
from mentat.command.command import Command | ||
from mentat.session_context import SESSION_CONTEXT | ||
from mentat.transcripts import Transcript, get_transcript_logs | ||
from mentat.utils import create_viewer | ||
|
||
|
||
class ConversationCommand(Command, command_name="conversation"): | ||
hidden = True | ||
|
||
async def apply(self, *args: str) -> None: | ||
session_context = SESSION_CONTEXT.get() | ||
conversation = session_context.conversation | ||
|
||
logs = get_transcript_logs() | ||
|
||
viewer_path = create_viewer( | ||
[Transcript(timestamp="Current", messages=conversation.literal_messages)] | ||
+ logs | ||
) | ||
webbrowser.open(f"file://{viewer_path.resolve()}") | ||
|
||
@classmethod | ||
def argument_names(cls) -> list[str]: | ||
return [] | ||
|
||
@classmethod | ||
def help_message(cls) -> str: | ||
return "Opens an html page showing the conversation as seen by Mentat so far" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from pathlib import Path | ||
|
||
from mentat.command.command import Command | ||
from mentat.include_files import print_invalid_path | ||
from mentat.session_context import SESSION_CONTEXT | ||
|
||
|
||
class ExcludeCommand(Command, command_name="exclude"): | ||
async def apply(self, *args: str) -> None: | ||
session_context = SESSION_CONTEXT.get() | ||
stream = session_context.stream | ||
code_context = session_context.code_context | ||
git_root = session_context.git_root | ||
|
||
if len(args) == 0: | ||
stream.send("No files specified", color="yellow") | ||
return | ||
for file_path in args: | ||
excluded_paths, invalid_paths = code_context.exclude_file( | ||
Path(file_path).absolute() | ||
) | ||
for invalid_path in invalid_paths: | ||
print_invalid_path(invalid_path) | ||
for excluded_path in excluded_paths: | ||
rel_path = excluded_path.relative_to(git_root) | ||
stream.send(f"{rel_path} removed from context", color="red") | ||
|
||
@classmethod | ||
def argument_names(cls) -> list[str]: | ||
return ["file1", "file2", "..."] | ||
|
||
@classmethod | ||
def help_message(cls) -> str: | ||
return "Remove files from the code context" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from mentat.command.command import Command | ||
from mentat.session_context import SESSION_CONTEXT | ||
|
||
help_message_width = 60 | ||
|
||
|
||
class HelpCommand(Command, command_name="help"): | ||
async def apply(self, *args: str) -> None: | ||
session_context = SESSION_CONTEXT.get() | ||
stream = session_context.stream | ||
|
||
if not args: | ||
commands = Command.get_command_names() | ||
else: | ||
commands = args | ||
for command_name in commands: | ||
if command_name not in Command._registered_commands: | ||
stream.send( | ||
f"Error: Command {command_name} does not exist.", color="red" | ||
) | ||
else: | ||
command_class = Command._registered_commands[command_name] | ||
argument_names = command_class.argument_names() | ||
help_message = command_class.help_message() | ||
message = ( | ||
" ".join( | ||
[f"/{command_name}"] + [f"<{arg}>" for arg in argument_names] | ||
).ljust(help_message_width) | ||
+ help_message | ||
) | ||
stream.send(message) | ||
|
||
@classmethod | ||
def argument_names(cls) -> list[str]: | ||
return ["command"] | ||
|
||
@classmethod | ||
def help_message(cls) -> str: | ||
return "Displays this message" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from pathlib import Path | ||
|
||
from mentat.command.command import Command | ||
from mentat.include_files import print_invalid_path | ||
from mentat.session_context import SESSION_CONTEXT | ||
|
||
|
||
class IncludeCommand(Command, command_name="include"): | ||
async def apply(self, *args: str) -> None: | ||
session_context = SESSION_CONTEXT.get() | ||
stream = session_context.stream | ||
code_context = session_context.code_context | ||
git_root = session_context.git_root | ||
|
||
if len(args) == 0: | ||
stream.send("No files specified", color="yellow") | ||
return | ||
for file_path in args: | ||
included_paths, invalid_paths = code_context.include_file( | ||
Path(file_path).absolute() | ||
) | ||
for invalid_path in invalid_paths: | ||
print_invalid_path(invalid_path) | ||
for included_path in included_paths: | ||
rel_path = included_path.relative_to(git_root) | ||
stream.send(f"{rel_path} added to context", color="green") | ||
|
||
@classmethod | ||
def argument_names(cls) -> list[str]: | ||
return ["file1", "file2", "..."] | ||
|
||
@classmethod | ||
def help_message(cls) -> str: | ||
return "Add files to the code context" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from mentat.command.command import Command | ||
from mentat.session_context import SESSION_CONTEXT | ||
|
||
|
||
class RunCommand(Command, command_name="run"): | ||
async def apply(self, *args: str) -> None: | ||
session_context = SESSION_CONTEXT.get() | ||
conversation = session_context.conversation | ||
await conversation.run_command(list(args)) | ||
|
||
@classmethod | ||
def argument_names(cls) -> list[str]: | ||
return ["command", "args..."] | ||
|
||
@classmethod | ||
def help_message(cls) -> str: | ||
return "Run a shell command and put its output in context." |
Oops, something went wrong.