This repository has been archived by the owner on Jan 7, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 243
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
268 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
You are part of an automated coding system. Your responses must follow the required format so they can be parsed programmaticaly. | ||
You will be given a unified diff of a recent change made to a code file. Your job is to determine if the change made is syntactically correct, and if it is not, to modify the diff so that it is. | ||
If you are not changing the diff, output an exact copy of the git diff! Do not output anything besides the modified diff or your output will not be parsed correctly! | ||
Additionally, you will be provided with a variety of code files relevant to the diff, as well as the user request that this diff addresses. | ||
Do **NOT** wrap your response in a ```diff tag or it will not be parsed correctly!!! | ||
|
||
Example Input: | ||
|
||
Code Files: | ||
|
||
hello_world.py | ||
1:def hello_world(): | ||
2: pass | ||
|
||
User Request: | ||
Implement the hello_world function. | ||
|
||
Diff: | ||
--- | ||
+++ | ||
@@ -1,4 +1,4 @@ | ||
def hello_world(): | ||
- pass | ||
+ print("Hello, World! | ||
|
||
Example Output: | ||
--- | ||
+++ | ||
@@ -1,4 +1,4 @@ | ||
def hello_world(): | ||
- pass | ||
+ print("Hello, World!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import difflib | ||
from pathlib import Path | ||
from typing import List | ||
|
||
from openai.types.chat import ( | ||
ChatCompletionAssistantMessageParam, | ||
ChatCompletionMessageParam, | ||
ChatCompletionSystemMessageParam, | ||
) | ||
from termcolor import colored | ||
|
||
from mentat.errors import MentatError | ||
from mentat.llm_api_handler import prompt_tokens | ||
from mentat.parsers.change_display_helper import ( | ||
change_delimiter, | ||
get_lexer, | ||
highlight_text, | ||
) | ||
from mentat.parsers.file_edit import FileEdit | ||
from mentat.parsers.git_parser import GitParser | ||
from mentat.prompts.prompts import read_prompt | ||
from mentat.session_context import SESSION_CONTEXT | ||
from mentat.transcripts import ModelMessage | ||
from mentat.utils import get_relative_path | ||
|
||
revisor_prompt_filename = Path("revisor_prompt.txt") | ||
revisor_prompt = read_prompt(revisor_prompt_filename) | ||
|
||
|
||
def _get_stored_lines(file_edit: FileEdit) -> List[str]: | ||
ctx = SESSION_CONTEXT.get() | ||
|
||
if file_edit.is_creation: | ||
return [] | ||
else: | ||
return ctx.code_file_manager.file_lines[file_edit.file_path].copy() | ||
|
||
|
||
def _file_edit_diff(file_edit: FileEdit) -> str: | ||
stored_lines = _get_stored_lines(file_edit) | ||
new_lines = file_edit.get_updated_file_lines(stored_lines) | ||
diff = list(difflib.unified_diff(stored_lines, new_lines, lineterm="")) | ||
return "\n".join(diff) | ||
|
||
|
||
async def revise_edit(file_edit: FileEdit): | ||
ctx = SESSION_CONTEXT.get() | ||
|
||
# No point in revising deletion edits | ||
if file_edit.is_deletion: | ||
return | ||
diff = _file_edit_diff(file_edit) | ||
# There should always be a user_message by the time we're revising | ||
user_message = list( | ||
filter( | ||
lambda message: message["role"] == "user", | ||
ctx.conversation.get_messages(), | ||
) | ||
)[-1] | ||
user_message["content"] = f"User Request:\n{user_message['content']}" | ||
messages: List[ChatCompletionMessageParam] = [ | ||
ChatCompletionSystemMessageParam(content=revisor_prompt, role="system"), | ||
user_message, | ||
ChatCompletionSystemMessageParam(content=f"Diff:\n{diff}", role="system"), | ||
] | ||
code_message = await ctx.code_context.get_code_message( | ||
prompt_tokens(messages, ctx.config.model) | ||
) | ||
messages.insert( | ||
1, ChatCompletionSystemMessageParam(content=code_message, role="system") | ||
) | ||
|
||
ctx.stream.send( | ||
"\nRevising edits for file" | ||
f" {get_relative_path(file_edit.file_path, ctx.cwd)}...", | ||
style="info", | ||
) | ||
response = await ctx.llm_api_handler.call_llm_api( | ||
messages, model=ctx.config.model, stream=False | ||
) | ||
message = response.choices[0].message.content or "" | ||
messages.append( | ||
ChatCompletionAssistantMessageParam(content=message, role="assistant") | ||
) | ||
ctx.conversation.add_transcript_message( | ||
ModelMessage(message=message, prior_messages=messages, message_type="revisor") | ||
) | ||
|
||
# Sometimes the model wraps response in a ```diff ``` block | ||
# I believe new prompt fixes this but this makes sure it never interferes | ||
if message.startswith("```diff\n"): | ||
message = message[8:] | ||
if message.endswith("\n```"): | ||
message = message[:-4] | ||
|
||
# This makes it more similar to a git diff so that we can use the pre existing git diff parser | ||
message = "\n".join(message.split("\n")[2:]) # remove leading +++ and --- | ||
post_diff = ( | ||
"diff --git a/file b/file\nindex 0000000..0000000\n--- a/file\n+++" | ||
f" b/file\n{message}" | ||
) | ||
parsed_response = GitParser().parse_string(post_diff) | ||
|
||
# Only modify the replacements of the current file edit | ||
# (the new file edit doesn't know about file creation or renaming) | ||
# Additionally, since we do this one at a time there should only ever be 1 file edit. | ||
if parsed_response.file_edits: | ||
stored_lines = _get_stored_lines(file_edit) | ||
pre_lines = file_edit.get_updated_file_lines(stored_lines) | ||
file_edit.replacements = parsed_response.file_edits[0].replacements | ||
post_lines = file_edit.get_updated_file_lines(stored_lines) | ||
|
||
diff_lines = difflib.unified_diff(pre_lines, post_lines, lineterm="") | ||
diff_diff: List[str] = [] | ||
lexer = get_lexer(file_edit.file_path) | ||
for line in diff_lines: | ||
if line.startswith("---"): | ||
diff_diff.append(f"{line}{file_edit.file_path}") | ||
elif line.startswith("+++"): | ||
new_name = ( | ||
file_edit.rename_file_path | ||
if file_edit.rename_file_path is not None | ||
else file_edit.file_path | ||
) | ||
diff_diff.append(f"{line}{new_name}") | ||
elif line.startswith("@@"): | ||
diff_diff.append(line) | ||
elif line.startswith("+"): | ||
diff_diff.append(colored(line, "green")) | ||
elif line.startswith("-"): | ||
diff_diff.append(colored(line, "red")) | ||
elif line.startswith(" "): | ||
diff_diff.append(highlight_text(line, lexer)) | ||
else: | ||
raise MentatError("Invalid Diff") | ||
if diff_diff: | ||
ctx.stream.send("Revision diff:", style="info") | ||
ctx.stream.send(change_delimiter) | ||
ctx.stream.send("\n".join(diff_diff)) | ||
ctx.stream.send(change_delimiter) | ||
ctx.cost_tracker.display_last_api_call() | ||
|
||
|
||
async def revise_edits(file_edits: List[FileEdit]): | ||
for file_edit in file_edits: | ||
# We could do all edits asynchronously; of course, this risks getting rate limited and probably not worth effort | ||
await revise_edit(file_edit) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from pathlib import Path | ||
from textwrap import dedent | ||
|
||
import pytest | ||
|
||
from mentat.parsers.file_edit import FileEdit, Replacement | ||
from mentat.revisor.revisor import revise_edit | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_revision(mock_session_context, mock_call_llm_api): | ||
file_name = Path("file").resolve() | ||
mock_session_context.conversation.add_user_message("User Request") | ||
mock_session_context.code_file_manager.file_lines[file_name] = dedent("""\ | ||
def hello_world(): | ||
pass | ||
hello_world( | ||
""").split("\n") | ||
|
||
mock_call_llm_api.set_unstreamed_values(dedent("""\ | ||
--- | ||
+++ | ||
@@ -1,5 +1,5 @@ | ||
def hello_world(): | ||
- pass | ||
+ print("Hello, World!") | ||
-hello_world( | ||
+hello_world() | ||
""")) | ||
|
||
replacement_text = dedent("""\ | ||
print("Hello, World! | ||
hello_world()""").split("\n") | ||
file_edit = FileEdit(file_name, [Replacement(1, 4, replacement_text)], False, False) | ||
await revise_edit(file_edit) | ||
assert "\n".join( | ||
file_edit.get_updated_file_lines( | ||
mock_session_context.code_file_manager.file_lines[file_name] | ||
) | ||
) == dedent( | ||
"""\ | ||
def hello_world(): | ||
print("Hello, World!") | ||
hello_world()""" | ||
) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_skip_deletion(mock_session_context, mock_call_llm_api): | ||
file_name = Path("file").resolve() | ||
mock_session_context.code_file_manager.file_lines[file_name] = [] | ||
|
||
# This will error if not deletion | ||
file_edit = FileEdit(file_name, [Replacement(1, 4, [])], False, True) | ||
await revise_edit(file_edit) | ||
assert file_edit.replacements == [Replacement(1, 4, [])] |