Skip to content

Commit

Permalink
Make local Python interpreter safer by checking if returns builtins (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
albertvillanova authored Mar 3, 2025
1 parent 4f2aa3e commit 0460614
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 13 deletions.
25 changes: 25 additions & 0 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import math
import re
from collections.abc import Mapping
from functools import wraps
from importlib import import_module
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
Expand Down Expand Up @@ -212,6 +213,29 @@ def fix_final_answer_code(code: str) -> str:
return code


def safer_eval(func: Callable):
"""
Decorator to make the evaluation of a function safer by checking its return value.
Args:
func: Function to make safer.
Returns:
Callable: Safer function with return value check.
"""

@wraps(func)
def _check_return(*args, **kwargs):
result = func(*args, **kwargs)
if (isinstance(result, ModuleType) and result is builtins) or (
isinstance(result, dict) and result == vars(builtins)
):
raise InterpreterError("Forbidden return value: builtins")
return result

return _check_return


def evaluate_unaryop(
expression: ast.UnaryOp,
state: Dict[str, Any],
Expand Down Expand Up @@ -1177,6 +1201,7 @@ def evaluate_delete(
raise InterpreterError(f"Deletion of {type(target).__name__} targets is not supported")


@safer_eval
def evaluate_ast(
expression: ast.AST,
state: Dict[str, Any],
Expand Down
102 changes: 89 additions & 13 deletions tests/test_local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ast
import types
import unittest
from contextlib import nullcontext as does_not_raise
from textwrap import dedent

import numpy as np
Expand Down Expand Up @@ -980,20 +981,12 @@ def test_dangerous_builtins_calls_are_blocked(self):
evaluate_python_code(dangerous_code, static_tools=BASE_PYTHON_TOOLS)

def test_dangerous_builtins_are_callable_if_explicitly_added(self):
dangerous_code = """
compile = callable.__self__.compile
eval = callable.__self__.eval
exec = callable.__self__.exec
eval("1 + 1")
exec(compile("1 + 1", "no filename", "exec"))
teval("1 + 1")
texec(tcompile("1 + 1", "no filename", "exec"))
"""

dangerous_code = dedent("""
eval("1 + 1")
exec(compile("1 + 1", "no filename", "exec"))
""")
evaluate_python_code(
dangerous_code, static_tools={"tcompile": compile, "teval": eval, "texec": exec} | BASE_PYTHON_TOOLS
dangerous_code, static_tools={"compile": compile, "eval": eval, "exec": exec} | BASE_PYTHON_TOOLS
)

def test_can_import_os_if_explicitly_authorized(self):
Expand Down Expand Up @@ -1424,3 +1417,86 @@ def test_call_from_dict(self, code):
executor = LocalPythonExecutor([])
result, _, _ = executor(code)
assert result == 11


class TestLocalPythonExecutorSecurity:
@pytest.mark.parametrize(
"additional_authorized_imports, expectation",
[([], pytest.raises(InterpreterError)), (["os"], does_not_raise())],
)
def test_vulnerability_import(self, additional_authorized_imports, expectation):
executor = LocalPythonExecutor(additional_authorized_imports)
with expectation:
executor("import os")

def test_vulnerability_builtins(self):
executor = LocalPythonExecutor([])
with pytest.raises(InterpreterError):
executor("import builtins")

@pytest.mark.parametrize(
"additional_authorized_imports, expectation",
[([], pytest.raises(InterpreterError)), (["sys"], does_not_raise())],
)
def test_vulnerability_via_sys(self, additional_authorized_imports, expectation):
executor = LocalPythonExecutor(additional_authorized_imports)
with expectation:
executor(
dedent(
"""
import sys
sys.modules["os"].system(":")
"""
)
)

@pytest.mark.parametrize(
"code",
[
dedent(
"""
try:
1 / 0
except Exception as e:
builtins = e.__traceback__.tb_frame.f_back.f_globals["__builtins__"]
builtins_import = builtins["__import__"]
os_module = builtins_import("os")
os_module.system(":")
"""
),
dedent(
"""
try:
1 / 0
except Exception as e:
builtins = e.__traceback__.tb_frame.f_back.f_globals["__builtins__"]
builtins_import = builtins["__import__"]
builtins_import.__module__ = None
os_module = builtins_import("os")
os_module.system(":")
"""
),
],
)
def test_vulnerability_builtins_via_traceback(self, code):
executor = LocalPythonExecutor([])
with pytest.raises(InterpreterError):
executor(code)

def test_vulnerability_builtins_via_class_catch_warnings(self):
executor = LocalPythonExecutor([])
with pytest.raises(InterpreterError):
executor(
dedent(
"""
classes = {}.__class__.__base__.__subclasses__()
for cls in classes:
if cls.__name__ == "catch_warnings":
builtins = cls()._module.__builtins__
builtins_import = builtins["__import__"]
break
os_module = builtins_import('os')
os_module.system(":")
"""
)
)

0 comments on commit 0460614

Please sign in to comment.