From e884534e170c9ab2ccf684ddc17a9bf37bb4a50e Mon Sep 17 00:00:00 2001 From: David Buchanan Date: Fri, 1 Mar 2024 00:49:52 +0000 Subject: [PATCH] test jwt_monkeypatch --- .vscode/settings.json | 11 ++++++ src/millipds/jwt_monkeypatch.py | 2 +- tests/test_jwt_monkeypatch.py | 59 +++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 .vscode/settings.json create mode 100644 tests/test_jwt_monkeypatch.py diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..3d83234a --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,11 @@ +{ + "python.testing.unittestArgs": [ + "-v", + "-s", + "./tests", + "-p", + "test_*.py" + ], + "python.testing.pytestEnabled": false, + "python.testing.unittestEnabled": true +} diff --git a/src/millipds/jwt_monkeypatch.py b/src/millipds/jwt_monkeypatch.py index fc3d2a49..31c2d708 100644 --- a/src/millipds/jwt_monkeypatch.py +++ b/src/millipds/jwt_monkeypatch.py @@ -24,7 +24,7 @@ def _low_s_patched_der_to_raw_signature(der_sig: bytes, curve: ec.EllipticCurve) def _low_s_patched_raw_to_der_signature(raw_sig: bytes, curve: ec.EllipticCurve) -> bytes: der_sig = _orig_raw_to_der_signature(raw_sig, curve) - assert_dss_sig_is_low_s(der_sig) + assert_dss_sig_is_low_s(der_sig, curve) return der_sig algorithms.raw_to_der_signature = _low_s_patched_raw_to_der_signature diff --git a/tests/test_jwt_monkeypatch.py b/tests/test_jwt_monkeypatch.py new file mode 100644 index 00000000..bf48ef8a --- /dev/null +++ b/tests/test_jwt_monkeypatch.py @@ -0,0 +1,59 @@ +import unittest +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives import serialization +from cryptography.exceptions import InvalidSignature +import millipds.jwt_monkeypatch as jwt + +class JWTMonkeyPatchTestCase(unittest.TestCase): + def setUp(self): + self.priv_k1_pem = b"""-----BEGIN PRIVATE KEY----- +MIGEAgEAMBAGByqGSM49AgEGBSuBBAAKBG0wawIBAQQgsPWcRtPwwHBFvujzyRUj +C7JlnnG3yOE1AxNGjbp8vtyhRANCAAQeOzdlRItXy4xfCEwm/FlwViqXzrXlV5r3 +edC3qYgsCwXM9431jxbo4DJSutOrNVvZ2FIdBQWWMjWY9BlJykaV +-----END PRIVATE KEY-----""" + self.priv_k1 = serialization.load_pem_private_key(self.priv_k1_pem, password=None) + self.pub_k1 = self.priv_k1.public_key() + self.pub_k1_pem = self.pub_k1.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + + self.priv_r1_pem = b"""-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgRFkAFd7bWnLoITfO +039Z8foMf5HJuV1NWdZ0Uw9A3KOhRANCAATphFD8cTEqoZ3DwSf0ymVZ8LMEz6+i +zVrbeSHCLN+xv33QrqEQj1GO18squ5a15I2NfJrovxap1LlJZFBl3cPL +-----END PRIVATE KEY-----""" + self.priv_r1 = serialization.load_pem_private_key(self.priv_r1_pem, password=None) + self.pub_r1 = self.priv_r1.public_key() + self.pub_r1_pem = self.pub_r1.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + + assert(type(self.priv_k1.curve) is ec.SECP256K1) + assert(type(self.priv_r1.curve) is ec.SECP256R1) + + def test_k1_sign_verify(self): + PAYLOAD = {"hello": "world"} + for _ in range(32): + token = jwt.encode(PAYLOAD, self.priv_k1_pem, algorithm="ES256K") + decoded = jwt.decode(token, self.pub_k1_pem, algorithms=["ES256K"]) + self.assertEqual(decoded, PAYLOAD) + + def test_r1_sign_verify(self): + PAYLOAD = {"hello": "world"} + for _ in range(32): + token = jwt.encode(PAYLOAD, self.priv_r1_pem, algorithm="ES256") + decoded = jwt.decode(token, self.pub_r1_pem, algorithms=["ES256"]) + self.assertEqual(decoded, PAYLOAD) + + def test_k1_reject_high_s(self): + high_s_token = "eyJhbGciOiJFUzI1NksiLCJ0eXAiOiJKV1QifQ.eyJoZWxsbyI6IndvcmxkIn0.feGiEa50jQIhP9X_JhjUAAGrKMd4hyWGHRNVJCCoMZ3_OmsCf7NmoK_uqSnzRzazWCCuBUAoU1v5KAbmWoFZYQ" + self.assertRaises(InvalidSignature, jwt.decode, high_s_token, self.pub_k1_pem, algorithms=["ES256K"]) + + def test_r1_reject_high_s(self): + high_s_token = "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJoZWxsbyI6IndvcmxkIn0.lsY80RX_GTMO2sNGAPm3s4girlFimMoHSkmnzr1TWbqYMPQjfYZhQTT9K2M6c_9O3qPoH7FCBSssXTGniq4RxQ" + self.assertRaises(InvalidSignature, jwt.decode, high_s_token, self.pub_r1_pem, algorithms=["ES256"]) + +if __name__ == '__main__': + unittest.main(module="tests.test_jwt_monkeypatch")