diff --git a/packages/atproto_crypto/algs/__init__.py b/packages/atproto_crypto/algs/__init__.py index ea24996d..1e412580 100644 --- a/packages/atproto_crypto/algs/__init__.py +++ b/packages/atproto_crypto/algs/__init__.py @@ -1,6 +1,11 @@ +import typing as t + from .p256 import P256 from .secp256k1 import Secp256k1 -__all__ = ['P256', 'Secp256k1'] +_ANY_ALG_TYPE = t.Union[t.Type[P256], t.Type[Secp256k1]] + +AVAILABLE_ALGORITHMS: t.List[_ANY_ALG_TYPE] = [P256, Secp256k1] +ALGORITHM_TO_CLASS: t.Dict[str, _ANY_ALG_TYPE] = {alg.NAME: alg for alg in AVAILABLE_ALGORITHMS} -AVAILABLE_ALGORITHMS = [P256, Secp256k1] +__all__ = ['P256', 'Secp256k1', 'ALGORITHM_TO_CLASS'] diff --git a/packages/atproto_crypto/algs/base_alg.py b/packages/atproto_crypto/algs/base_alg.py index acc9b4c5..1b77542a 100644 --- a/packages/atproto_crypto/algs/base_alg.py +++ b/packages/atproto_crypto/algs/base_alg.py @@ -1,5 +1,8 @@ +from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve, EllipticCurvePublicKey +from cryptography.hazmat.primitives.asymmetric.ec import ECDSA, EllipticCurve, EllipticCurvePublicKey +from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature +from cryptography.hazmat.primitives.hashes import SHA256 from atproto_crypto.exceptions import InvalidCompressedPubkeyError @@ -7,6 +10,8 @@ class AlgBase: """Base class for all algorithms.""" + NAME = None + def __init__(self, curve: EllipticCurve) -> None: self.curve = curve @@ -28,3 +33,29 @@ def decompress_pubkey(self, pubkey: bytes) -> bytes: return self.get_elliptic_curve_public_key(pubkey).public_bytes( encoding=serialization.Encoding.X962, format=serialization.PublicFormat.UncompressedPoint ) + + @staticmethod + def _encode_signature(signature: bytes) -> bytes: + """Encode signature.""" + r = int.from_bytes(signature[:32], 'big') + s = int.from_bytes(signature[32:], 'big') + return encode_dss_signature(r, s) + + def verify_signature(self, pubkey: bytes, signing_input: bytes, signature: bytes) -> bool: + """Verify signature. + + Args: + pubkey: Public key. + signing_input: Signing input (data). + signature: Signature. + + Returns: + :obj:`bool`: True if signature is valid, False otherwise. + """ + try: + self.get_elliptic_curve_public_key(pubkey).verify( + signature=self._encode_signature(signature), data=signing_input, signature_algorithm=ECDSA(SHA256()) + ) + return True + except InvalidSignature: + return False diff --git a/packages/atproto_crypto/algs/p256.py b/packages/atproto_crypto/algs/p256.py index 24d47818..e6fb4164 100644 --- a/packages/atproto_crypto/algs/p256.py +++ b/packages/atproto_crypto/algs/p256.py @@ -1,8 +1,11 @@ from cryptography.hazmat.primitives.asymmetric.ec import SECP256R1 from atproto_crypto.algs.base_alg import AlgBase +from atproto_crypto.consts import P256_JWT_ALG class P256(AlgBase): + NAME = P256_JWT_ALG + def __init__(self) -> None: super().__init__(SECP256R1()) diff --git a/packages/atproto_crypto/algs/secp256k1.py b/packages/atproto_crypto/algs/secp256k1.py index b870f919..07e65cea 100644 --- a/packages/atproto_crypto/algs/secp256k1.py +++ b/packages/atproto_crypto/algs/secp256k1.py @@ -1,8 +1,11 @@ from cryptography.hazmat.primitives.asymmetric.ec import SECP256K1 from atproto_crypto.algs.base_alg import AlgBase +from atproto_crypto.consts import SECP256K1_JWT_ALG class Secp256k1(AlgBase): + NAME = SECP256K1_JWT_ALG + def __init__(self) -> None: super().__init__(SECP256K1()) diff --git a/packages/atproto_crypto/did.py b/packages/atproto_crypto/did.py index c5ba45ba..10e34876 100644 --- a/packages/atproto_crypto/did.py +++ b/packages/atproto_crypto/did.py @@ -11,7 +11,7 @@ SECP256K1_DID_PREFIX, SECP256K1_JWT_ALG, ) -from atproto_crypto.exceptions import IncorrectMultikeyPrefixError, UnsupportedKeyTypeError +from atproto_crypto.exceptions import IncorrectDidKeyPrefixError, IncorrectMultikeyPrefixError, UnsupportedKeyTypeError from atproto_crypto.multibase import bytes_to_multibase, multibase_to_bytes @@ -105,6 +105,24 @@ def format_multikey(jwt_alg: str, key: bytes) -> str: return bytes_to_multibase(BASE58_MULTIBASE_PREFIX, prefixed_bytes) +def parse_did_key(did_key: str) -> Multikey: + """Parse DID key. + + Args: + did_key: DID key. + + Returns: + :obj:`Multikey`: Multikey. + + Raises: + :obj:`IncorrectDidKeyPrefixError`: Incorrect prefix for DID key. + """ + if not did_key.startswith(DID_KEY_PREFIX): + raise IncorrectDidKeyPrefixError(f'Incorrect prefix for DID key {did_key}') + + return parse_multikey(did_key[len(DID_KEY_PREFIX) :]) + + def format_did_key(jwt_alg: str, key: bytes) -> str: """Format DID key. diff --git a/packages/atproto_crypto/exceptions.py b/packages/atproto_crypto/exceptions.py index c7609f8c..c2b795b6 100644 --- a/packages/atproto_crypto/exceptions.py +++ b/packages/atproto_crypto/exceptions.py @@ -13,5 +13,13 @@ class IncorrectMultikeyPrefixError(DidKeyError): ... +class IncorrectDidKeyPrefixError(DidKeyError): + ... + + class UnsupportedKeyTypeError(DidKeyError): ... + + +class UnsupportedSignatureAlgorithmError(AtProtocolError): + ... diff --git a/packages/atproto_crypto/verify.py b/packages/atproto_crypto/verify.py index 028b28e3..3b7dc1ac 100644 --- a/packages/atproto_crypto/verify.py +++ b/packages/atproto_crypto/verify.py @@ -1,13 +1,24 @@ import typing as t -import warnings + +from atproto_crypto.algs import ALGORITHM_TO_CLASS +from atproto_crypto.did import parse_did_key +from atproto_crypto.exceptions import UnsupportedSignatureAlgorithmError def verify_signature(did_key: str, signing_input: t.Union[str, bytes], signature: t.Union[str, bytes]) -> bool: - # TODO(MarshalX): implement - warnings.warn( - 'verify_signature is not implemented yet. Do not trust to this signing_input', - RuntimeWarning, - stacklevel=0, - ) - - return True + """Verify signature. + + Args: + did_key: DID key. + signing_input: Signing input (data). + signature: Signature. + + Returns: + bool: True if signature is valid, False otherwise. + """ + parsed_did_key = parse_did_key(did_key) + if parsed_did_key.jwt_alg not in ALGORITHM_TO_CLASS: + raise UnsupportedSignatureAlgorithmError('Unsupported signature alg') + + algorithm_class = ALGORITHM_TO_CLASS[parsed_did_key.jwt_alg] + return algorithm_class().verify_signature(parsed_did_key.key_bytes, signing_input, signature) diff --git a/packages/atproto_server/auth/jwt.py b/packages/atproto_server/auth/jwt.py index 87178071..2b844e3c 100644 --- a/packages/atproto_server/auth/jwt.py +++ b/packages/atproto_server/auth/jwt.py @@ -207,7 +207,7 @@ def verify_jwt( fresh_signing_key = get_signing_key_callback(payload.iss, True) # get signing key without a cache if fresh_signing_key == signing_key: - raise TokenInvalidSignatureError('Could not verify JWT signature. Fresh signing key is equal to the old one') + raise TokenInvalidSignatureError('Invalid signature even with fresh signing key it is equal to the old one)') if _verify_signature(fresh_signing_key, signing_input, signature): return payload @@ -252,7 +252,7 @@ async def verify_jwt_async( fresh_signing_key = await get_signing_key_callback(payload.iss, True) # get signing key without a cache if fresh_signing_key == signing_key: - raise TokenInvalidSignatureError('Could not verify JWT signature. Fresh signing key is equal to the old one') + raise TokenInvalidSignatureError('Invalid signature even with fresh signing key it is equal to the old one)') if _verify_signature(fresh_signing_key, signing_input, signature): return payload diff --git a/tests/test_atproto_server/auth/test_jwt.py b/tests/test_atproto_server/auth/test_jwt.py index 26124c8a..84022f4f 100644 --- a/tests/test_atproto_server/auth/test_jwt.py +++ b/tests/test_atproto_server/auth/test_jwt.py @@ -1,3 +1,5 @@ +import typing as t + import pytest from atproto_server.auth.jwt import get_jwt_payload, parse_jwt, validate_jwt_payload, verify_jwt, verify_jwt_async from atproto_server.exceptions import TokenDecodeError, TokenExpiredSignatureError, TokenInvalidAudienceError @@ -7,6 +9,10 @@ _TEST_JWT_INVALID_SIGN = 'eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NksifQ.eyJpc3MiOiJkaWQ6cGxjOmt2d3ZjbjVpcWZvb29wbXl6dmI0cXpiYSIsImF1ZCI6ImRpZDp3ZWI6ZmVlZC5hdHByb3RvLmJsdWUiLCJleHAiOjIwMDAwMDAwMDB9.50SlT6vw26HsDXVDM4D2D53_Dvzd6bjp3TDc5EyDVD4ob9i3EEB7fmaKE0XR4egMS9Kf9eMdVqH5gJNCaIah4Q' # noqa: E501 +if t.TYPE_CHECKING: + from _pytest.monkeypatch import MonkeyPatch + + def test_parse_jwt_empty() -> None: with pytest.raises(TokenDecodeError): parse_jwt('') @@ -41,16 +47,26 @@ def test_validate_jwt_payload_valid() -> None: validate_jwt_payload(payload) -def test_verify_jwt() -> None: +def test_verify_jwt_valid_signature(monkeypatch: 'MonkeyPatch') -> None: + def get_signing_key(_: str, __: bool) -> str: + return 'did:key:zQ3shc6V2kvUxn7hNmPy9JMToKT7u2NH27SnKNxGL1GcBcS4j' + + # allow expired token + monkeypatch.setattr('atproto_server.auth.jwt._validate_exp', lambda *_: True) + + verify_jwt(_TEST_JWT_EXPIRED, get_signing_key) + + +def test_verify_jwt_aud_validation(monkeypatch: 'MonkeyPatch') -> None: expected_iss = 'did:plc:kvwvcn5iqfooopmyzvb4qzba' expected_aud = 'did:web:feed.atproto.blue' - def get_signing_key(iss: str, force_refresh: bool) -> str: + def get_signing_key(iss: str, _: bool) -> str: assert iss == expected_iss + return 'blabla' - if force_refresh: - return 'refreshedKey' - return 'key' + # allow invalid signature + monkeypatch.setattr('atproto_server.auth.jwt._verify_signature', lambda *_: True) verify_jwt(_TEST_JWT_INVALID_SIGN, get_signing_key) verify_jwt(_TEST_JWT_INVALID_SIGN, get_signing_key, expected_aud) @@ -60,16 +76,16 @@ def get_signing_key(iss: str, force_refresh: bool) -> str: @pytest.mark.asyncio -async def test_verify_jwt_async() -> None: +async def test_verify_jwt_aud_validation_async(monkeypatch: 'MonkeyPatch') -> None: expected_iss = 'did:plc:kvwvcn5iqfooopmyzvb4qzba' expected_aud = 'did:web:feed.atproto.blue' - async def get_signing_key(iss: str, force_refresh: bool) -> str: + async def get_signing_key(iss: str, _: bool) -> str: assert iss == expected_iss + return 'blabla' - if force_refresh: - return 'refreshedKey' - return 'key' + # allow invalid signature + monkeypatch.setattr('atproto_server.auth.jwt._verify_signature', lambda *_: True) await verify_jwt_async(_TEST_JWT_INVALID_SIGN, get_signing_key) await verify_jwt_async(_TEST_JWT_INVALID_SIGN, get_signing_key, expected_aud)