Skip to content

Commit

Permalink
Merge branch 'main' into sphinx_pyscript
Browse files Browse the repository at this point in the history
  • Loading branch information
ericwb authored Feb 20, 2025
2 parents 1b54293 + c58c00a commit 19eeebd
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 52 deletions.
6 changes: 6 additions & 0 deletions bandit/blacklists/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@
| | | - random.uniform | |
| | | - random.triangular | |
| | | - random.randbytes | |
| | | - random.randrange | |
| | | - random.sample | |
| | | - random.getrandbits | |
+------+---------------------+------------------------------------+-----------+
B312: telnetlib
Expand Down Expand Up @@ -515,6 +518,9 @@ def gen_blacklist():
"random.uniform",
"random.triangular",
"random.randbytes",
"random.sample",
"random.randrange",
"random.getrandbits",
],
"Standard pseudo-random generators are not suitable for "
"security/cryptographic purposes.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,26 @@
#
# SPDX-License-Identifier: Apache-2.0
r"""
==========================================
B614: Test for unsafe PyTorch load or save
==========================================
==================================
B614: Test for unsafe PyTorch load
==================================
This plugin checks for the use of `torch.load` and `torch.save`. Using
`torch.load` with untrusted data can lead to arbitrary code execution, and
improper use of `torch.save` might expose sensitive data or lead to data
corruption. A safe alternative is to use `torch.load` with the `safetensors`
library from hugingface, which provides a safe deserialization mechanism.
This plugin checks for unsafe use of `torch.load`. Using `torch.load` with
untrusted data can lead to arbitrary code execution. There are two safe
alternatives:
1. Use `torch.load` with `weights_only=True` where only tensor data is
extracted, and no arbitrary Python objects are deserialized
2. Use the `safetensors` library from huggingface, which provides a safe
deserialization mechanism
With `weights_only=True`, PyTorch enforces a strict type check, ensuring
that only torch.Tensor objects are loaded.
:Example:
.. code-block:: none
>> Issue: Use of unsafe PyTorch load or save
>> Issue: Use of unsafe PyTorch load
Severity: Medium Confidence: High
CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html)
Location: examples/pytorch_load_save.py:8
Expand All @@ -42,12 +47,11 @@

@test.checks("Call")
@test.test_id("B614")
def pytorch_load_save(context):
def pytorch_load(context):
"""
This plugin checks for the use of `torch.load` and `torch.save`. Using
`torch.load` with untrusted data can lead to arbitrary code execution,
and improper use of `torch.save` might expose sensitive data or lead
to data corruption.
This plugin checks for unsafe use of `torch.load`. Using `torch.load`
with untrusted data can lead to arbitrary code execution. The safe
alternative is to use `weights_only=True` or the safetensors library.
"""
imported = context.is_module_imported_exact("torch")
qualname = context.call_function_name_qual
Expand All @@ -59,14 +63,18 @@ def pytorch_load_save(context):
if all(
[
"torch" in qualname_list,
func in ["load", "save"],
not context.check_call_arg_value("map_location", "cpu"),
func == "load",
]
):
# For torch.load, check if weights_only=True is specified
weights_only = context.get_call_arg_value("weights_only")
if weights_only == "True" or weights_only is True:
return

return bandit.Issue(
severity=bandit.MEDIUM,
confidence=bandit.HIGH,
text="Use of unsafe PyTorch load or save",
text="Use of unsafe PyTorch load",
cwe=issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA,
lineno=context.get_lineno_for_call_arg("load"),
)
5 changes: 5 additions & 0 deletions doc/source/plugins/b614_pytorch_load.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
------------------
B614: pytorch_load
------------------

.. automodule:: bandit.plugins.pytorch_load
5 changes: 0 additions & 5 deletions doc/source/plugins/b614_pytorch_load_save.rst

This file was deleted.

26 changes: 26 additions & 0 deletions examples/pytorch_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import torchvision.models as models

# Example of saving a model
model = models.resnet18(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

# Example of loading the model weights in an insecure way (should trigger B614)
loaded_model = models.resnet18()
loaded_model.load_state_dict(torch.load('model_weights.pth'))

# Example of loading with weights_only=True (should NOT trigger B614)
safe_model = models.resnet18()
safe_model.load_state_dict(torch.load('model_weights.pth', weights_only=True))

# Example of loading with weights_only=False (should trigger B614)
unsafe_model = models.resnet18()
unsafe_model.load_state_dict(torch.load('model_weights.pth', weights_only=False))

# Example of loading with map_location but no weights_only (should trigger B614)
cpu_model = models.resnet18()
cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))

# Example of loading with both map_location and weights_only=True (should NOT trigger B614)
safe_cpu_model = models.resnet18()
safe_cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu', weights_only=True))
21 changes: 0 additions & 21 deletions examples/pytorch_load_save.py

This file was deleted.

3 changes: 3 additions & 0 deletions examples/random_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
bad = random.uniform()
bad = random.triangular()
bad = random.randbytes()
bad = random.sample()
bad = random.randrange()
bad = random.getrandbits()

good = os.urandom()
good = random.SystemRandom()
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ bandit.plugins =
#bandit/plugins/tarfile_unsafe_members.py
tarfile_unsafe_members = bandit.plugins.tarfile_unsafe_members:tarfile_unsafe_members

#bandit/plugins/pytorch_load_save.py
pytorch_load_save = bandit.plugins.pytorch_load_save:pytorch_load_save
#bandit/plugins/pytorch_load.py
pytorch_load = bandit.plugins.pytorch_load:pytorch_load

# bandit/plugins/trojansource.py
trojansource = bandit.plugins.trojansource:trojansource
Expand Down
14 changes: 7 additions & 7 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,8 @@ def test_popen_wrappers(self):
def test_random_module(self):
"""Test for the `random` module."""
expect = {
"SEVERITY": {"UNDEFINED": 0, "LOW": 9, "MEDIUM": 0, "HIGH": 0},
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 9},
"SEVERITY": {"UNDEFINED": 0, "LOW": 12, "MEDIUM": 0, "HIGH": 0},
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 12},
}
self.check_example("random_module.py", expect)

Expand Down Expand Up @@ -872,13 +872,13 @@ def test_tarfile_unsafe_members(self):
}
self.check_example("tarfile_extractall.py", expect)

def test_pytorch_load_save(self):
"""Test insecure usage of torch.load and torch.save."""
def test_pytorch_load(self):
"""Test insecure usage of torch.load."""
expect = {
"SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 4, "HIGH": 0},
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 4},
"SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 3, "HIGH": 0},
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 3},
}
self.check_example("pytorch_load_save.py", expect)
self.check_example("pytorch_load.py", expect)

def test_trojansource(self):
expect = {
Expand Down

0 comments on commit 19eeebd

Please sign in to comment.