Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
add revisor using git diff (#475)
Browse files Browse the repository at this point in the history
  • Loading branch information
PCSwingle authored Jan 16, 2024
1 parent 555e277 commit f02559a
Show file tree
Hide file tree
Showing 11 changed files with 268 additions and 14 deletions.
11 changes: 11 additions & 0 deletions mentat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,17 @@ class Config:
},
converter=converters.optional(converters.to_bool),
)
revisor: bool = attr.field(
default=False,
metadata={
"description": (
"Enables or disables a revisor tweaking model edits after they're made."
" The revisor will use the same model regular edits do."
),
"auto_completions": bool_autocomplete,
},
converter=converters.optional(converters.to_bool),
)

# Context specific settings
file_exclude_glob_list: list[str] = attr.field(
Expand Down
8 changes: 4 additions & 4 deletions mentat/diff_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import attr

from mentat.errors import UserError
from mentat.errors import MentatError
from mentat.git_handler import (
check_head_exists,
get_diff_for_file,
Expand Down Expand Up @@ -34,8 +34,8 @@ def parse_diff(diff: str) -> list[DiffAnnotation]:
annotations: list[DiffAnnotation] = []
active_annotation: Optional[DiffAnnotation] = None
lines = diff.splitlines()
for line in lines[4:]: # Ignore header
if line.startswith(("---", "+++", "//")):
for line in lines:
if line.startswith(("---", "+++", "//", "diff", "index")):
continue
elif line.startswith("@@"):
if active_annotation:
Expand All @@ -48,7 +48,7 @@ def parse_diff(diff: str) -> list[DiffAnnotation]:
active_annotation = DiffAnnotation(start=int(new_start), message=[])
elif line.startswith(("+", "-")):
if not active_annotation:
raise UserError("Invalid diff")
raise MentatError("Invalid diff")
active_annotation.message.append(line)
if active_annotation:
annotations.append(active_annotation)
Expand Down
12 changes: 6 additions & 6 deletions mentat/parsers/change_display_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
change_delimiter = 60 * "="


def _get_lexer(file_path: Path):
def get_lexer(file_path: Path):
try:
lexer: Lexer = get_lexer_for_filename(file_path)
except ClassNotFound:
Expand Down Expand Up @@ -64,7 +64,7 @@ def __attrs_post_init__(self):
ctx = SESSION_CONTEXT.get()

self.line_number_buffer = get_line_number_buffer(self.file_lines)
self.lexer = _get_lexer(self.file_name)
self.lexer = get_lexer(self.file_name)

if self.file_name.is_absolute():
self.file_name = get_relative_path(self.file_name, ctx.cwd)
Expand Down Expand Up @@ -191,9 +191,9 @@ def get_removed_lines(
)


def highlight_text(display_information: DisplayInformation, text: str) -> str:
def highlight_text(text: str, lexer: Lexer) -> str:
# pygments doesn't have type hints on TerminalFormatter
return highlight(text, display_information.lexer, TerminalFormatter(bg="dark")) # type: ignore
return highlight(text, lexer, TerminalFormatter(bg="dark")) # type: ignore


def get_previous_lines(
Expand Down Expand Up @@ -223,7 +223,7 @@ def get_previous_lines(
]

prev = "\n".join(numbered)
return highlight_text(display_information, prev)
return highlight_text(prev, display_information.lexer)


def get_later_lines(
Expand Down Expand Up @@ -253,4 +253,4 @@ def get_later_lines(
]

later = "\n".join(numbered)
return highlight_text(display_information, later)
return highlight_text(later, display_information.lexer)
2 changes: 1 addition & 1 deletion mentat/parsers/file_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ async def filter_replacements(

if not self.is_creation:
new_replacements = list[Replacement]()
for replacement in self.replacements:
for replacement in sorted(self.replacements):
self._display_replacement(replacement, file_lines)
if await _ask_user_change("Keep this change?"):
new_replacements.append(replacement)
Expand Down
2 changes: 1 addition & 1 deletion mentat/parsers/unified_diff_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _code_line_content(
elif cur_line.startswith("-"):
return colored(content, "red")
else:
return highlight_text(display_information, content)
return highlight_text(content, display_information.lexer)

@override
def _could_be_special(self, cur_line: str) -> bool:
Expand Down
32 changes: 32 additions & 0 deletions mentat/resources/prompts/revisor_prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
You are part of an automated coding system. Your responses must follow the required format so they can be parsed programmaticaly.
You will be given a unified diff of a recent change made to a code file. Your job is to determine if the change made is syntactically correct, and if it is not, to modify the diff so that it is.
If you are not changing the diff, output an exact copy of the git diff! Do not output anything besides the modified diff or your output will not be parsed correctly!
Additionally, you will be provided with a variety of code files relevant to the diff, as well as the user request that this diff addresses.
Do **NOT** wrap your response in a ```diff tag or it will not be parsed correctly!!!

Example Input:

Code Files:

hello_world.py
1:def hello_world():
2: pass

User Request:
Implement the hello_world function.

Diff:
---
+++
@@ -1,4 +1,4 @@
def hello_world():
- pass
+ print("Hello, World!

Example Output:
---
+++
@@ -1,4 +1,4 @@
def hello_world():
- pass
+ print("Hello, World!")
4 changes: 4 additions & 0 deletions mentat/resources/templates/css/transcript.css
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ pre {
background-color: rgb(255, 197, 197);
}

.revisor {
background-color: rgb(243, 246, 196);
}

.button-group {
position: absolute;
top: 20px;
Expand Down
Empty file added mentat/revisor/__init__.py
Empty file.
147 changes: 147 additions & 0 deletions mentat/revisor/revisor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import difflib
from pathlib import Path
from typing import List

from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
)
from termcolor import colored

from mentat.errors import MentatError
from mentat.llm_api_handler import prompt_tokens
from mentat.parsers.change_display_helper import (
change_delimiter,
get_lexer,
highlight_text,
)
from mentat.parsers.file_edit import FileEdit
from mentat.parsers.git_parser import GitParser
from mentat.prompts.prompts import read_prompt
from mentat.session_context import SESSION_CONTEXT
from mentat.transcripts import ModelMessage
from mentat.utils import get_relative_path

revisor_prompt_filename = Path("revisor_prompt.txt")
revisor_prompt = read_prompt(revisor_prompt_filename)


def _get_stored_lines(file_edit: FileEdit) -> List[str]:
ctx = SESSION_CONTEXT.get()

if file_edit.is_creation:
return []
else:
return ctx.code_file_manager.file_lines[file_edit.file_path].copy()


def _file_edit_diff(file_edit: FileEdit) -> str:
stored_lines = _get_stored_lines(file_edit)
new_lines = file_edit.get_updated_file_lines(stored_lines)
diff = list(difflib.unified_diff(stored_lines, new_lines, lineterm=""))
return "\n".join(diff)


async def revise_edit(file_edit: FileEdit):
ctx = SESSION_CONTEXT.get()

# No point in revising deletion edits
if file_edit.is_deletion:
return
diff = _file_edit_diff(file_edit)
# There should always be a user_message by the time we're revising
user_message = list(
filter(
lambda message: message["role"] == "user",
ctx.conversation.get_messages(),
)
)[-1]
user_message["content"] = f"User Request:\n{user_message['content']}"
messages: List[ChatCompletionMessageParam] = [
ChatCompletionSystemMessageParam(content=revisor_prompt, role="system"),
user_message,
ChatCompletionSystemMessageParam(content=f"Diff:\n{diff}", role="system"),
]
code_message = await ctx.code_context.get_code_message(
prompt_tokens(messages, ctx.config.model)
)
messages.insert(
1, ChatCompletionSystemMessageParam(content=code_message, role="system")
)

ctx.stream.send(
"\nRevising edits for file"
f" {get_relative_path(file_edit.file_path, ctx.cwd)}...",
style="info",
)
response = await ctx.llm_api_handler.call_llm_api(
messages, model=ctx.config.model, stream=False
)
message = response.choices[0].message.content or ""
messages.append(
ChatCompletionAssistantMessageParam(content=message, role="assistant")
)
ctx.conversation.add_transcript_message(
ModelMessage(message=message, prior_messages=messages, message_type="revisor")
)

# Sometimes the model wraps response in a ```diff ``` block
# I believe new prompt fixes this but this makes sure it never interferes
if message.startswith("```diff\n"):
message = message[8:]
if message.endswith("\n```"):
message = message[:-4]

# This makes it more similar to a git diff so that we can use the pre existing git diff parser
message = "\n".join(message.split("\n")[2:]) # remove leading +++ and ---
post_diff = (
"diff --git a/file b/file\nindex 0000000..0000000\n--- a/file\n+++"
f" b/file\n{message}"
)
parsed_response = GitParser().parse_string(post_diff)

# Only modify the replacements of the current file edit
# (the new file edit doesn't know about file creation or renaming)
# Additionally, since we do this one at a time there should only ever be 1 file edit.
if parsed_response.file_edits:
stored_lines = _get_stored_lines(file_edit)
pre_lines = file_edit.get_updated_file_lines(stored_lines)
file_edit.replacements = parsed_response.file_edits[0].replacements
post_lines = file_edit.get_updated_file_lines(stored_lines)

diff_lines = difflib.unified_diff(pre_lines, post_lines, lineterm="")
diff_diff: List[str] = []
lexer = get_lexer(file_edit.file_path)
for line in diff_lines:
if line.startswith("---"):
diff_diff.append(f"{line}{file_edit.file_path}")
elif line.startswith("+++"):
new_name = (
file_edit.rename_file_path
if file_edit.rename_file_path is not None
else file_edit.file_path
)
diff_diff.append(f"{line}{new_name}")
elif line.startswith("@@"):
diff_diff.append(line)
elif line.startswith("+"):
diff_diff.append(colored(line, "green"))
elif line.startswith("-"):
diff_diff.append(colored(line, "red"))
elif line.startswith(" "):
diff_diff.append(highlight_text(line, lexer))
else:
raise MentatError("Invalid Diff")
if diff_diff:
ctx.stream.send("Revision diff:", style="info")
ctx.stream.send(change_delimiter)
ctx.stream.send("\n".join(diff_diff))
ctx.stream.send(change_delimiter)
ctx.cost_tracker.display_last_api_call()


async def revise_edits(file_edits: List[FileEdit]):
for file_edit in file_edits:
# We could do all edits asynchronously; of course, this risks getting rate limited and probably not worth effort
await revise_edit(file_edit)
8 changes: 6 additions & 2 deletions mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from mentat.git_handler import get_git_root_for_path
from mentat.llm_api_handler import LlmApiHandler, is_test_environment
from mentat.logging_config import setup_logging
from mentat.revisor.revisor import revise_edits
from mentat.sampler.sampler import Sampler
from mentat.sentry import sentry_init
from mentat.session_context import SESSION_CONTEXT, SessionContext
Expand Down Expand Up @@ -171,13 +172,16 @@ async def _main(self):
for file_edit in parsed_llm_response.file_edits
if file_edit.is_valid()
]
for file_edit in file_edits:
file_edit.resolve_conflicts()
if file_edits:
if session_context.config.revisor:
await revise_edits(file_edits)

if not agent_handler.agent_enabled:
file_edits, need_user_request = (
await get_user_feedback_on_edits(file_edits)
)
for file_edit in file_edits:
file_edit.resolve_conflicts()

if session_context.sampler and session_context.sampler.active:
try:
Expand Down
56 changes: 56 additions & 0 deletions tests/revisor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from pathlib import Path
from textwrap import dedent

import pytest

from mentat.parsers.file_edit import FileEdit, Replacement
from mentat.revisor.revisor import revise_edit


@pytest.mark.asyncio
async def test_revision(mock_session_context, mock_call_llm_api):
file_name = Path("file").resolve()
mock_session_context.conversation.add_user_message("User Request")
mock_session_context.code_file_manager.file_lines[file_name] = dedent("""\
def hello_world():
pass
hello_world(
""").split("\n")

mock_call_llm_api.set_unstreamed_values(dedent("""\
---
+++
@@ -1,5 +1,5 @@
def hello_world():
- pass
+ print("Hello, World!")
-hello_world(
+hello_world()
"""))

replacement_text = dedent("""\
print("Hello, World!
hello_world()""").split("\n")
file_edit = FileEdit(file_name, [Replacement(1, 4, replacement_text)], False, False)
await revise_edit(file_edit)
assert "\n".join(
file_edit.get_updated_file_lines(
mock_session_context.code_file_manager.file_lines[file_name]
)
) == dedent(
"""\
def hello_world():
print("Hello, World!")
hello_world()"""
)


@pytest.mark.asyncio
async def test_skip_deletion(mock_session_context, mock_call_llm_api):
file_name = Path("file").resolve()
mock_session_context.code_file_manager.file_lines[file_name] = []

# This will error if not deletion
file_edit = FileEdit(file_name, [Replacement(1, 4, [])], False, True)
await revise_edit(file_edit)
assert file_edit.replacements == [Replacement(1, 4, [])]

0 comments on commit f02559a

Please sign in to comment.