diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index 5a48f53fa..5f9277e4d 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -24,7 +24,7 @@ from collections.abc import Mapping from functools import wraps from importlib import import_module -from types import ModuleType +from types import BuiltinFunctionType, FunctionType, ModuleType from typing import Any, Callable, Dict, List, Optional, Set, Tuple import numpy as np @@ -116,25 +116,15 @@ def custom_print(*args): "complex": complex, } -DANGEROUS_PATTERNS = ( - "_os", - "os", - "subprocess", - "_subprocess", - "pty", - "system", - "popen", - "spawn", - "shutil", - "sys", - "pathlib", - "io", - "socket", - "compile", - "eval", - "exec", - "multiprocessing", -) +DANGEROUS_FUNCTIONS = [ + "builtins.compile", + "builtins.eval", + "builtins.exec", + "builtins.globalsbuiltins.locals", + "builtins.__import__", + "os.popen", + "os.system", +] DANGEROUS_MODULES = [ "builtins", @@ -248,23 +238,32 @@ def _check_return( result = func(expression, state, static_tools, custom_tools, authorized_imports=authorized_imports) if "*" not in authorized_imports: if isinstance(result, ModuleType): - for module in DANGEROUS_MODULES: + for module_name in DANGEROUS_MODULES: if ( - module not in authorized_imports - and result.__name__ == module + module_name not in authorized_imports + and result.__name__ == module_name # builtins has no __file__ attribute - and getattr(result, "__file__", "") == getattr(import_module(module), "__file__", "") + and getattr(result, "__file__", "") == getattr(import_module(module_name), "__file__", "") ): - raise InterpreterError(f"Forbidden return value: {module}") + raise InterpreterError(f"Forbidden access to module: {module_name}") elif isinstance(result, dict) and result.get("__name__"): - for module in DANGEROUS_MODULES: + for module_name in DANGEROUS_MODULES: if ( - module not in authorized_imports - and result["__name__"] == module + module_name not in authorized_imports + and result["__name__"] == module_name # builtins has no __file__ attribute - and result.get("__file__", "") == getattr(import_module(module), "__file__", "") + and result.get("__file__", "") == getattr(import_module(module_name), "__file__", "") + ): + raise InterpreterError(f"Forbidden access to module: {module_name}") + elif isinstance(result, (FunctionType, BuiltinFunctionType)): + for qualified_function_name in DANGEROUS_FUNCTIONS: + module_name, function_name = qualified_function_name.rsplit(".", 1) + if ( + function_name not in static_tools + and result.__name__ == function_name + and result.__module__ == module_name ): - raise InterpreterError(f"Forbidden return value: {module}") + raise InterpreterError(f"Forbidden access to function: {function_name}") return result return _check_return @@ -1083,15 +1082,6 @@ def get_safe_module(raw_module, authorized_imports, visited=None): # Copy all attributes by reference, recursively checking modules for attr_name in dir(raw_module): - # Skip dangerous patterns at any level - if any( - pattern in raw_module.__name__.split(".") + [attr_name] - and not check_module_authorized(pattern, authorized_imports) - for pattern in DANGEROUS_PATTERNS - ): - logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}") - continue - try: attr_value = getattr(raw_module, attr_name) except (ImportError, AttributeError) as e: @@ -1114,8 +1104,6 @@ def check_module_authorized(module_name, authorized_imports): return True else: module_path = module_name.split(".") - if any([module in DANGEROUS_PATTERNS and module not in authorized_imports for module in module_path]): - return False # ["A", "B", "C"] -> ["A", "A.B", "A.B.C"] module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)] return any(subpath in authorized_imports for subpath in module_subpaths) diff --git a/tests/test_local_python_executor.py b/tests/test_local_python_executor.py index 34b56d1fa..fe91d937f 100644 --- a/tests/test_local_python_executor.py +++ b/tests/test_local_python_executor.py @@ -944,22 +944,6 @@ def test_fix_final_answer_code(self): Got: {result} """ - def test_dangerous_subpackage_access_blocked(self): - # Direct imports with dangerous patterns should fail - code = "import random._os" - with pytest.raises(InterpreterError): - evaluate_python_code(code) - - # Import of whitelisted modules should succeed but dangerous submodules should not exist - code = "import random;random._os.system('echo bad command passed')" - with pytest.raises(InterpreterError) as e: - evaluate_python_code(code) - assert "AttributeError: module 'random' has no attribute '_os'" in str(e) - - code = "import doctest;doctest.inspect.os.system('echo bad command passed')" - with pytest.raises(InterpreterError): - evaluate_python_code(code, authorized_imports=["doctest"]) - def test_close_matches_subscript(self): code = 'capitals = {"Czech Republic": "Prague", "Monaco": "Monaco", "Bhutan": "Thimphu"};capitals["Butan"]' with pytest.raises(Exception) as e: @@ -1395,7 +1379,7 @@ def test_len(self): ("AnyModule", ["*"], True), ("os", ["os"], True), ("AnyModule", ["AnyModule"], True), - ("Module.os", ["Module"], False), + ("Module.os", ["Module"], True), ("Module.os", ["Module", "os"], True), ("os.path", ["os"], True), ("os", ["os.path"], False), @@ -1446,6 +1430,63 @@ def test_vulnerability_builtins(self, additional_authorized_imports, expected_er ): executor("import builtins") + @pytest.mark.parametrize( + "additional_authorized_imports, expected_error", + [([], InterpreterError("Import of builtins is not allowed")), (["builtins"], None)], + ) + def test_vulnerability_builtins_safe_functions(self, additional_authorized_imports, expected_error): + executor = LocalPythonExecutor(additional_authorized_imports) + with ( + pytest.raises(type(expected_error), match=f".*{expected_error}") + if isinstance(expected_error, Exception) + else does_not_raise() + ): + executor("import builtins; builtins.print(1)") + + @pytest.mark.parametrize( + "additional_authorized_imports, additional_tools, expected_error", + [ + ([], [], InterpreterError("Import of builtins is not allowed")), + (["builtins"], [], InterpreterError("Forbidden access to function: exec")), + (["builtins"], ["exec"], None), + ], + ) + def test_vulnerability_builtins_dangerous_functions( + self, additional_authorized_imports, additional_tools, expected_error + ): + executor = LocalPythonExecutor(additional_authorized_imports) + if additional_tools: + from builtins import exec + + executor.send_tools({"exec": exec}) + with ( + pytest.raises(type(expected_error), match=f".*{expected_error}") + if isinstance(expected_error, Exception) + else does_not_raise() + ): + executor("import builtins; builtins.exec") + + @pytest.mark.parametrize( + "additional_authorized_imports, additional_tools, expected_error", + [ + ([], [], InterpreterError("Import of os is not allowed")), + (["os"], [], InterpreterError("Forbidden access to function: popen")), + (["os"], ["popen"], None), + ], + ) + def test_vulnerability_dangerous_functions(self, additional_authorized_imports, additional_tools, expected_error): + executor = LocalPythonExecutor(additional_authorized_imports) + if additional_tools: + from os import popen + + executor.send_tools({"popen": popen}) + with ( + pytest.raises(type(expected_error), match=f".*{expected_error}") + if isinstance(expected_error, Exception) + else does_not_raise() + ): + executor("import os; os.popen") + @pytest.mark.parametrize( "additional_authorized_imports, expected_error", [([], InterpreterError("Import of sys is not allowed")), (["os", "sys"], None)], @@ -1468,7 +1509,7 @@ def test_vulnerability_via_sys(self, additional_authorized_imports, expected_err @pytest.mark.parametrize( "additional_authorized_imports, expected_error", - [(["importlib"], InterpreterError("Forbidden return value: os")), (["importlib", "os"], None)], + [(["importlib"], InterpreterError("Forbidden access to module: os")), (["importlib", "os"], None)], ) def test_vulnerability_via_importlib(self, additional_authorized_imports, expected_error): executor = LocalPythonExecutor(additional_authorized_imports) @@ -1486,19 +1527,58 @@ def test_vulnerability_via_importlib(self, additional_authorized_imports, expect ) ) + @pytest.mark.parametrize( + "code, additional_authorized_imports, expected_error", + [ + # os submodule + ("import queue; queue.threading._os.system(':')", [], InterpreterError("Forbidden access to module: os")), + ("import random; random._os.system(':')", [], InterpreterError("Forbidden access to module: os")), + ( + "import random; random.__dict__['_os'].system(':')", + [], + InterpreterError("Forbidden access to module: os"), + ), + ( + "import doctest; doctest.inspect.os.system(':')", + ["doctest"], + InterpreterError("Forbidden access to module: os"), + ), + # subprocess submodule + ( + "import asyncio; asyncio.base_events.events.subprocess", + ["asyncio"], + InterpreterError("Forbidden access to module: subprocess"), + ), + # sys submodule + ( + "import queue; queue.threading._sys.modules['os'].system(':')", + [], + InterpreterError("Forbidden access to module: sys"), + ), + # Allowed + ("import pandas; pandas.io", ["pandas"], None), + ], + ) + def test_vulnerability_via_submodules(self, code, additional_authorized_imports, expected_error): + executor = LocalPythonExecutor(additional_authorized_imports) + with ( + pytest.raises(type(expected_error), match=f".*{expected_error}") + if isinstance(expected_error, Exception) + else does_not_raise() + ): + executor(code) + @pytest.mark.parametrize( "additional_authorized_imports, additional_tools, expected_error", [ ([], [], InterpreterError("Import of sys is not allowed")), - (["sys"], [], InterpreterError("Forbidden return value: builtins")), + (["sys"], [], InterpreterError("Forbidden access to module: builtins")), ( ["sys", "builtins"], [], - InterpreterError( - "Invoking a builtin function that has not been explicitly added as a tool is not allowed" - ), + InterpreterError("Forbidden access to function: __import__"), ), - (["sys", "builtins"], ["__import__"], InterpreterError("Forbidden return value: os")), + (["sys", "builtins"], ["__import__"], InterpreterError("Forbidden access to module: os")), (["sys", "builtins", "os"], ["__import__"], None), ], ) @@ -1528,7 +1608,10 @@ def test_vulnerability_builtins_via_sys(self, additional_authorized_imports, add @pytest.mark.parametrize("patch_builtin_import_module", [False, True]) # builtins_import.__module__ = None @pytest.mark.parametrize( "additional_authorized_imports, additional_tools, expected_error", - [([], [], InterpreterError("Forbidden return value: builtins")), (["builtins", "os"], ["__import__"], None)], + [ + ([], [], InterpreterError("Forbidden access to module: builtins")), + (["builtins", "os"], ["__import__"], None), + ], ) def test_vulnerability_builtins_via_traceback( self, patch_builtin_import_module, additional_authorized_imports, additional_tools, expected_error, monkeypatch @@ -1562,7 +1645,10 @@ def test_vulnerability_builtins_via_traceback( @pytest.mark.parametrize("patch_builtin_import_module", [False, True]) # builtins_import.__module__ = None @pytest.mark.parametrize( "additional_authorized_imports, additional_tools, expected_error", - [([], [], InterpreterError("Forbidden return value: builtins")), (["builtins", "os"], ["__import__"], None)], + [ + ([], [], InterpreterError("Forbidden access to module: builtins")), + (["builtins", "os"], ["__import__"], None), + ], ) def test_vulnerability_builtins_via_class_catch_warnings( self, patch_builtin_import_module, additional_authorized_imports, additional_tools, expected_error, monkeypatch @@ -1597,7 +1683,7 @@ def test_vulnerability_builtins_via_class_catch_warnings( @pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.parametrize( "additional_authorized_imports, expected_error", - [([], InterpreterError("Forbidden return value: os")), (["os"], None)], + [([], InterpreterError("Forbidden access to module: os")), (["os"], None)], ) def test_vulnerability_load_module_via_builtin_importer(self, additional_authorized_imports, expected_error): executor = LocalPythonExecutor(additional_authorized_imports)