-
Notifications
You must be signed in to change notification settings - Fork 95
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
auto_coder_000000001528_chat_action.yml_1fd4b2b8b7763f98b15d9613a01d2b83
根据历史对话,完成ModelPathFilter的实现.该类放在./src/autocoder/privacy 目录下
- Loading branch information
1 parent
a876b0f
commit 6405d0e
Showing
3 changed files
with
203 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .model_filter import ModelPathFilter | ||
|
||
__all__ = ["ModelPathFilter"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import re | ||
import yaml | ||
from pathlib import Path | ||
from typing import Dict, List, Optional | ||
from loguru import logger | ||
|
||
|
||
class ModelPathFilter: | ||
def __init__(self, | ||
model_name: str, | ||
config_path: str = "model_filters.yml", | ||
default_forbidden: List[str] = None): | ||
""" | ||
模型路径过滤器 | ||
:param model_name: 当前使用的模型名称 | ||
:param config_path: 过滤规则配置文件路径 | ||
:param default_forbidden: 默认禁止路径规则 | ||
""" | ||
self.model_name = model_name | ||
self.config_path = Path(config_path) | ||
self.default_forbidden = default_forbidden or [] | ||
self._rules_cache: Dict[str, List[re.Pattern]] = {} | ||
self._load_rules() | ||
|
||
def _load_rules(self): | ||
"""加载并编译正则规则""" | ||
if not self.config_path.exists(): | ||
logger.warning(f"Filter config {self.config_path} not found") | ||
return | ||
|
||
with open(self.config_path, 'r') as f: | ||
config = yaml.safe_load(f) | ||
|
||
model_rules = config.get('model_filters', {}).get(self.model_name, {}) | ||
all_rules = model_rules.get('forbidden_paths', []) + self.default_forbidden | ||
|
||
# 预编译正则表达式 | ||
self._rules_cache[self.model_name] = [ | ||
re.compile(rule) for rule in all_rules | ||
] | ||
|
||
def is_accessible(self, file_path: str) -> bool: | ||
""" | ||
检查文件路径是否符合访问规则 | ||
:return: True表示允许访问,False表示禁止 | ||
""" | ||
# 优先使用模型专属规则 | ||
patterns = self._rules_cache.get(self.model_name, []) | ||
|
||
# 回退到默认规则 | ||
if not patterns and self.default_forbidden: | ||
patterns = [re.compile(rule) for rule in self.default_forbidden] | ||
|
||
# 如果路径为空或None,直接返回True | ||
if not file_path: | ||
return True | ||
|
||
return not any(pattern.search(file_path) for pattern in patterns) | ||
|
||
def add_temp_rule(self, rule: str): | ||
""" | ||
添加临时规则 | ||
:param rule: 正则表达式规则 | ||
""" | ||
patterns = self._rules_cache.get(self.model_name, []) | ||
patterns.append(re.compile(rule)) | ||
self._rules_cache[self.model_name] = patterns | ||
|
||
def reload_rules(self): | ||
"""重新加载规则配置""" | ||
self._rules_cache.clear() | ||
self._load_rules() | ||
|
||
@classmethod | ||
def from_model_object(cls, | ||
llm_obj, | ||
config_path: Optional[str] = None, | ||
default_forbidden: Optional[List[str]] = None): | ||
""" | ||
从LLM对象创建过滤器 | ||
:param llm_obj: ByzerLLM实例或类似对象 | ||
:param config_path: 可选的自定义配置文件路径 | ||
:param default_forbidden: 默认禁止路径规则 | ||
""" | ||
model_name = getattr(llm_obj, 'default_model_name', None) | ||
if not model_name: | ||
model_name = "unknown(without default model name)" | ||
|
||
return cls( | ||
model_name=model_name, | ||
config_path=config_path or "model_filters.yml", | ||
default_forbidden=default_forbidden | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import pytest | ||
from pathlib import Path | ||
from autocoder.privacy import ModelPathFilter | ||
|
||
|
||
def test_model_filter_basic(tmp_path): | ||
# 创建临时配置文件 | ||
config_content = """ | ||
model_filters: | ||
test_model: | ||
forbidden_paths: | ||
- "^src/autocoder/index/.*" | ||
- "^tests/.*" | ||
""" | ||
config_file = tmp_path / "test_filters.yml" | ||
config_file.write_text(config_content) | ||
|
||
# 创建过滤器实例 | ||
filter = ModelPathFilter("test_model", str(config_file)) | ||
|
||
# 测试路径检查 | ||
assert filter.is_accessible("src/main.py") is True | ||
assert filter.is_accessible("src/autocoder/index/core.py") is False | ||
assert filter.is_accessible("tests/test_index.py") is False | ||
|
||
|
||
def test_model_filter_empty_config(tmp_path): | ||
# 测试空配置文件 | ||
config_file = tmp_path / "empty_filters.yml" | ||
config_file.write_text("") | ||
|
||
filter = ModelPathFilter("test_model", str(config_file)) | ||
assert filter.is_accessible("any/path.py") is True | ||
|
||
|
||
def test_model_filter_default_rules(): | ||
# 测试默认规则 | ||
default_rules = ["^config/.*", "\\.env$"] | ||
filter = ModelPathFilter( | ||
"test_model", | ||
config_path="non_existent.yml", | ||
default_forbidden=default_rules | ||
) | ||
|
||
assert filter.is_accessible("src/main.py") is True | ||
assert filter.is_accessible("config/settings.py") is False | ||
assert filter.is_accessible(".env") is False | ||
|
||
|
||
def test_model_filter_add_temp_rule(): | ||
# 测试添加临时规则 | ||
filter = ModelPathFilter("test_model", "non_existent.yml") | ||
|
||
# 初始状态应该允许访问 | ||
assert filter.is_accessible("temp/file.py") is True | ||
|
||
# 添加临时规则后应该禁止访问 | ||
filter.add_temp_rule("^temp/.*") | ||
assert filter.is_accessible("temp/file.py") is False | ||
|
||
|
||
def test_model_filter_from_model_object(): | ||
# 模拟LLM对象 | ||
class MockLLM: | ||
default_model_name = "mock_model" | ||
|
||
llm = MockLLM() | ||
filter = ModelPathFilter.from_model_object(llm) | ||
assert filter.model_name == "mock_model" | ||
|
||
|
||
def test_model_filter_reload_rules(tmp_path): | ||
# 测试规则重新加载 | ||
config_file = tmp_path / "reload_filters.yml" | ||
|
||
# 初始配置 | ||
initial_config = """ | ||
model_filters: | ||
test_model: | ||
forbidden_paths: | ||
- "^src/.*" | ||
""" | ||
config_file.write_text(initial_config) | ||
|
||
filter = ModelPathFilter("test_model", str(config_file)) | ||
assert filter.is_accessible("src/main.py") is False | ||
|
||
# 更新配置 | ||
new_config = """ | ||
model_filters: | ||
test_model: | ||
forbidden_paths: | ||
- "^tests/.*" | ||
""" | ||
config_file.write_text(new_config) | ||
|
||
# 重新加载规则 | ||
filter.reload_rules() | ||
assert filter.is_accessible("src/main.py") is True | ||
assert filter.is_accessible("tests/test.py") is False | ||
|
||
|
||
def test_model_filter_empty_path(): | ||
# 测试空路径处理 | ||
filter = ModelPathFilter("test_model", "non_existent.yml") | ||
assert filter.is_accessible("") is True | ||
assert filter.is_accessible(None) is True |