-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improvements to OverleafGitPaperRemote (#25)
* add test fixture for OverleafGitPaperRemote * testing that git can resolve non-conflicting edits using merge (partially addressing #19) * testing that we recover gracefully from merge conflicts and other kinds of bad edits * better git diff handling to test for recent human edits * handling edits' revision_id in OverleafGitPaperRemote. Snazzy `with paper.rewind(commit_id)` syntax
- Loading branch information
Showing
6 changed files
with
533 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,8 +7,12 @@ class OpenAIConfig(BaseSettings): | |
|
||
|
||
class OverleafConfig(BaseSettings): | ||
# Username and password for logging into your overleaf account | ||
username: str = "###" | ||
password: str = "###" | ||
# Author name and email that should appear in git history | ||
git_name: str = "AI assistant" | ||
git_email: str = "[email protected]" | ||
|
||
|
||
class Settings(BaseSettings): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,15 +8,47 @@ | |
import shutil | ||
import datetime | ||
from urllib.parse import quote | ||
from git import Repo # type: ignore | ||
|
||
from llm4papers.models import EditTrigger, EditResult, EditType, DocumentID, RevisionID | ||
from git import Repo, GitCommandError # type: ignore | ||
from typing import Iterable | ||
import re | ||
|
||
from llm4papers.models import ( | ||
EditTrigger, | ||
EditResult, | ||
EditType, | ||
DocumentID, | ||
RevisionID, | ||
LineRange, | ||
) | ||
from llm4papers.paper_remote.MultiDocumentPaperRemote import MultiDocumentPaperRemote | ||
from llm4papers.logger import logger | ||
|
||
|
||
diff_line_edit_re = re.compile( | ||
r"@{2,}\s*-(?P<old_line>\d+),(?P<old_count>\d+)\s*\+(?P<new_line>\d+),(?P<new_count>\d+)\s*@{2,}" | ||
) | ||
|
||
|
||
def _diff_to_ranges(diff: str) -> Iterable[LineRange]: | ||
"""Given a git diff, return LineRange object(s) indicating which lines in the | ||
original document were changed. | ||
""" | ||
for match in diff_line_edit_re.finditer(diff): | ||
git_line_number = int(match.group("new_line")) | ||
git_line_count = int(match.group("new_count")) | ||
# Git counts from 1 and gives (start, length), inclusive. LineRange counts from | ||
# 0 and gives start:end indices (exclusive). | ||
zero_index_start = git_line_number - 1 | ||
yield zero_index_start, zero_index_start + git_line_count | ||
|
||
|
||
def _ranges_overlap(a: LineRange, b: LineRange) -> bool: | ||
"""Given two LineRanges, return True if they overlap, False otherwise.""" | ||
return not (a[1] < b[0] or b[1] < a[0]) | ||
|
||
|
||
def _too_close_to_human_edits( | ||
repo: Repo, filename: str, line_number: int, last_n: int = 2 | ||
repo: Repo, filename: str, line_range: LineRange, last_n: int = 2 | ||
) -> bool: | ||
""" | ||
Determine if the line `line_number` of the file `filename` was changed in | ||
|
@@ -41,22 +73,19 @@ def _too_close_to_human_edits( | |
logger.info(f"Last commit was {sec_since_last_commit}s ago, approving edit.") | ||
return False | ||
|
||
# Get the diff for HEAD~n: | ||
# Get the diff for HEAD~n. Note that the gitpython DiffIndex and Diff objects | ||
# drop the line number info (!) so we can't use the gitpython object-oriented API | ||
# to do this. Calling repo.git.diff is pretty much a direct pass-through to | ||
# running "git diff HEAD~n -- <filename>" on the command line. | ||
total_diff = repo.git.diff(f"HEAD~{last_n}", filename, unified=0) | ||
|
||
# Get the current repo state of that line: | ||
current_line = repo.git.show(f"HEAD:{filename}").split("\n")[line_number] | ||
|
||
logger.debug("Diff: " + total_diff) | ||
logger.debug("Current line: " + current_line) | ||
|
||
# Match the line in the diff: | ||
if current_line in total_diff: | ||
logger.info( | ||
f"Found current line ({current_line[:10]}...) in diff, rejecting edit." | ||
) | ||
return True | ||
|
||
for git_line_range in _diff_to_ranges(total_diff): | ||
if _ranges_overlap(git_line_range, line_range): | ||
logger.info( | ||
f"Line range {line_range} overlaps with git-edited {git_line_range}, " | ||
f"rejecting edit." | ||
) | ||
return True | ||
return False | ||
|
||
|
||
|
@@ -77,15 +106,28 @@ def _add_auth(uri: str): | |
return uri | ||
|
||
|
||
def _add_git_user_from_config(repo: Repo) -> None: | ||
try: | ||
from llm4papers.config import OverleafConfig | ||
|
||
config = OverleafConfig() | ||
repo.config_writer().set_value("user", "name", config.git_name).release() | ||
repo.config_writer().set_value("user", "email", config.git_email).release() | ||
except ImportError: | ||
logger.debug("No config file found, assuming public repo.") | ||
repo.config_writer().set_value("user", "name", "no-config").release() | ||
repo.config_writer().set_value( | ||
"user", "email", "[email protected]" | ||
).release() | ||
|
||
|
||
class OverleafGitPaperRemote(MultiDocumentPaperRemote): | ||
""" | ||
Overleaf exposes a git remote for each project. This class handles reading | ||
and writing to Overleaf documents using gitpython, and implements the | ||
PaperRemote protocol for use by the AI editor. | ||
""" | ||
|
||
current_revision_id: RevisionID | ||
|
||
def __init__(self, git_cached_repo: str): | ||
""" | ||
Saves the git repo to a local temporary directory using gitpython. | ||
|
@@ -100,6 +142,10 @@ def __init__(self, git_cached_repo: str): | |
self._cached_repo: Repo | None = None | ||
self.refresh() | ||
|
||
@property | ||
def current_revision_id(self) -> RevisionID: | ||
return self._get_repo().head.commit.hexsha | ||
|
||
def _get_repo(self) -> Repo: | ||
if self._cached_repo is None: | ||
# TODO - this makes me anxious about race conditions. every time we refresh, | ||
|
@@ -119,7 +165,7 @@ def _doc_id_to_path(self, doc_id: DocumentID) -> pathlib.Path: | |
# so we can cast to a string on this next line: | ||
return pathlib.Path(git_root) / str(doc_id) | ||
|
||
def refresh(self): | ||
def refresh(self, retry: bool = True): | ||
""" | ||
This is a fallback method (that likely needs some love) to ensure that | ||
the repo is up to date with the latest upstream changes. | ||
|
@@ -134,6 +180,7 @@ def refresh(self): | |
) | ||
|
||
self._cached_repo = Repo(f"/tmp/{self._reposlug}") | ||
_add_git_user_from_config(self._cached_repo) | ||
|
||
logger.info(f"Pulling latest from repo {self._reposlug}.") | ||
try: | ||
|
@@ -143,15 +190,14 @@ def refresh(self): | |
f"Latest change at {self._get_repo().head.commit.committed_datetime}" | ||
) | ||
logger.info(f"Repo dirty: {self._get_repo().is_dirty()}") | ||
self.current_revision_id = self._get_repo().head.commit.hexsha | ||
try: | ||
self._get_repo().git.stash("pop") | ||
except Exception as e: | ||
except GitCommandError as e: | ||
# TODO: this just means there was nothing to pop, but | ||
# we should handle this more gracefully. | ||
logger.debug(f"Nothing to pop: {e}") | ||
pass | ||
except Exception as e: | ||
except GitCommandError as e: | ||
logger.error( | ||
f"Error pulling from repo {self._reposlug}: {e}. " | ||
"Falling back on DESTRUCTION!!!" | ||
|
@@ -161,7 +207,10 @@ def refresh(self): | |
self._cached_repo = None | ||
# recursively delete the repo | ||
shutil.rmtree(f"/tmp/{self._reposlug}") | ||
self.refresh() | ||
if retry: | ||
self.refresh(retry=False) | ||
else: | ||
raise e | ||
|
||
def list_doc_ids(self) -> list[DocumentID]: | ||
""" | ||
|
@@ -196,14 +245,15 @@ def is_edit_ok(self, edit: EditTrigger) -> bool: | |
# want to wait for the user to move on to the next line. | ||
for doc_range in edit.input_ranges + edit.output_ranges: | ||
repo_scoped_file = str(self._doc_id_to_path(doc_range.doc_id)) | ||
for i in range(doc_range.selection[0], doc_range.selection[1]): | ||
if _too_close_to_human_edits(self._get_repo(), repo_scoped_file, i): | ||
logger.info( | ||
f"Temporarily skipping edit request in {doc_range.doc_id}" | ||
" at line {i} because it was still in progress" | ||
" in the last commit." | ||
) | ||
return False | ||
if _too_close_to_human_edits( | ||
self._get_repo(), repo_scoped_file, doc_range.selection | ||
): | ||
logger.info( | ||
f"Temporarily skipping edit request in {doc_range.doc_id}" | ||
" at line {i} because it was still in progress" | ||
" in the last commit." | ||
) | ||
return False | ||
return True | ||
|
||
def to_dict(self): | ||
|
@@ -221,27 +271,30 @@ def perform_edit(self, edit: EditResult) -> bool: | |
Returns: | ||
True if the edit was successful, False otherwise | ||
""" | ||
if not self._doc_id_to_path(edit.range.doc_id).exists(): | ||
logger.error(f"Document {edit.range.doc_id} does not exist.") | ||
return False | ||
|
||
logger.info(f"Performing edit {edit} on remote {self._reposlug}") | ||
|
||
if edit.type == EditType.replace: | ||
success = self._perform_replace(edit) | ||
elif edit.type == EditType.comment: | ||
success = self._perform_comment(edit) | ||
else: | ||
raise ValueError(f"Unknown edit type {edit.type}") | ||
try: | ||
with self.rewind(edit.range.revision_id, message="AI edit") as paper: | ||
if edit.type == EditType.replace: | ||
success = paper._perform_replace(edit) | ||
elif edit.type == EditType.comment: | ||
success = paper._perform_comment(edit) | ||
else: | ||
raise ValueError(f"Unknown edit type {edit.type}") | ||
except GitCommandError as e: | ||
logger.error( | ||
f"Git error performing edit {edit} on remote {self._reposlug}: {e}" | ||
) | ||
success = False | ||
|
||
if success: | ||
# TODO - apply edit relative to the edit.range.revision_id commit and then | ||
# rebase onto HEAD for poor-man's operational transforms | ||
self._get_repo().index.add([self._doc_id_to_path(str(edit.range.doc_id))]) | ||
self._get_repo().index.commit("AI edit completed.") | ||
# Instead of just pushing, we need to rebase and then push. | ||
# This is because we want to make sure that the AI edits are always | ||
# on top of the stack. | ||
self._get_repo().git.pull() | ||
# TODO: We could do a better job catching WARNs here and then maybe setting | ||
# success = False | ||
self._get_repo().git.push() | ||
else: | ||
self.refresh() | ||
|
||
return success | ||
|
||
|
@@ -257,6 +310,19 @@ def _perform_replace(self, edit: EditResult) -> bool: | |
""" | ||
doc_range, text = edit.range, edit.content | ||
try: | ||
num_lines = len(self.get_lines(doc_range.doc_id)) | ||
if ( | ||
any(i < 0 for i in doc_range.selection) | ||
or doc_range.selection[1] < doc_range.selection[0] | ||
or any( | ||
i > len(self.get_lines(doc_range.doc_id)) | ||
for i in doc_range.selection | ||
) | ||
): | ||
raise IndexError( | ||
f"Invalid selection {doc_range.selection} for document " | ||
f"{doc_range.doc_id} with {num_lines} lines." | ||
) | ||
lines = self.get_lines(doc_range.doc_id) | ||
lines = ( | ||
lines[: doc_range.selection[0]] | ||
|
@@ -284,3 +350,49 @@ def _perform_comment(self, edit: EditResult) -> bool: | |
# TODO - implement this for real | ||
logger.info(f"Performing comment edit {edit} on remote {self._reposlug}") | ||
return True | ||
|
||
def rewind(self, commit: str, message: str): | ||
return self.RewindContext(self, commit, message) | ||
|
||
# Create an inner class for "with" semantics so that we can do | ||
# `with remote.rewind(commit)` to rewind to a particular commit and play some edits | ||
# onto it, then merge when the 'with' context exits. | ||
class RewindContext: | ||
# TODO - there are tricks in gitpython where an IndexFile can be used to | ||
# handle changes to files in-memory without having to call checkout() and | ||
# (briefly) modify the state of things on disk. This would be an improvement, | ||
# but would require using the gitpython API more directly inside of | ||
# perform_edit, such as calling git.IndexFile.write() instead of python's | ||
# open() and write() | ||
|
||
def __init__(self, remote: "OverleafGitPaperRemote", commit: str, message: str): | ||
self._remote = remote | ||
self._message = message | ||
self._rewind_commit = commit | ||
|
||
def __enter__(self): | ||
repo = self._remote._get_repo() | ||
self._restore_ref = repo.head.ref | ||
self._new_branch = repo.create_head( | ||
"tmp-edit-branch", commit=self._rewind_commit | ||
) | ||
self._new_branch.checkout() | ||
return self._remote | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
repo = self._remote._get_repo() | ||
assert ( | ||
repo.active_branch == self._new_branch | ||
), "Branch changed unexpectedly mid-`with`" | ||
# Add files that changed | ||
repo.index.add([_file for (_file, _), _ in repo.index.entries.items()]) | ||
repo.index.commit(self._message) | ||
self._restore_ref.checkout() | ||
try: | ||
repo.git.merge("tmp-edit-branch") | ||
except GitCommandError as e: | ||
# Hard reset on failure | ||
repo.git.reset("--hard", self._restore_ref.commit.hexsha) | ||
raise e | ||
finally: | ||
repo.delete_head(self._new_branch, force=True) |
Oops, something went wrong.