diff --git a/src/autocoder/privacy/__init__.py b/src/autocoder/privacy/__init__.py new file mode 100644 index 000000000..e07488555 --- /dev/null +++ b/src/autocoder/privacy/__init__.py @@ -0,0 +1,3 @@ +from .model_filter import ModelPathFilter + +__all__ = ["ModelPathFilter"] \ No newline at end of file diff --git a/src/autocoder/privacy/model_filter.py b/src/autocoder/privacy/model_filter.py new file mode 100644 index 000000000..368882075 --- /dev/null +++ b/src/autocoder/privacy/model_filter.py @@ -0,0 +1,93 @@ +import re +import yaml +from pathlib import Path +from typing import Dict, List, Optional +from loguru import logger + + +class ModelPathFilter: + def __init__(self, + model_name: str, + config_path: str = "model_filters.yml", + default_forbidden: List[str] = None): + """ + 模型路径过滤器 + :param model_name: 当前使用的模型名称 + :param config_path: 过滤规则配置文件路径 + :param default_forbidden: 默认禁止路径规则 + """ + self.model_name = model_name + self.config_path = Path(config_path) + self.default_forbidden = default_forbidden or [] + self._rules_cache: Dict[str, List[re.Pattern]] = {} + self._load_rules() + + def _load_rules(self): + """加载并编译正则规则""" + if not self.config_path.exists(): + logger.warning(f"Filter config {self.config_path} not found") + return + + with open(self.config_path, 'r') as f: + config = yaml.safe_load(f) + + model_rules = config.get('model_filters', {}).get(self.model_name, {}) + all_rules = model_rules.get('forbidden_paths', []) + self.default_forbidden + + # 预编译正则表达式 + self._rules_cache[self.model_name] = [ + re.compile(rule) for rule in all_rules + ] + + def is_accessible(self, file_path: str) -> bool: + """ + 检查文件路径是否符合访问规则 + :return: True表示允许访问,False表示禁止 + """ + # 优先使用模型专属规则 + patterns = self._rules_cache.get(self.model_name, []) + + # 回退到默认规则 + if not patterns and self.default_forbidden: + patterns = [re.compile(rule) for rule in self.default_forbidden] + + # 如果路径为空或None,直接返回True + if not file_path: + return True + + return not any(pattern.search(file_path) for pattern in patterns) + + def add_temp_rule(self, rule: str): + """ + 添加临时规则 + :param rule: 正则表达式规则 + """ + patterns = self._rules_cache.get(self.model_name, []) + patterns.append(re.compile(rule)) + self._rules_cache[self.model_name] = patterns + + def reload_rules(self): + """重新加载规则配置""" + self._rules_cache.clear() + self._load_rules() + + @classmethod + def from_model_object(cls, + llm_obj, + config_path: Optional[str] = None, + default_forbidden: Optional[List[str]] = None): + """ + 从LLM对象创建过滤器 + :param llm_obj: ByzerLLM实例或类似对象 + :param config_path: 可选的自定义配置文件路径 + :param default_forbidden: 默认禁止路径规则 + """ + model_name = getattr(llm_obj, 'default_model_name', None) + if not model_name: + model_name = "unknown(without default model name)" + + return cls( + model_name=model_name, + config_path=config_path or "model_filters.yml", + default_forbidden=default_forbidden + ) \ No newline at end of file diff --git a/tests/test_privacy.py b/tests/test_privacy.py new file mode 100644 index 000000000..770fc8921 --- /dev/null +++ b/tests/test_privacy.py @@ -0,0 +1,107 @@ +import pytest +from pathlib import Path +from autocoder.privacy import ModelPathFilter + + +def test_model_filter_basic(tmp_path): + # 创建临时配置文件 + config_content = """ +model_filters: + test_model: + forbidden_paths: + - "^src/autocoder/index/.*" + - "^tests/.*" + """ + config_file = tmp_path / "test_filters.yml" + config_file.write_text(config_content) + + # 创建过滤器实例 + filter = ModelPathFilter("test_model", str(config_file)) + + # 测试路径检查 + assert filter.is_accessible("src/main.py") is True + assert filter.is_accessible("src/autocoder/index/core.py") is False + assert filter.is_accessible("tests/test_index.py") is False + + +def test_model_filter_empty_config(tmp_path): + # 测试空配置文件 + config_file = tmp_path / "empty_filters.yml" + config_file.write_text("") + + filter = ModelPathFilter("test_model", str(config_file)) + assert filter.is_accessible("any/path.py") is True + + +def test_model_filter_default_rules(): + # 测试默认规则 + default_rules = ["^config/.*", "\\.env$"] + filter = ModelPathFilter( + "test_model", + config_path="non_existent.yml", + default_forbidden=default_rules + ) + + assert filter.is_accessible("src/main.py") is True + assert filter.is_accessible("config/settings.py") is False + assert filter.is_accessible(".env") is False + + +def test_model_filter_add_temp_rule(): + # 测试添加临时规则 + filter = ModelPathFilter("test_model", "non_existent.yml") + + # 初始状态应该允许访问 + assert filter.is_accessible("temp/file.py") is True + + # 添加临时规则后应该禁止访问 + filter.add_temp_rule("^temp/.*") + assert filter.is_accessible("temp/file.py") is False + + +def test_model_filter_from_model_object(): + # 模拟LLM对象 + class MockLLM: + default_model_name = "mock_model" + + llm = MockLLM() + filter = ModelPathFilter.from_model_object(llm) + assert filter.model_name == "mock_model" + + +def test_model_filter_reload_rules(tmp_path): + # 测试规则重新加载 + config_file = tmp_path / "reload_filters.yml" + + # 初始配置 + initial_config = """ +model_filters: + test_model: + forbidden_paths: + - "^src/.*" + """ + config_file.write_text(initial_config) + + filter = ModelPathFilter("test_model", str(config_file)) + assert filter.is_accessible("src/main.py") is False + + # 更新配置 + new_config = """ +model_filters: + test_model: + forbidden_paths: + - "^tests/.*" + """ + config_file.write_text(new_config) + + # 重新加载规则 + filter.reload_rules() + assert filter.is_accessible("src/main.py") is True + assert filter.is_accessible("tests/test.py") is False + + +def test_model_filter_empty_path(): + # 测试空路径处理 + filter = ModelPathFilter("test_model", "non_existent.yml") + assert filter.is_accessible("") is True + assert filter.is_accessible(None) is True \ No newline at end of file