Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check dangerous modules instead of dangerous patterns #877

Merged
merged 7 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 29 additions & 41 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from collections.abc import Mapping
from functools import wraps
from importlib import import_module
from types import ModuleType
from types import BuiltinFunctionType, FunctionType, ModuleType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

import numpy as np
Expand Down Expand Up @@ -116,25 +116,15 @@ def custom_print(*args):
"complex": complex,
}

DANGEROUS_PATTERNS = (
"_os",
"os",
"subprocess",
"_subprocess",
"pty",
"system",
"popen",
"spawn",
"shutil",
"sys",
"pathlib",
"io",
"socket",
"compile",
"eval",
"exec",
"multiprocessing",
)
DANGEROUS_FUNCTIONS = [
"builtins.compile",
"builtins.eval",
"builtins.exec",
"builtins.globalsbuiltins.locals",
"builtins.__import__",
"os.popen",
"os.system",
]

DANGEROUS_MODULES = [
"builtins",
Expand Down Expand Up @@ -248,23 +238,32 @@ def _check_return(
result = func(expression, state, static_tools, custom_tools, authorized_imports=authorized_imports)
if "*" not in authorized_imports:
if isinstance(result, ModuleType):
for module in DANGEROUS_MODULES:
for module_name in DANGEROUS_MODULES:
if (
module not in authorized_imports
and result.__name__ == module
module_name not in authorized_imports
and result.__name__ == module_name
# builtins has no __file__ attribute
and getattr(result, "__file__", "") == getattr(import_module(module), "__file__", "")
and getattr(result, "__file__", "") == getattr(import_module(module_name), "__file__", "")
):
raise InterpreterError(f"Forbidden return value: {module}")
raise InterpreterError(f"Forbidden access to module: {module_name}")
elif isinstance(result, dict) and result.get("__name__"):
for module in DANGEROUS_MODULES:
for module_name in DANGEROUS_MODULES:
if (
module not in authorized_imports
and result["__name__"] == module
module_name not in authorized_imports
and result["__name__"] == module_name
# builtins has no __file__ attribute
and result.get("__file__", "") == getattr(import_module(module), "__file__", "")
and result.get("__file__", "") == getattr(import_module(module_name), "__file__", "")
):
raise InterpreterError(f"Forbidden access to module: {module_name}")
elif isinstance(result, (FunctionType, BuiltinFunctionType)):
for qualified_function_name in DANGEROUS_FUNCTIONS:
module_name, function_name = qualified_function_name.rsplit(".", 1)
if (
function_name not in static_tools
and result.__name__ == function_name
and result.__module__ == module_name
):
raise InterpreterError(f"Forbidden return value: {module}")
raise InterpreterError(f"Forbidden access to function: {function_name}")
return result

return _check_return
Expand Down Expand Up @@ -1083,15 +1082,6 @@ def get_safe_module(raw_module, authorized_imports, visited=None):

# Copy all attributes by reference, recursively checking modules
for attr_name in dir(raw_module):
# Skip dangerous patterns at any level
if any(
pattern in raw_module.__name__.split(".") + [attr_name]
and not check_module_authorized(pattern, authorized_imports)
for pattern in DANGEROUS_PATTERNS
):
logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}")
continue

try:
attr_value = getattr(raw_module, attr_name)
except (ImportError, AttributeError) as e:
Expand All @@ -1114,8 +1104,6 @@ def check_module_authorized(module_name, authorized_imports):
return True
else:
module_path = module_name.split(".")
if any([module in DANGEROUS_PATTERNS and module not in authorized_imports for module in module_path]):
return False
# ["A", "B", "C"] -> ["A", "A.B", "A.B.C"]
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
return any(subpath in authorized_imports for subpath in module_subpaths)
Expand Down
138 changes: 112 additions & 26 deletions tests/test_local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,22 +944,6 @@ def test_fix_final_answer_code(self):
Got: {result}
"""

def test_dangerous_subpackage_access_blocked(self):
# Direct imports with dangerous patterns should fail
code = "import random._os"
with pytest.raises(InterpreterError):
evaluate_python_code(code)

# Import of whitelisted modules should succeed but dangerous submodules should not exist
code = "import random;random._os.system('echo bad command passed')"
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code)
assert "AttributeError: module 'random' has no attribute '_os'" in str(e)

code = "import doctest;doctest.inspect.os.system('echo bad command passed')"
with pytest.raises(InterpreterError):
evaluate_python_code(code, authorized_imports=["doctest"])

def test_close_matches_subscript(self):
code = 'capitals = {"Czech Republic": "Prague", "Monaco": "Monaco", "Bhutan": "Thimphu"};capitals["Butan"]'
with pytest.raises(Exception) as e:
Expand Down Expand Up @@ -1395,7 +1379,7 @@ def test_len(self):
("AnyModule", ["*"], True),
("os", ["os"], True),
("AnyModule", ["AnyModule"], True),
("Module.os", ["Module"], False),
("Module.os", ["Module"], True),
("Module.os", ["Module", "os"], True),
("os.path", ["os"], True),
("os", ["os.path"], False),
Expand Down Expand Up @@ -1446,6 +1430,63 @@ def test_vulnerability_builtins(self, additional_authorized_imports, expected_er
):
executor("import builtins")

@pytest.mark.parametrize(
"additional_authorized_imports, expected_error",
[([], InterpreterError("Import of builtins is not allowed")), (["builtins"], None)],
)
def test_vulnerability_builtins_safe_functions(self, additional_authorized_imports, expected_error):
executor = LocalPythonExecutor(additional_authorized_imports)
with (
pytest.raises(type(expected_error), match=f".*{expected_error}")
if isinstance(expected_error, Exception)
else does_not_raise()
):
executor("import builtins; builtins.print(1)")

@pytest.mark.parametrize(
"additional_authorized_imports, additional_tools, expected_error",
[
([], [], InterpreterError("Import of builtins is not allowed")),
(["builtins"], [], InterpreterError("Forbidden access to function: exec")),
(["builtins"], ["exec"], None),
],
)
def test_vulnerability_builtins_dangerous_functions(
self, additional_authorized_imports, additional_tools, expected_error
):
executor = LocalPythonExecutor(additional_authorized_imports)
if additional_tools:
from builtins import exec

executor.send_tools({"exec": exec})
with (
pytest.raises(type(expected_error), match=f".*{expected_error}")
if isinstance(expected_error, Exception)
else does_not_raise()
):
executor("import builtins; builtins.exec")

@pytest.mark.parametrize(
"additional_authorized_imports, additional_tools, expected_error",
[
([], [], InterpreterError("Import of os is not allowed")),
(["os"], [], InterpreterError("Forbidden access to function: popen")),
(["os"], ["popen"], None),
],
)
def test_vulnerability_dangerous_functions(self, additional_authorized_imports, additional_tools, expected_error):
executor = LocalPythonExecutor(additional_authorized_imports)
if additional_tools:
from os import popen

executor.send_tools({"popen": popen})
with (
pytest.raises(type(expected_error), match=f".*{expected_error}")
if isinstance(expected_error, Exception)
else does_not_raise()
):
executor("import os; os.popen")

@pytest.mark.parametrize(
"additional_authorized_imports, expected_error",
[([], InterpreterError("Import of sys is not allowed")), (["os", "sys"], None)],
Expand All @@ -1468,7 +1509,7 @@ def test_vulnerability_via_sys(self, additional_authorized_imports, expected_err

@pytest.mark.parametrize(
"additional_authorized_imports, expected_error",
[(["importlib"], InterpreterError("Forbidden return value: os")), (["importlib", "os"], None)],
[(["importlib"], InterpreterError("Forbidden access to module: os")), (["importlib", "os"], None)],
)
def test_vulnerability_via_importlib(self, additional_authorized_imports, expected_error):
executor = LocalPythonExecutor(additional_authorized_imports)
Expand All @@ -1486,19 +1527,58 @@ def test_vulnerability_via_importlib(self, additional_authorized_imports, expect
)
)

@pytest.mark.parametrize(
"code, additional_authorized_imports, expected_error",
[
# os submodule
("import queue; queue.threading._os.system(':')", [], InterpreterError("Forbidden access to module: os")),
("import random; random._os.system(':')", [], InterpreterError("Forbidden access to module: os")),
(
"import random; random.__dict__['_os'].system(':')",
[],
InterpreterError("Forbidden access to module: os"),
),
(
"import doctest; doctest.inspect.os.system(':')",
["doctest"],
InterpreterError("Forbidden access to module: os"),
),
# subprocess submodule
(
"import asyncio; asyncio.base_events.events.subprocess",
["asyncio"],
InterpreterError("Forbidden access to module: subprocess"),
),
# sys submodule
(
"import queue; queue.threading._sys.modules['os'].system(':')",
[],
InterpreterError("Forbidden access to module: sys"),
),
# Allowed
("import pandas; pandas.io", ["pandas"], None),
],
)
def test_vulnerability_via_submodules(self, code, additional_authorized_imports, expected_error):
executor = LocalPythonExecutor(additional_authorized_imports)
with (
pytest.raises(type(expected_error), match=f".*{expected_error}")
if isinstance(expected_error, Exception)
else does_not_raise()
):
executor(code)

@pytest.mark.parametrize(
"additional_authorized_imports, additional_tools, expected_error",
[
([], [], InterpreterError("Import of sys is not allowed")),
(["sys"], [], InterpreterError("Forbidden return value: builtins")),
(["sys"], [], InterpreterError("Forbidden access to module: builtins")),
(
["sys", "builtins"],
[],
InterpreterError(
"Invoking a builtin function that has not been explicitly added as a tool is not allowed"
),
InterpreterError("Forbidden access to function: __import__"),
),
(["sys", "builtins"], ["__import__"], InterpreterError("Forbidden return value: os")),
(["sys", "builtins"], ["__import__"], InterpreterError("Forbidden access to module: os")),
(["sys", "builtins", "os"], ["__import__"], None),
],
)
Expand Down Expand Up @@ -1528,7 +1608,10 @@ def test_vulnerability_builtins_via_sys(self, additional_authorized_imports, add
@pytest.mark.parametrize("patch_builtin_import_module", [False, True]) # builtins_import.__module__ = None
@pytest.mark.parametrize(
"additional_authorized_imports, additional_tools, expected_error",
[([], [], InterpreterError("Forbidden return value: builtins")), (["builtins", "os"], ["__import__"], None)],
[
([], [], InterpreterError("Forbidden access to module: builtins")),
(["builtins", "os"], ["__import__"], None),
],
)
def test_vulnerability_builtins_via_traceback(
self, patch_builtin_import_module, additional_authorized_imports, additional_tools, expected_error, monkeypatch
Expand Down Expand Up @@ -1562,7 +1645,10 @@ def test_vulnerability_builtins_via_traceback(
@pytest.mark.parametrize("patch_builtin_import_module", [False, True]) # builtins_import.__module__ = None
@pytest.mark.parametrize(
"additional_authorized_imports, additional_tools, expected_error",
[([], [], InterpreterError("Forbidden return value: builtins")), (["builtins", "os"], ["__import__"], None)],
[
([], [], InterpreterError("Forbidden access to module: builtins")),
(["builtins", "os"], ["__import__"], None),
],
)
def test_vulnerability_builtins_via_class_catch_warnings(
self, patch_builtin_import_module, additional_authorized_imports, additional_tools, expected_error, monkeypatch
Expand Down Expand Up @@ -1597,7 +1683,7 @@ def test_vulnerability_builtins_via_class_catch_warnings(
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.parametrize(
"additional_authorized_imports, expected_error",
[([], InterpreterError("Forbidden return value: os")), (["os"], None)],
[([], InterpreterError("Forbidden access to module: os")), (["os"], None)],
)
def test_vulnerability_load_module_via_builtin_importer(self, additional_authorized_imports, expected_error):
executor = LocalPythonExecutor(additional_authorized_imports)
Expand Down