diff --git a/mentat/code_context.py b/mentat/code_context.py index 812440f9e..d675123fa 100644 --- a/mentat/code_context.py +++ b/mentat/code_context.py @@ -409,6 +409,8 @@ async def search( all_nodes_sorted = self.daemon.search(query, max_results) all_features_sorted = list[tuple[CodeFeature, float]]() for node in all_nodes_sorted: + if node.get("type") not in {"file", "chunk"}: + continue distance = node["distance"] path, interval = split_intervals_from_path(Path(node["ref"])) intervals = parse_intervals(interval) diff --git a/mentat/code_feature.py b/mentat/code_feature.py index 8de5c9d03..ef760f981 100644 --- a/mentat/code_feature.py +++ b/mentat/code_feature.py @@ -1,16 +1,13 @@ from __future__ import annotations -import asyncio -import logging -from collections import OrderedDict, defaultdict +from collections import defaultdict from pathlib import Path from typing import Optional import attr +from ragdaemon.utils import get_document -from mentat.diff_context import annotate_file_message, parse_diff from mentat.errors import MentatError -from mentat.git_handler import get_diff_for_file from mentat.interval import INTERVAL_FILE_END, Interval from mentat.llm_api_handler import count_tokens from mentat.session_context import SESSION_CONTEXT @@ -62,118 +59,12 @@ def interval_string(self) -> str: def __str__(self, cwd: Optional[Path] = None) -> str: return self.rel_path(cwd) + self.interval_string() - def get_code_message(self, standalone: bool = True) -> list[str]: - """ - Gets this code features code message. - If standalone is true, will include the filename at top and extra newline at the end. - If feature contains entire file, will add inline diff annotations; otherwise, will append them to the end. - """ - if not self.path.exists() or self.path.is_dir(): - return [] - - session_context = SESSION_CONTEXT.get() - code_file_manager = session_context.code_file_manager - parser = session_context.config.parser - code_context = session_context.code_context - - code_message: list[str] = [] - - if standalone: - # We always want to give GPT posix paths - code_message_path = get_relative_path(self.path, session_context.cwd) - code_message.append(str(code_message_path.as_posix())) - - # Get file lines - file_lines = code_file_manager.read_file(self.path) - for i, line in enumerate(file_lines): - if self.interval.contains(i + 1): - if parser.provide_line_numbers(): - code_message.append(f"{i + parser.line_number_starting_index()}:{line}") - else: - code_message.append(f"{line}") - - if standalone: - code_message.append("") - - if self.path in code_context.diff_context.diff_files(): - diff = get_diff_for_file(code_context.diff_context.target, self.path) - diff_annotations = parse_diff(diff) - if self.interval.whole_file(): - code_message = annotate_file_message(code_message, diff_annotations) - else: - for section in diff_annotations: - # TODO: Place diff_annotations inside interval where they belong - if section.start >= self.interval.start and section.start < self.interval.end: - code_message += section.message - return code_message - - def get_checksum(self) -> str: - # TODO: Only update checksum if last modified time of file updates to conserve file system reads - session_context = SESSION_CONTEXT.get() - code_file_manager = session_context.code_file_manager - - return code_file_manager.get_file_checksum(self.path, self.interval) - - def count_tokens(self, model: str) -> int: - code_message = self.get_code_message() - return count_tokens("\n".join(code_message), model, full_message=False) - - -async def count_feature_tokens(features: list[CodeFeature], model: str) -> list[int]: - """Return the number of tokens in each feature.""" - sem = asyncio.Semaphore(10) - - feature_tokens = list[int]() - for feature in features: - async with sem: - tokens = feature.count_tokens(model) - feature_tokens.append(tokens) - return feature_tokens - - -def _get_code_message_from_intervals(features: list[CodeFeature]) -> list[str]: - """ - Merge multiple features for the same file into a single code message. - """ - features_sorted = sorted(features, key=lambda f: f.interval) - posix_path = features_sorted[0].get_code_message()[0] - code_message = [posix_path] - next_line = 1 - for feature in features_sorted: - starting_line = feature.interval.start - if starting_line < next_line: - logging.info(f"Features overlap: {feature}") - if feature.interval.end <= next_line: - continue - feature = CodeFeature( - feature.path, - interval=Interval(next_line, feature.interval.end), - name=feature.name, - ) - elif starting_line > next_line: - code_message += ["..."] - code_message += feature.get_code_message(standalone=False) - next_line = feature.interval.end - return code_message + [""] - - -def get_code_message_from_features(features: list[CodeFeature]) -> list[str]: - """ - Generate a code message from a list of features. - Will automatically handle overlapping intervals. - """ - code_message = list[str]() - features_by_path: dict[Path, list[CodeFeature]] = OrderedDict() - for feature in features: - if feature.path not in features_by_path: - features_by_path[feature.path] = list[CodeFeature]() - features_by_path[feature.path].append(feature) - for path_features in features_by_path.values(): - if len(path_features) == 1: - code_message += path_features[0].get_code_message() - else: - code_message += _get_code_message_from_intervals(path_features) - return code_message + +def count_feature_tokens(feature: CodeFeature, model: str) -> int: + cwd = SESSION_CONTEXT.get().cwd + ref = feature.__str__(cwd) + document = get_document(ref, cwd) + return count_tokens(document, model, full_message=False) def get_consolidated_feature_refs(features: list[CodeFeature]) -> list[str]: diff --git a/mentat/command/commands/search.py b/mentat/command/commands/search.py index d6406ff2c..b804365b6 100644 --- a/mentat/command/commands/search.py +++ b/mentat/command/commands/search.py @@ -2,6 +2,7 @@ from typing_extensions import override +from mentat.code_feature import count_feature_tokens from mentat.command.command import Command, CommandArgument from mentat.errors import UserError from mentat.session_context import SESSION_CONTEXT @@ -60,7 +61,7 @@ async def apply(self, *args: str) -> None: file_interval = feature.interval_string() stream.send(file_interval, color="bright_cyan", end="") - tokens = feature.count_tokens(config.model) + tokens = count_feature_tokens(feature, config.model) cumulative_tokens += tokens tokens_str = f" ({tokens} tokens)" stream.send(tokens_str, color="yellow") diff --git a/mentat/diff_context.py b/mentat/diff_context.py index a8f510f84..96af9f609 100644 --- a/mentat/diff_context.py +++ b/mentat/diff_context.py @@ -2,9 +2,6 @@ from pathlib import Path from typing import List, Literal, Optional -import attr - -from mentat.errors import MentatError from mentat.git_handler import ( check_head_exists, get_diff_for_file, @@ -13,83 +10,10 @@ get_treeish_metadata, get_untracked_files, ) -from mentat.interval import Interval from mentat.session_context import SESSION_CONTEXT from mentat.session_stream import SessionStream -@attr.define(frozen=True) -class DiffAnnotation(Interval): - start: int | float = attr.field() - message: List[str] = attr.field() - end: int | float = attr.field( - default=attr.Factory( - lambda self: self.start + sum(bool(line.startswith("-")) for line in self.message), - takes_self=True, - ) - ) - - -def parse_diff(diff: str) -> list[DiffAnnotation]: - """Parse diff into a list of annotations.""" - annotations: list[DiffAnnotation] = [] - active_annotation: Optional[DiffAnnotation] = None - lines = diff.splitlines() - for line in lines: - if line.startswith(("---", "+++", "//", "diff", "index")): - continue - elif line.startswith("@@"): - if active_annotation: - annotations.append(active_annotation) - _new_index = line.split(" ")[2] - if "," in _new_index: - new_start = _new_index[1:].split(",")[0] - else: - new_start = _new_index[1:] - active_annotation = DiffAnnotation(start=int(new_start), message=[]) - elif line.startswith(("+", "-")): - if not active_annotation: - raise MentatError("Invalid diff") - active_annotation.message.append(line) - if active_annotation: - annotations.append(active_annotation) - annotations.sort(key=lambda a: a.start) - return annotations - - -def annotate_file_message(code_message: list[str], annotations: list[DiffAnnotation]) -> list[str]: - """Return the code_message with annotations inserted.""" - active_index = 0 - annotated_message: list[str] = [] - for annotation in annotations: - # Fill-in lines between annotations - if active_index < annotation.start: - unaffected_lines = code_message[active_index : annotation.start] - annotated_message += unaffected_lines - active_index = annotation.start - if annotation.start == 0: - # Make sure the PATH stays on line 1 - annotated_message.append(code_message[0]) - active_index += 1 - i_minus = None - for line in annotation.message: - sign = line[0] - if sign == "+": - # Add '+' lines in place of code_message lines - annotated_message.append(f"{active_index}:{line}") - active_index += 1 - i_minus = None - elif sign == "-": - # Insert '-' lines at the point they were removed - i_minus = 0 if i_minus is None else i_minus - annotated_message.append(f"{annotation.start + i_minus}:{line}") - i_minus += 1 - if active_index < len(code_message): - annotated_message += code_message[active_index:] - - return annotated_message - - class DiffContext: target: str = "" name: str = "index (last commit)" @@ -184,12 +108,6 @@ def refresh(self): self._diff_files = [(ctx.cwd / f).resolve() for f in get_files_in_diff(self.target)] self._untracked_files = [(ctx.cwd / f).resolve() for f in get_untracked_files(ctx.cwd)] - def get_annotations(self, rel_path: Path) -> list[DiffAnnotation]: - if not self.git_root: - return [] - diff = get_diff_for_file(self.target, rel_path) - return parse_diff(diff) - def get_display_context(self) -> Optional[str]: if not self.git_root: return None @@ -204,13 +122,6 @@ def get_display_context(self) -> Optional[str]: num_lines += len([line for line in diff_lines if line.startswith(("+ ", "- "))]) return f" {self.name} | {num_files} files | {num_lines} lines" - def annotate_file_message(self, rel_path: Path, file_message: list[str]) -> list[str]: - """Return file_message annotated with active diff.""" - if not self.git_root: - return [] - annotations = self.get_annotations(rel_path) - return annotate_file_message(file_message, annotations) - TreeishType = Literal["commit", "branch", "relative", "compare"] diff --git a/scripts/sampler/__main__.py b/scripts/sampler/__main__.py index 601b7b07d..39b22dca8 100644 --- a/scripts/sampler/__main__.py +++ b/scripts/sampler/__main__.py @@ -10,9 +10,7 @@ from pathlib import Path from typing import Any -from add_context import add_context from finetune import generate_finetune -from remove_context import remove_context from validate import validate_sample from mentat.llm_api_handler import count_tokens, prompt_tokens @@ -50,13 +48,6 @@ async def main(): help="Validate samples conform to spec", ) parser.add_argument("--finetune", "-f", action="store_true", help="Generate fine-tuning examples") - parser.add_argument("--add-context", "-a", action="store_true", help="Add extra context to samples") - parser.add_argument( - "--remove-context", - "-r", - action="store_true", - help="Remove context from samples", - ) args = parser.parse_args() sample_files = [] if args.sample_ids: @@ -81,11 +72,6 @@ async def main(): except Exception as e: warn(f"Error loading sample {sample_file}: {e}") continue - if (args.add_context or args.remove_context) and ( - "[ADDED CONTEXT]" in sample.title or "[REMOVED CONTEXT]" in sample.title - ): - warn(f"Skipping {sample.id[:8]}: has already been modified.") - continue if args.validate: is_valid, reason = await validate_sample(sample) status = "\033[92mPASSED\033[0m" if is_valid else f"\033[91mFAILED: {reason}\033[0m" @@ -104,26 +90,6 @@ async def main(): logs.append(example) except Exception as e: warn(f"Error generating finetune example for sample {sample.id[:8]}: {e}") - elif args.add_context: - try: - new_sample = await add_context(sample) - sample_file = SAMPLES_DIR / f"sample_{new_sample.id}.json" - new_sample.save(sample_file) - print(f"Generated new sample with extra context: {sample_file}") - logs.append({"id": new_sample.id, "prototype_id": sample.id}) - except Exception as e: - warn(f"Error adding extra context to sample {sample.id[:8]}: {e}") - elif args.remove_context: - if not sample.context or len(sample.context) == 1: - warn(f"Skipping {sample.id[:8]}: no context to remove.") - continue - try: - new_sample = await remove_context(sample) - new_sample.save(SAMPLES_DIR / f"sample_{new_sample.id}.json") - print(f"Generated new sample with context removed: {sample_file}") - logs.append({"id": new_sample.id, "prototype_id": sample.id}) - except Exception as e: - warn(f"Error removing context from sample {sample.id[:8]}: {e}") else: print(f"Running sample {sample.id[:8]}") print(f" Prompt: {sample.message_prompt}") @@ -161,10 +127,6 @@ async def main(): del log["tokens"] f.write(json.dumps(log) + "\n") print(f"{len(logs)} fine-tuning examples ({tokens} tokens) saved to {fname}.") - elif args.add_context: - print(f"{len(logs)} samples with extra context generated.") - elif args.remove_context: - print(f"{len(logs)} samples with context removed generated.") if __name__ == "__main__": diff --git a/scripts/sampler/add_context.py b/scripts/sampler/add_context.py deleted file mode 100644 index ba49ad2fa..000000000 --- a/scripts/sampler/add_context.py +++ /dev/null @@ -1,41 +0,0 @@ -from pathlib import Path -from uuid import uuid4 - -import attr - -from mentat.code_feature import get_consolidated_feature_refs -from mentat.python_client.client import PythonClient -from mentat.sampler.sample import Sample -from mentat.sampler.utils import setup_repo -from mentat.session_context import SESSION_CONTEXT - - -async def add_context(sample, extra_tokens: int = 5000) -> Sample: - """Return a duplicate sample with extra (auto-context generated) context.""" - # Setup mentat CodeContext with included_files - repo = setup_repo( - url=sample.repo, - commit=sample.merge_base, - diff_merge_base=sample.diff_merge_base, - diff_active=sample.diff_active, - ) - cwd = Path(repo.working_dir) - paths = list[Path]() - for a in sample.context: - paths.append(Path(a)) - python_client = PythonClient(cwd=cwd, paths=paths) - await python_client.startup() - - # Use auto-context to add extra tokens, then copy the resulting features - ctx = SESSION_CONTEXT.get() - ctx.config.auto_context_tokens = extra_tokens - _ = await ctx.code_context.get_code_message(prompt_tokens=0, prompt=sample.message_prompt) - included_features = list(f for fs in ctx.code_context.include_files.values() for f in fs) - all_features = get_consolidated_feature_refs(included_features) - await python_client.shutdown() - - new_sample = Sample(**attr.asdict(sample)) - new_sample.context = [str(f) for f in all_features] - new_sample.id = uuid4().hex - new_sample.title = f"{sample.title} [ADDED CONTEXT]" - return new_sample diff --git a/scripts/sampler/remove_context.py b/scripts/sampler/remove_context.py deleted file mode 100644 index 7df241217..000000000 --- a/scripts/sampler/remove_context.py +++ /dev/null @@ -1,132 +0,0 @@ -import random -from pathlib import Path -from textwrap import dedent -from uuid import uuid4 - -import attr -from openai.types.chat import ( - ChatCompletionSystemMessageParam, - ChatCompletionUserMessageParam, -) - -from mentat.code_feature import CodeFeature, get_code_message_from_features -from mentat.errors import SampleError -from mentat.python_client.client import PythonClient -from mentat.sampler.sample import Sample -from mentat.sampler.utils import setup_repo - - -async def remove_context(sample) -> Sample: - """Return a duplicate sample with one context item removed and a warning message""" - - # Setup the repo and load context files - repo = setup_repo( - url=sample.repo, - commit=sample.merge_base, - diff_merge_base=sample.diff_merge_base, - diff_active=sample.diff_active, - ) - cwd = Path(repo.working_dir) - python_client = PythonClient(cwd=Path("."), paths=[]) - await python_client.startup() - - context = [CodeFeature(cwd / p) for p in sample.context] - i_target = random.randint(0, len(context) - 1) - target = context[i_target] - print("-" * 80) - print("Prompt\n", sample.message_prompt) - print("Context\n", sample.context) - print("Removed:", target) - print("") - - # Build conversation: [rejection_prompt, message_prompt, keep_context, remove_context] - target_context = target.get_code_message(standalone=False) - background_features = context[:i_target] + context[i_target + 1 :] - background_context = "\n".join(get_code_message_from_features(background_features)) - messages = [ - ChatCompletionSystemMessageParam( - role="system", - content=dedent( - """\ - You are part of an LLM Coding Assistant, designed to answer questions and - complete tasks for developers. Specifically, you generate examples of - interactions where the user has not provided enough context to fulfill the - query. You will be shown an example query, some background code which will - be included, and some target code which is NOT be included. - - Pretend you haven't seen the target code, and tell the user what additional - information you'll need in order to fulfill the task. Take a deep breath, - focus, and then complete your task by following this procedure: - - 1. Read the USER QUERY (below) carefully. Consider the steps involved in - completing it. - 2. Read the BACKROUND CONTEXT (below that) carefully. Consider how it - contributes to completing the task. - 3. Read the TARGET CONTEXT (below that) carefully. Consider how it - contributes to completing the task. - 4. Think of a short (1-sentence) explanation of why the TARGET CONTEXT is - required to complete the task. - 5. Return a ~1 paragraph message to the user explaining why the BACKGROUND - CONTEXT is not sufficient to answer the question. - - REMEMBER: - * Don't reference TARGET CONTEXT specifically. Answer as if you've never - seen it, you just know you're missing something essential. - * Return #5 (your response to the user) as a single sentence, without - preamble, notes, extra spacing or additional commentary. - - EXAMPLE - ============= - USER QUERY: "Can you make it so that I can write questions/answers in a - list at the top of the file, and then use that list to populate the - component." - BACKGROUND_CONTEXT: "" - TARGET_CONTEXT: - RESPONSE: "No code files have been included. In order to make the - requested changes, I need to see the context related to \"writing - questions/answers\" and \"populating the component\"." - """ - ), - ), - ChatCompletionUserMessageParam(role="user", content=f"USER QUERY:\n{sample.message_prompt}"), - ChatCompletionSystemMessageParam( - role="system", - content=f"BACKGROUND CONTEXT:\n{background_context}", - ), - ChatCompletionSystemMessageParam( - role="system", - content=f"TARGET CONTEXT:\n{target_context}", - ), - ] - - # Ask gpt-4 to generate rejection prompt - llm_api_handler = python_client.session.ctx.llm_api_handler - llm_api_handler.initialize_client() - llm_response = await llm_api_handler.call_llm_api( - messages=messages, - model=python_client.session.ctx.config.model, - stream=False, - ) - message = (llm_response.choices[0].message.content) or "" - await python_client.shutdown() - - # Ask user to review and accept/reject - print("Generated reason:", message) - print("Press ENTER to accept, 's' to skip this sample, or type a new reason to" " reject.") - response = input() - if response: - if response.lower() == "s": - raise SampleError("Skipping sample.") - message = response - if not message: - raise SampleError("No rejection reason provided. Aborting.") - - # Create and return a duplicate/udpated sample - new_sample = Sample(**attr.asdict(sample)) - new_sample.context = [str(f) for f in background_features] - new_sample.id = uuid4().hex - new_sample.title = f"{sample.title} [REMOVED CONTEXT]" - new_sample.message_edit = message - new_sample.diff_edit = "" - - return new_sample diff --git a/tests/diff_context_test.py b/tests/diff_context_test.py index 0f070f1dc..21e04ef9a 100644 --- a/tests/diff_context_test.py +++ b/tests/diff_context_test.py @@ -76,15 +76,6 @@ def test_diff_context_default(temp_testbed, git_history, mock_session_context): diff_context._diff_files = None # This is usually cached assert diff_context.diff_files() == [abs_path] - # DiffContext.annotate_file_message(): modify file_message with diff - file_message = _get_file_message(abs_path) - annotated_message = diff_context.annotate_file_message(abs_path, file_message) - expected = file_message[:-1] + [ - "14:- return commit3", - "14:+ return commit5", - ] - assert annotated_message == expected - @pytest.mark.asyncio async def test_diff_context_commit(temp_testbed, git_history, mock_session_context): @@ -101,14 +92,6 @@ async def test_diff_context_commit(temp_testbed, git_history, mock_session_conte assert diff_context.name == f"{last_commit[:8]}: add testbed" assert diff_context.diff_files() == [abs_path] - file_message = _get_file_message(abs_path) - annotated_message = diff_context.annotate_file_message(abs_path, file_message) - expected = file_message[:-1] + [ - "14:- return a / b", - "14:+ return commit3", - ] - assert annotated_message == expected - @pytest.mark.asyncio async def test_diff_context_branch(temp_testbed, git_history, mock_session_context): @@ -124,14 +107,6 @@ async def test_diff_context_branch(temp_testbed, git_history, mock_session_conte assert diff_context.name.endswith(": commit4") assert diff_context.diff_files() == [abs_path] - file_message = _get_file_message(abs_path) - annotated_message = diff_context.annotate_file_message(abs_path, file_message) - expected = file_message[:-1] + [ - "14:- return commit4", - "14:+ return commit3", - ] - assert annotated_message == expected - @pytest.mark.asyncio async def test_diff_context_relative(temp_testbed, git_history, mock_session_context): @@ -147,13 +122,6 @@ async def test_diff_context_relative(temp_testbed, git_history, mock_session_con assert diff_context.name.endswith(": add testbed") assert diff_context.diff_files() == [abs_path] - file_message = _get_file_message(abs_path) - annotated_message = diff_context.annotate_file_message(abs_path, file_message) - expected = file_message[:-1] + [ - "14:- return a / b", - "14:+ return commit3", - ] - assert annotated_message == expected @pytest.mark.asyncio