Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Diff Context never null (#531)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakethekoenig authored Feb 23, 2024
1 parent d6c88bf commit 6219d49
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 55 deletions.
44 changes: 17 additions & 27 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,16 @@ class CodeContext:
def __init__(
self,
stream: SessionStream,
git_root: Optional[Path] = None,
cwd: Path,
diff: Optional[str] = None,
pr_diff: Optional[str] = None,
ignore_patterns: Iterable[Path | str] = [],
):
self.git_root = git_root
self.diff = diff
self.pr_diff = pr_diff
self.ignore_patterns = set(Path(p) for p in ignore_patterns)

self.diff_context = None
if self.git_root:
self.diff_context = DiffContext(
stream, self.git_root, self.diff, self.pr_diff
)
self.diff_context = DiffContext(stream, cwd, self.diff, self.pr_diff)

self.include_files: Dict[Path, List[CodeFeature]] = {}
self.ignore_files: Set[Path] = set()
Expand All @@ -79,9 +74,7 @@ def refresh_context_display(self):
"""
ctx = SESSION_CONTEXT.get()

diff_context_display = None
if self.diff_context and self.diff_context.name:
diff_context_display = self.diff_context.get_display_context()
diff_context_display = self.diff_context.get_display_context()

features = get_consolidated_feature_refs(
[
Expand All @@ -91,12 +84,9 @@ def refresh_context_display(self):
]
)
auto_features = get_consolidated_feature_refs(self.auto_features)
if self.diff_context:
git_diff_paths = [str(p) for p in self.diff_context.diff_files()]
git_untracked_paths = [str(p) for p in self.diff_context.untracked_files()]
else:
git_diff_paths = []
git_untracked_paths = []
git_diff_paths = [str(p) for p in self.diff_context.diff_files()]
git_untracked_paths = [str(p) for p in self.diff_context.untracked_files()]

messages = ctx.conversation.get_messages()
code_message = get_code_message_from_features(
[
Expand Down Expand Up @@ -151,17 +141,17 @@ async def get_code_message(

# Setup code message metadata
code_message = list[str]()
if self.diff_context:
# Since there is no way of knowing when the git diff changes,
# we just refresh the cache every time get_code_message is called
self.diff_context.refresh()
if self.diff_context.diff_files():
code_message += [
"Diff References:",
f' "-" = {self.diff_context.name}',
' "+" = Active Changes',
"",
]

# Since there is no way of knowing when the git diff changes,
# we just refresh the cache every time get_code_message is called
self.diff_context.refresh()
if self.diff_context.diff_files():
code_message += [
"Diff References:",
f' "-" = {self.diff_context.name}',
' "+" = Active Changes',
"",
]

code_message += ["Code Files:\n"]
meta_tokens = count_tokens("\n".join(code_message), model, full_message=True)
Expand Down
5 changes: 1 addition & 4 deletions mentat/code_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,7 @@ def get_code_message(self, standalone: bool = True) -> list[str]:
if standalone:
code_message.append("")

if (
code_context.diff_context is not None
and self.path in code_context.diff_context.diff_files()
):
if self.path in code_context.diff_context.diff_files():
diff = get_diff_for_file(code_context.diff_context.target, self.path)
diff_annotations = parse_diff(diff)
if self.interval.whole_file():
Expand Down
27 changes: 22 additions & 5 deletions mentat/diff_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
check_head_exists,
get_diff_for_file,
get_files_in_diff,
get_git_root_for_path,
get_treeish_metadata,
get_untracked_files,
)
Expand Down Expand Up @@ -96,10 +97,14 @@ class DiffContext:
def __init__(
self,
stream: SessionStream,
git_root: Path,
cwd: Path,
diff: Optional[str] = None,
pr_diff: Optional[str] = None,
):
self.git_root = get_git_root_for_path(cwd, raise_error=False)
if not self.git_root:
return

if diff and pr_diff:
# TODO: Once broadcast queue's unread messages and/or config is moved to client,
# determine if this should quit or not
Expand All @@ -118,7 +123,7 @@ def __init__(
return

name = ""
treeish_type = _get_treeish_type(git_root, target)
treeish_type = _get_treeish_type(self.git_root, target)
if treeish_type is None:
stream.send(f"Invalid treeish: {target}", style="failure")
stream.send("Disabling diff and pr-diff.", style="warning")
Expand All @@ -133,7 +138,7 @@ def __init__(

if pr_diff:
name = f"Merge-base {name}"
target = _git_command(git_root, "merge-base", "HEAD", pr_diff)
target = _git_command(self.git_root, "merge-base", "HEAD", pr_diff)
if not target:
# TODO: Same as above todo
stream.send(
Expand All @@ -145,7 +150,7 @@ def __init__(
self.name = "HEAD (last commit)"
return

meta = get_treeish_metadata(git_root, target)
meta = get_treeish_metadata(self.git_root, target)
name += f'{meta["hexsha"][:8]}: {meta["summary"]}'
if target == "HEAD":
name = "HEAD (last commit)"
Expand All @@ -157,16 +162,22 @@ def __init__(
_untracked_files: List[Path] | None = None

def diff_files(self) -> List[Path]:
if not self.git_root:
return []
if self._diff_files is None:
self.refresh()
return self._diff_files # pyright: ignore

def untracked_files(self) -> List[Path]:
if not self.git_root:
return []
if self._untracked_files is None:
self.refresh()
return self._untracked_files # pyright: ignore

def refresh(self):
if not self.git_root:
return
ctx = SESSION_CONTEXT.get()

if self.target == "HEAD" and not check_head_exists():
Expand All @@ -181,10 +192,14 @@ def refresh(self):
]

def get_annotations(self, rel_path: Path) -> list[DiffAnnotation]:
if not self.git_root:
return []
diff = get_diff_for_file(self.target, rel_path)
return parse_diff(diff)

def get_display_context(self) -> str:
def get_display_context(self) -> Optional[str]:
if not self.git_root:
return None
diff_files = self.diff_files()
if not diff_files:
return ""
Expand All @@ -202,6 +217,8 @@ def annotate_file_message(
self, rel_path: Path, file_message: list[str]
) -> list[str]:
"""Return file_message annotated with active diff."""
if not self.git_root:
return []
annotations = self.get_annotations(rel_path)
return annotate_file_message(file_message, annotations)

Expand Down
11 changes: 2 additions & 9 deletions mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from mentat.cost_tracker import CostTracker
from mentat.ctags import ensure_ctags_installed
from mentat.errors import MentatError, ReturnToUser, SessionExit, UserError
from mentat.git_handler import get_git_root_for_path
from mentat.llm_api_handler import LlmApiHandler, is_test_environment
from mentat.logging_config import setup_logging
from mentat.revisor.revisor import revise_edits
Expand Down Expand Up @@ -69,8 +68,6 @@ def __init__(

# Since we can't set the session_context until after all of the singletons are created,
# any singletons used in the constructor of another singleton must be passed in
git_root = get_git_root_for_path(cwd, raise_error=False)

llm_api_handler = LlmApiHandler()

stream = SessionStream()
Expand All @@ -79,7 +76,7 @@ def __init__(

cost_tracker = CostTracker()

code_context = CodeContext(stream, git_root, diff, pr_diff, ignore_paths)
code_context = CodeContext(stream, cwd, diff, pr_diff, ignore_paths)

code_file_manager = CodeFileManager()

Expand Down Expand Up @@ -116,11 +113,7 @@ def __init__(
config.send_errors_to_stream()
for path in paths:
code_context.include(path, exclude_patterns=exclude_paths)
if (
code_context.diff_context is not None
and len(code_context.include_files) == 0
and (diff or pr_diff)
):
if len(code_context.include_files) == 0 and (diff or pr_diff):
for file in code_context.diff_context.diff_files():
code_context.include(file)
if config.sampler:
Expand Down
2 changes: 1 addition & 1 deletion tests/code_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def func_4(string):

code_context = CodeContext(
mock_session_context.stream,
mock_session_context.code_context.git_root,
mock_session_context.code_context.diff_context.git_root,
)
code_context.include("file_1.py")
mock_session_context.config.auto_context_tokens = 8000
Expand Down
5 changes: 1 addition & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from mentat.config import Config, config_file_name
from mentat.conversation import Conversation
from mentat.cost_tracker import CostTracker
from mentat.git_handler import get_git_root_for_path
from mentat.llm_api_handler import LlmApiHandler
from mentat.parsers.streaming_printer import StreamingPrinter
from mentat.sampler.sampler import Sampler
Expand Down Expand Up @@ -205,8 +204,6 @@ def mock_session_context(temp_testbed):
set by a Session if the test creates a Session.
If you create a Session or Client in your test, do NOT use this SessionContext!
"""
git_root = get_git_root_for_path(temp_testbed, raise_error=False)

stream = SessionStream()

cost_tracker = CostTracker()
Expand All @@ -215,7 +212,7 @@ def mock_session_context(temp_testbed):

llm_api_handler = LlmApiHandler()

code_context = CodeContext(stream, git_root)
code_context = CodeContext(stream, temp_testbed)

code_file_manager = CodeFileManager()
conversation = Conversation()
Expand Down
11 changes: 6 additions & 5 deletions tests/diff_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def test_diff_context_default(temp_testbed, git_history, mock_session_context):

# DiffContext.__init__() (default): active code vs last commit
diff_context = DiffContext(
mock_session_context.stream, mock_session_context.code_context.git_root
mock_session_context.stream,
temp_testbed,
)
assert diff_context.target == "HEAD"
assert diff_context.name == "HEAD (last commit)"
Expand Down Expand Up @@ -99,7 +100,7 @@ async def test_diff_context_commit(temp_testbed, git_history, mock_session_conte
).strip()
diff_context = DiffContext(
mock_session_context.stream,
mock_session_context.code_context.git_root,
temp_testbed,
diff=last_commit,
)
assert diff_context.target == last_commit
Expand All @@ -119,7 +120,7 @@ async def test_diff_context_commit(temp_testbed, git_history, mock_session_conte
async def test_diff_context_branch(temp_testbed, git_history, mock_session_context):
diff_context = DiffContext(
mock_session_context.stream,
mock_session_context.code_context.git_root,
temp_testbed,
diff="test_branch",
)
abs_path = Path(temp_testbed) / "multifile_calculator" / "operations.py"
Expand All @@ -142,7 +143,7 @@ async def test_diff_context_branch(temp_testbed, git_history, mock_session_conte
async def test_diff_context_relative(temp_testbed, git_history, mock_session_context):
diff_context = DiffContext(
mock_session_context.stream,
mock_session_context.code_context.git_root,
temp_testbed,
diff="HEAD~2",
)
abs_path = Path(temp_testbed) / "multifile_calculator" / "operations.py"
Expand All @@ -168,7 +169,7 @@ async def test_diff_context_pr(temp_testbed, git_history, mock_session_context):
subprocess.run(["git", "checkout", "test_branch"], cwd=temp_testbed)
diff_context = DiffContext(
mock_session_context.stream,
mock_session_context.code_context.git_root,
temp_testbed,
pr_diff="master",
)

Expand Down

0 comments on commit 6219d49

Please sign in to comment.