From 8ff25e07e487f143571cc305e56dd0253c60bc7b Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Sun, 16 Feb 2025 01:56:56 +0000 Subject: [PATCH 1/2] Pytorch fix (#1231) * Fix pytorch weights check * B614: Fix PyTorch plugin to handle weights_only parameter correctly The PyTorch plugin (B614) has been updated to properly handle the weights_only parameter in torch.load calls. When weights_only=True is specified, PyTorch will only deserialize known safe types, making the operation more secure. I also removed torch.save as there is no certain insecure element as such, saving any file or artifact requires consideration of what it is you are saving. Changes: - Update plugin to only check torch.load calls (not torch.save) - Fix weights_only check to handle both string and boolean True values - Remove map_location check as it doesn't affect security - Update example file to demonstrate both safe and unsafe cases - Update plugin documentation to mention weights_only as a safe alternative The plugin now correctly identifies unsafe torch.load calls while allowing safe usage with weights_only=True to pass without warning. Fixes: #1224 * Fix E501 line too long * Rename files to new test scope * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update doc/source/plugins/b614_pytorch_load.rst Co-authored-by: Eric Brown * Update pytorch_load.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Brown --- .../{pytorch_load_save.py => pytorch_load.py} | 42 +++++++++++-------- doc/source/plugins/b614_pytorch_load.rst | 5 +++ doc/source/plugins/b614_pytorch_load_save.rst | 5 --- examples/pytorch_load.py | 26 ++++++++++++ examples/pytorch_load_save.py | 21 ---------- setup.cfg | 4 +- tests/functional/test_functional.py | 10 ++--- 7 files changed, 63 insertions(+), 50 deletions(-) rename bandit/plugins/{pytorch_load_save.py => pytorch_load.py} (54%) create mode 100644 doc/source/plugins/b614_pytorch_load.rst delete mode 100644 doc/source/plugins/b614_pytorch_load_save.rst create mode 100644 examples/pytorch_load.py delete mode 100644 examples/pytorch_load_save.py diff --git a/bandit/plugins/pytorch_load_save.py b/bandit/plugins/pytorch_load.py similarity index 54% rename from bandit/plugins/pytorch_load_save.py rename to bandit/plugins/pytorch_load.py index 77522da22..8be5e3451 100644 --- a/bandit/plugins/pytorch_load_save.py +++ b/bandit/plugins/pytorch_load.py @@ -2,21 +2,26 @@ # # SPDX-License-Identifier: Apache-2.0 r""" -========================================== -B614: Test for unsafe PyTorch load or save -========================================== +================================== +B614: Test for unsafe PyTorch load +================================== -This plugin checks for the use of `torch.load` and `torch.save`. Using -`torch.load` with untrusted data can lead to arbitrary code execution, and -improper use of `torch.save` might expose sensitive data or lead to data -corruption. A safe alternative is to use `torch.load` with the `safetensors` -library from hugingface, which provides a safe deserialization mechanism. +This plugin checks for unsafe use of `torch.load`. Using `torch.load` with +untrusted data can lead to arbitrary code execution. There are two safe +alternatives: +1. Use `torch.load` with `weights_only=True` where only tensor data is + extracted, and no arbitrary Python objects are deserialized +2. Use the `safetensors` library from huggingface, which provides a safe + deserialization mechanism + +With `weights_only=True`, PyTorch enforces a strict type check, ensuring +that only torch.Tensor objects are loaded. :Example: .. code-block:: none - >> Issue: Use of unsafe PyTorch load or save + >> Issue: Use of unsafe PyTorch load Severity: Medium Confidence: High CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html) Location: examples/pytorch_load_save.py:8 @@ -42,12 +47,11 @@ @test.checks("Call") @test.test_id("B614") -def pytorch_load_save(context): +def pytorch_load(context): """ - This plugin checks for the use of `torch.load` and `torch.save`. Using - `torch.load` with untrusted data can lead to arbitrary code execution, - and improper use of `torch.save` might expose sensitive data or lead - to data corruption. + This plugin checks for unsafe use of `torch.load`. Using `torch.load` + with untrusted data can lead to arbitrary code execution. The safe + alternative is to use `weights_only=True` or the safetensors library. """ imported = context.is_module_imported_exact("torch") qualname = context.call_function_name_qual @@ -59,14 +63,18 @@ def pytorch_load_save(context): if all( [ "torch" in qualname_list, - func in ["load", "save"], - not context.check_call_arg_value("map_location", "cpu"), + func == "load", ] ): + # For torch.load, check if weights_only=True is specified + weights_only = context.get_call_arg_value("weights_only") + if weights_only == "True" or weights_only is True: + return + return bandit.Issue( severity=bandit.MEDIUM, confidence=bandit.HIGH, - text="Use of unsafe PyTorch load or save", + text="Use of unsafe PyTorch load", cwe=issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA, lineno=context.get_lineno_for_call_arg("load"), ) diff --git a/doc/source/plugins/b614_pytorch_load.rst b/doc/source/plugins/b614_pytorch_load.rst new file mode 100644 index 000000000..808383e6a --- /dev/null +++ b/doc/source/plugins/b614_pytorch_load.rst @@ -0,0 +1,5 @@ +------------------ +B614: pytorch_load +------------------ + +.. automodule:: bandit.plugins.pytorch_load diff --git a/doc/source/plugins/b614_pytorch_load_save.rst b/doc/source/plugins/b614_pytorch_load_save.rst deleted file mode 100644 index dcc1ae3a0..000000000 --- a/doc/source/plugins/b614_pytorch_load_save.rst +++ /dev/null @@ -1,5 +0,0 @@ ------------------------ -B614: pytorch_load_save ------------------------ - -.. automodule:: bandit.plugins.pytorch_load_save diff --git a/examples/pytorch_load.py b/examples/pytorch_load.py new file mode 100644 index 000000000..c5129a035 --- /dev/null +++ b/examples/pytorch_load.py @@ -0,0 +1,26 @@ +import torch +import torchvision.models as models + +# Example of saving a model +model = models.resnet18(pretrained=True) +torch.save(model.state_dict(), 'model_weights.pth') + +# Example of loading the model weights in an insecure way (should trigger B614) +loaded_model = models.resnet18() +loaded_model.load_state_dict(torch.load('model_weights.pth')) + +# Example of loading with weights_only=True (should NOT trigger B614) +safe_model = models.resnet18() +safe_model.load_state_dict(torch.load('model_weights.pth', weights_only=True)) + +# Example of loading with weights_only=False (should trigger B614) +unsafe_model = models.resnet18() +unsafe_model.load_state_dict(torch.load('model_weights.pth', weights_only=False)) + +# Example of loading with map_location but no weights_only (should trigger B614) +cpu_model = models.resnet18() +cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu')) + +# Example of loading with both map_location and weights_only=True (should NOT trigger B614) +safe_cpu_model = models.resnet18() +safe_cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu', weights_only=True)) diff --git a/examples/pytorch_load_save.py b/examples/pytorch_load_save.py deleted file mode 100644 index e1f912022..000000000 --- a/examples/pytorch_load_save.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -import torchvision.models as models - -# Example of saving a model -model = models.resnet18(pretrained=True) -torch.save(model.state_dict(), 'model_weights.pth') - -# Example of loading the model weights in an insecure way -loaded_model = models.resnet18() -loaded_model.load_state_dict(torch.load('model_weights.pth')) - -# Save the model -torch.save(loaded_model.state_dict(), 'model_weights.pth') - -# Another example using torch.load with more parameters -another_model = models.resnet18() -another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu')) - -# Save the model -torch.save(another_model.state_dict(), 'model_weights.pth') - diff --git a/setup.cfg b/setup.cfg index 83d57d1ea..e0288e600 100644 --- a/setup.cfg +++ b/setup.cfg @@ -155,8 +155,8 @@ bandit.plugins = #bandit/plugins/tarfile_unsafe_members.py tarfile_unsafe_members = bandit.plugins.tarfile_unsafe_members:tarfile_unsafe_members - #bandit/plugins/pytorch_load_save.py - pytorch_load_save = bandit.plugins.pytorch_load_save:pytorch_load_save + #bandit/plugins/pytorch_load.py + pytorch_load = bandit.plugins.pytorch_load:pytorch_load # bandit/plugins/trojansource.py trojansource = bandit.plugins.trojansource:trojansource diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index f9fe6956b..660b65f94 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -872,13 +872,13 @@ def test_tarfile_unsafe_members(self): } self.check_example("tarfile_extractall.py", expect) - def test_pytorch_load_save(self): - """Test insecure usage of torch.load and torch.save.""" + def test_pytorch_load(self): + """Test insecure usage of torch.load.""" expect = { - "SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 4, "HIGH": 0}, - "CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 4}, + "SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 3, "HIGH": 0}, + "CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 3}, } - self.check_example("pytorch_load_save.py", expect) + self.check_example("pytorch_load.py", expect) def test_trojansource(self): expect = { From c58c00a2427012874fadc5379a9c676c98fced85 Mon Sep 17 00:00:00 2001 From: Ari Pollak Date: Wed, 19 Feb 2025 04:53:19 -0500 Subject: [PATCH 2/2] Add more random functions to B311 check (#1235) * Add sample, randrange, and getrandbits to B311 check * Add to bad examples * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_functional.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- bandit/blacklists/calls.py | 6 ++++++ examples/random_module.py | 3 +++ tests/functional/test_functional.py | 4 ++-- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/bandit/blacklists/calls.py b/bandit/blacklists/calls.py index 86f08a61d..024e873a7 100644 --- a/bandit/blacklists/calls.py +++ b/bandit/blacklists/calls.py @@ -205,6 +205,9 @@ | | | - random.uniform | | | | | - random.triangular | | | | | - random.randbytes | | +| | | - random.randrange | | +| | | - random.sample | | +| | | - random.getrandbits | | +------+---------------------+------------------------------------+-----------+ B312: telnetlib @@ -515,6 +518,9 @@ def gen_blacklist(): "random.uniform", "random.triangular", "random.randbytes", + "random.sample", + "random.randrange", + "random.getrandbits", ], "Standard pseudo-random generators are not suitable for " "security/cryptographic purposes.", diff --git a/examples/random_module.py b/examples/random_module.py index 224f2513c..f0b6d010a 100644 --- a/examples/random_module.py +++ b/examples/random_module.py @@ -11,6 +11,9 @@ bad = random.uniform() bad = random.triangular() bad = random.randbytes() +bad = random.sample() +bad = random.randrange() +bad = random.getrandbits() good = os.urandom() good = random.SystemRandom() diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 660b65f94..eaaa06428 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -365,8 +365,8 @@ def test_popen_wrappers(self): def test_random_module(self): """Test for the `random` module.""" expect = { - "SEVERITY": {"UNDEFINED": 0, "LOW": 9, "MEDIUM": 0, "HIGH": 0}, - "CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 9}, + "SEVERITY": {"UNDEFINED": 0, "LOW": 12, "MEDIUM": 0, "HIGH": 0}, + "CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 12}, } self.check_example("random_module.py", expect)