Skip to content

Commit

Permalink
Improvements to OverleafGitPaperRemote (#25)
Browse files Browse the repository at this point in the history
* add test fixture for OverleafGitPaperRemote
    * testing that git can resolve non-conflicting edits using merge (partially addressing #19)
    * testing that we recover gracefully from merge conflicts and other kinds of bad edits
* better git diff handling to test for recent human edits
* handling edits' revision_id in OverleafGitPaperRemote. Snazzy `with paper.rewind(commit_id)` syntax
  • Loading branch information
wrongu authored Jun 22, 2023
1 parent a8eb7c5 commit 7d54b98
Show file tree
Hide file tree
Showing 6 changed files with 533 additions and 49 deletions.
4 changes: 4 additions & 0 deletions llm4papers/config.example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ class OpenAIConfig(BaseSettings):


class OverleafConfig(BaseSettings):
# Username and password for logging into your overleaf account
username: str = "###"
password: str = "###"
# Author name and email that should appear in git history
git_name: str = "AI assistant"
git_email: str = "[email protected]"


class Settings(BaseSettings):
Expand Down
210 changes: 161 additions & 49 deletions llm4papers/paper_remote/OverleafGitPaperRemote.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,47 @@
import shutil
import datetime
from urllib.parse import quote
from git import Repo # type: ignore

from llm4papers.models import EditTrigger, EditResult, EditType, DocumentID, RevisionID
from git import Repo, GitCommandError # type: ignore
from typing import Iterable
import re

from llm4papers.models import (
EditTrigger,
EditResult,
EditType,
DocumentID,
RevisionID,
LineRange,
)
from llm4papers.paper_remote.MultiDocumentPaperRemote import MultiDocumentPaperRemote
from llm4papers.logger import logger


diff_line_edit_re = re.compile(
r"@{2,}\s*-(?P<old_line>\d+),(?P<old_count>\d+)\s*\+(?P<new_line>\d+),(?P<new_count>\d+)\s*@{2,}"
)


def _diff_to_ranges(diff: str) -> Iterable[LineRange]:
"""Given a git diff, return LineRange object(s) indicating which lines in the
original document were changed.
"""
for match in diff_line_edit_re.finditer(diff):
git_line_number = int(match.group("new_line"))
git_line_count = int(match.group("new_count"))
# Git counts from 1 and gives (start, length), inclusive. LineRange counts from
# 0 and gives start:end indices (exclusive).
zero_index_start = git_line_number - 1
yield zero_index_start, zero_index_start + git_line_count


def _ranges_overlap(a: LineRange, b: LineRange) -> bool:
"""Given two LineRanges, return True if they overlap, False otherwise."""
return not (a[1] < b[0] or b[1] < a[0])


def _too_close_to_human_edits(
repo: Repo, filename: str, line_number: int, last_n: int = 2
repo: Repo, filename: str, line_range: LineRange, last_n: int = 2
) -> bool:
"""
Determine if the line `line_number` of the file `filename` was changed in
Expand All @@ -41,22 +73,19 @@ def _too_close_to_human_edits(
logger.info(f"Last commit was {sec_since_last_commit}s ago, approving edit.")
return False

# Get the diff for HEAD~n:
# Get the diff for HEAD~n. Note that the gitpython DiffIndex and Diff objects
# drop the line number info (!) so we can't use the gitpython object-oriented API
# to do this. Calling repo.git.diff is pretty much a direct pass-through to
# running "git diff HEAD~n -- <filename>" on the command line.
total_diff = repo.git.diff(f"HEAD~{last_n}", filename, unified=0)

# Get the current repo state of that line:
current_line = repo.git.show(f"HEAD:{filename}").split("\n")[line_number]

logger.debug("Diff: " + total_diff)
logger.debug("Current line: " + current_line)

# Match the line in the diff:
if current_line in total_diff:
logger.info(
f"Found current line ({current_line[:10]}...) in diff, rejecting edit."
)
return True

for git_line_range in _diff_to_ranges(total_diff):
if _ranges_overlap(git_line_range, line_range):
logger.info(
f"Line range {line_range} overlaps with git-edited {git_line_range}, "
f"rejecting edit."
)
return True
return False


Expand All @@ -77,15 +106,28 @@ def _add_auth(uri: str):
return uri


def _add_git_user_from_config(repo: Repo) -> None:
try:
from llm4papers.config import OverleafConfig

config = OverleafConfig()
repo.config_writer().set_value("user", "name", config.git_name).release()
repo.config_writer().set_value("user", "email", config.git_email).release()
except ImportError:
logger.debug("No config file found, assuming public repo.")
repo.config_writer().set_value("user", "name", "no-config").release()
repo.config_writer().set_value(
"user", "email", "[email protected]"
).release()


class OverleafGitPaperRemote(MultiDocumentPaperRemote):
"""
Overleaf exposes a git remote for each project. This class handles reading
and writing to Overleaf documents using gitpython, and implements the
PaperRemote protocol for use by the AI editor.
"""

current_revision_id: RevisionID

def __init__(self, git_cached_repo: str):
"""
Saves the git repo to a local temporary directory using gitpython.
Expand All @@ -100,6 +142,10 @@ def __init__(self, git_cached_repo: str):
self._cached_repo: Repo | None = None
self.refresh()

@property
def current_revision_id(self) -> RevisionID:
return self._get_repo().head.commit.hexsha

def _get_repo(self) -> Repo:
if self._cached_repo is None:
# TODO - this makes me anxious about race conditions. every time we refresh,
Expand All @@ -119,7 +165,7 @@ def _doc_id_to_path(self, doc_id: DocumentID) -> pathlib.Path:
# so we can cast to a string on this next line:
return pathlib.Path(git_root) / str(doc_id)

def refresh(self):
def refresh(self, retry: bool = True):
"""
This is a fallback method (that likely needs some love) to ensure that
the repo is up to date with the latest upstream changes.
Expand All @@ -134,6 +180,7 @@ def refresh(self):
)

self._cached_repo = Repo(f"/tmp/{self._reposlug}")
_add_git_user_from_config(self._cached_repo)

logger.info(f"Pulling latest from repo {self._reposlug}.")
try:
Expand All @@ -143,15 +190,14 @@ def refresh(self):
f"Latest change at {self._get_repo().head.commit.committed_datetime}"
)
logger.info(f"Repo dirty: {self._get_repo().is_dirty()}")
self.current_revision_id = self._get_repo().head.commit.hexsha
try:
self._get_repo().git.stash("pop")
except Exception as e:
except GitCommandError as e:
# TODO: this just means there was nothing to pop, but
# we should handle this more gracefully.
logger.debug(f"Nothing to pop: {e}")
pass
except Exception as e:
except GitCommandError as e:
logger.error(
f"Error pulling from repo {self._reposlug}: {e}. "
"Falling back on DESTRUCTION!!!"
Expand All @@ -161,7 +207,10 @@ def refresh(self):
self._cached_repo = None
# recursively delete the repo
shutil.rmtree(f"/tmp/{self._reposlug}")
self.refresh()
if retry:
self.refresh(retry=False)
else:
raise e

def list_doc_ids(self) -> list[DocumentID]:
"""
Expand Down Expand Up @@ -196,14 +245,15 @@ def is_edit_ok(self, edit: EditTrigger) -> bool:
# want to wait for the user to move on to the next line.
for doc_range in edit.input_ranges + edit.output_ranges:
repo_scoped_file = str(self._doc_id_to_path(doc_range.doc_id))
for i in range(doc_range.selection[0], doc_range.selection[1]):
if _too_close_to_human_edits(self._get_repo(), repo_scoped_file, i):
logger.info(
f"Temporarily skipping edit request in {doc_range.doc_id}"
" at line {i} because it was still in progress"
" in the last commit."
)
return False
if _too_close_to_human_edits(
self._get_repo(), repo_scoped_file, doc_range.selection
):
logger.info(
f"Temporarily skipping edit request in {doc_range.doc_id}"
" at line {i} because it was still in progress"
" in the last commit."
)
return False
return True

def to_dict(self):
Expand All @@ -221,27 +271,30 @@ def perform_edit(self, edit: EditResult) -> bool:
Returns:
True if the edit was successful, False otherwise
"""
if not self._doc_id_to_path(edit.range.doc_id).exists():
logger.error(f"Document {edit.range.doc_id} does not exist.")
return False

logger.info(f"Performing edit {edit} on remote {self._reposlug}")

if edit.type == EditType.replace:
success = self._perform_replace(edit)
elif edit.type == EditType.comment:
success = self._perform_comment(edit)
else:
raise ValueError(f"Unknown edit type {edit.type}")
try:
with self.rewind(edit.range.revision_id, message="AI edit") as paper:
if edit.type == EditType.replace:
success = paper._perform_replace(edit)
elif edit.type == EditType.comment:
success = paper._perform_comment(edit)
else:
raise ValueError(f"Unknown edit type {edit.type}")
except GitCommandError as e:
logger.error(
f"Git error performing edit {edit} on remote {self._reposlug}: {e}"
)
success = False

if success:
# TODO - apply edit relative to the edit.range.revision_id commit and then
# rebase onto HEAD for poor-man's operational transforms
self._get_repo().index.add([self._doc_id_to_path(str(edit.range.doc_id))])
self._get_repo().index.commit("AI edit completed.")
# Instead of just pushing, we need to rebase and then push.
# This is because we want to make sure that the AI edits are always
# on top of the stack.
self._get_repo().git.pull()
# TODO: We could do a better job catching WARNs here and then maybe setting
# success = False
self._get_repo().git.push()
else:
self.refresh()

return success

Expand All @@ -257,6 +310,19 @@ def _perform_replace(self, edit: EditResult) -> bool:
"""
doc_range, text = edit.range, edit.content
try:
num_lines = len(self.get_lines(doc_range.doc_id))
if (
any(i < 0 for i in doc_range.selection)
or doc_range.selection[1] < doc_range.selection[0]
or any(
i > len(self.get_lines(doc_range.doc_id))
for i in doc_range.selection
)
):
raise IndexError(
f"Invalid selection {doc_range.selection} for document "
f"{doc_range.doc_id} with {num_lines} lines."
)
lines = self.get_lines(doc_range.doc_id)
lines = (
lines[: doc_range.selection[0]]
Expand Down Expand Up @@ -284,3 +350,49 @@ def _perform_comment(self, edit: EditResult) -> bool:
# TODO - implement this for real
logger.info(f"Performing comment edit {edit} on remote {self._reposlug}")
return True

def rewind(self, commit: str, message: str):
return self.RewindContext(self, commit, message)

# Create an inner class for "with" semantics so that we can do
# `with remote.rewind(commit)` to rewind to a particular commit and play some edits
# onto it, then merge when the 'with' context exits.
class RewindContext:
# TODO - there are tricks in gitpython where an IndexFile can be used to
# handle changes to files in-memory without having to call checkout() and
# (briefly) modify the state of things on disk. This would be an improvement,
# but would require using the gitpython API more directly inside of
# perform_edit, such as calling git.IndexFile.write() instead of python's
# open() and write()

def __init__(self, remote: "OverleafGitPaperRemote", commit: str, message: str):
self._remote = remote
self._message = message
self._rewind_commit = commit

def __enter__(self):
repo = self._remote._get_repo()
self._restore_ref = repo.head.ref
self._new_branch = repo.create_head(
"tmp-edit-branch", commit=self._rewind_commit
)
self._new_branch.checkout()
return self._remote

def __exit__(self, exc_type, exc_val, exc_tb):
repo = self._remote._get_repo()
assert (
repo.active_branch == self._new_branch
), "Branch changed unexpectedly mid-`with`"
# Add files that changed
repo.index.add([_file for (_file, _), _ in repo.index.entries.items()])
repo.index.commit(self._message)
self._restore_ref.checkout()
try:
repo.git.merge("tmp-edit-branch")
except GitCommandError as e:
# Hard reset on failure
repo.git.reset("--hard", self._restore_ref.commit.hexsha)
raise e
finally:
repo.delete_head(self._new_branch, force=True)
Loading

0 comments on commit 7d54b98

Please sign in to comment.