Skip to content

Commit

Permalink
@@context_pruner(location: src/autocoder/common/context_pruner.py) 中,…
Browse files Browse the repository at this point in the history
… 优化文件处理逻辑,确保每个文件的 token 计数正确传递。在 notebooks/test_context_prune_v2.py 中添加 file_path_to_source_code 函数,将文件路径转换为 SourceCode 对象,并启用滑动窗口测试。

auto_coder_000000001850_chat_action.yml_ccb129cf78b3bed77e1e0df53ca04a3a
  • Loading branch information
allwefantasy committed Mar 5, 2025
1 parent 2363ea3 commit 62c744a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
23 changes: 17 additions & 6 deletions notebooks/test_context_prune_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pkg_resources
import os
from autocoder.common import SourceCode
from autocoder.rag.token_counter import count_tokens

try:
tokenizer_path = pkg_resources.resource_filename(
Expand Down Expand Up @@ -83,6 +84,11 @@ def factorial(self, n: int) -> int:
f.write(content)
return os.path.abspath(file_path)

def file_path_to_source_code(file_path: str) -> SourceCode:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
return SourceCode(module_name=file_path, source_code=content, tokens=count_tokens(content))

def test_extract_large_file():
"""测试从大文件中抽取相关代码片段的功能"""
# 创建测试文件
Expand Down Expand Up @@ -118,7 +124,7 @@ def test_extract_large_file():

# 处理文件
pruned_files = context_pruner.handle_overflow(
[large_file_path],
[file_path_to_source_code(large_file_path)],
conversations,
args.context_prune_strategy
)
Expand Down Expand Up @@ -149,14 +155,19 @@ def operation_{i}(self, a: int, b: int) -> int:
'''Operation {i}'''
result = a + b * {i}
return result
"""
content += """
def add(self, a: int, b: int) -> int:
'''加法'''
return a + b
"""

with open(file_path, "w", encoding="utf-8") as f:
f.write(content)

file_path = os.path.abspath(file_path)

file_source = file_path_to_source_code(file_path)
# 配置参数
args = AutoCoderArgs(
source_dir=".",
Expand Down Expand Up @@ -185,7 +196,7 @@ def operation_{i}(self, a: int, b: int) -> int:

# 处理文件
pruned_files = context_pruner.handle_overflow(
[file_path],
[file_source],
conversations,
args.context_prune_strategy
)
Expand All @@ -207,5 +218,5 @@ def operation_{i}(self, a: int, b: int) -> int:
print("测试1: 提取相关代码片段")
test_extract_large_file()

# print("\n测试2: 滑动窗口处理超大文件")
# test_extract_with_sliding_window()
print("\n测试2: 滑动窗口处理超大文件")
test_extract_with_sliding_window()
11 changes: 7 additions & 4 deletions src/autocoder/common/context_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,12 @@ def extract_code_snippets(conversations: List[Dict[str, str]], content: str, is_
if tokens > self.max_tokens:
self.printer.print_in_terminal(
"file_sliding_window_processing", file_path=file_source.module_name, tokens=tokens)

chunks = self._split_content_with_sliding_window(file_source.source_code,
self.args.context_prune_sliding_window_size,
self.args.context_prune_sliding_window_overlap)
all_snippets = []
for chunk_start, chunk_end, chunk_content in chunks:
for chunk_start, chunk_end, chunk_content in chunks:
extracted = extract_code_snippets.with_llm(self.llm).run(
conversations=conversations,
content=chunk_content,
Expand Down Expand Up @@ -379,13 +380,15 @@ def _count_tokens(self, file_sources: List[SourceCode]) -> int:
for file_source in file_sources:
try:
if file_source.tokens > 0:
tokens = file_source.tokens
total_tokens += file_source.tokens
else:
tokens = count_tokens(file_source.source_code)
sources.append(SourceCode(module_name=file_source.module_name,
source_code=file_source.source_code, tokens=tokens))
tokens = count_tokens(file_source.source_code)
total_tokens += tokens

sources.append(SourceCode(module_name=file_source.module_name,
source_code=file_source.source_code, tokens=tokens))

except Exception as e:
logger.error(f"Failed to count tokens for {file_source.module_name}: {e}")
sources.append(SourceCode(module_name=file_source.module_name,
Expand Down

0 comments on commit 62c744a

Please sign in to comment.