Skip to content

Commit

Permalink
Implement signature verification
Browse files Browse the repository at this point in the history
  • Loading branch information
MarshalX committed Jan 11, 2024
1 parent a88db2e commit 2412083
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 25 deletions.
9 changes: 7 additions & 2 deletions packages/atproto_crypto/algs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import typing as t

from .p256 import P256
from .secp256k1 import Secp256k1

__all__ = ['P256', 'Secp256k1']
_ANY_ALG_TYPE = t.Union[t.Type[P256], t.Type[Secp256k1]]

AVAILABLE_ALGORITHMS: t.List[_ANY_ALG_TYPE] = [P256, Secp256k1]
ALGORITHM_TO_CLASS: t.Dict[str, _ANY_ALG_TYPE] = {alg.NAME: alg for alg in AVAILABLE_ALGORITHMS}

AVAILABLE_ALGORITHMS = [P256, Secp256k1]
__all__ = ['P256', 'Secp256k1', 'ALGORITHM_TO_CLASS']
33 changes: 32 additions & 1 deletion packages/atproto_crypto/algs/base_alg.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve, EllipticCurvePublicKey
from cryptography.hazmat.primitives.asymmetric.ec import ECDSA, EllipticCurve, EllipticCurvePublicKey
from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature
from cryptography.hazmat.primitives.hashes import SHA256

from atproto_crypto.exceptions import InvalidCompressedPubkeyError


class AlgBase:
"""Base class for all algorithms."""

NAME = None

def __init__(self, curve: EllipticCurve) -> None:
self.curve = curve

Expand All @@ -28,3 +33,29 @@ def decompress_pubkey(self, pubkey: bytes) -> bytes:
return self.get_elliptic_curve_public_key(pubkey).public_bytes(
encoding=serialization.Encoding.X962, format=serialization.PublicFormat.UncompressedPoint
)

@staticmethod
def _encode_signature(signature: bytes) -> bytes:
"""Encode signature."""
r = int.from_bytes(signature[:32], 'big')
s = int.from_bytes(signature[32:], 'big')
return encode_dss_signature(r, s)

def verify_signature(self, pubkey: bytes, signing_input: bytes, signature: bytes) -> bool:
"""Verify signature.
Args:
pubkey: Public key.
signing_input: Signing input (data).
signature: Signature.
Returns:
:obj:`bool`: True if signature is valid, False otherwise.
"""
try:
self.get_elliptic_curve_public_key(pubkey).verify(
signature=self._encode_signature(signature), data=signing_input, signature_algorithm=ECDSA(SHA256())
)
return True
except InvalidSignature:
return False
3 changes: 3 additions & 0 deletions packages/atproto_crypto/algs/p256.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from cryptography.hazmat.primitives.asymmetric.ec import SECP256R1

from atproto_crypto.algs.base_alg import AlgBase
from atproto_crypto.consts import P256_JWT_ALG


class P256(AlgBase):
NAME = P256_JWT_ALG

def __init__(self) -> None:
super().__init__(SECP256R1())
3 changes: 3 additions & 0 deletions packages/atproto_crypto/algs/secp256k1.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from cryptography.hazmat.primitives.asymmetric.ec import SECP256K1

from atproto_crypto.algs.base_alg import AlgBase
from atproto_crypto.consts import SECP256K1_JWT_ALG


class Secp256k1(AlgBase):
NAME = SECP256K1_JWT_ALG

def __init__(self) -> None:
super().__init__(SECP256K1())
20 changes: 19 additions & 1 deletion packages/atproto_crypto/did.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
SECP256K1_DID_PREFIX,
SECP256K1_JWT_ALG,
)
from atproto_crypto.exceptions import IncorrectMultikeyPrefixError, UnsupportedKeyTypeError
from atproto_crypto.exceptions import IncorrectDidKeyPrefixError, IncorrectMultikeyPrefixError, UnsupportedKeyTypeError
from atproto_crypto.multibase import bytes_to_multibase, multibase_to_bytes


Expand Down Expand Up @@ -105,6 +105,24 @@ def format_multikey(jwt_alg: str, key: bytes) -> str:
return bytes_to_multibase(BASE58_MULTIBASE_PREFIX, prefixed_bytes)


def parse_did_key(did_key: str) -> Multikey:
"""Parse DID key.
Args:
did_key: DID key.
Returns:
:obj:`Multikey`: Multikey.
Raises:
:obj:`IncorrectDidKeyPrefixError`: Incorrect prefix for DID key.
"""
if not did_key.startswith(DID_KEY_PREFIX):
raise IncorrectDidKeyPrefixError(f'Incorrect prefix for DID key {did_key}')

return parse_multikey(did_key[len(DID_KEY_PREFIX) :])


def format_did_key(jwt_alg: str, key: bytes) -> str:
"""Format DID key.
Expand Down
8 changes: 8 additions & 0 deletions packages/atproto_crypto/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,13 @@ class IncorrectMultikeyPrefixError(DidKeyError):
...


class IncorrectDidKeyPrefixError(DidKeyError):
...


class UnsupportedKeyTypeError(DidKeyError):
...


class UnsupportedSignatureAlgorithmError(AtProtocolError):
...
29 changes: 20 additions & 9 deletions packages/atproto_crypto/verify.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
import typing as t
import warnings

from atproto_crypto.algs import ALGORITHM_TO_CLASS
from atproto_crypto.did import parse_did_key
from atproto_crypto.exceptions import UnsupportedSignatureAlgorithmError


def verify_signature(did_key: str, signing_input: t.Union[str, bytes], signature: t.Union[str, bytes]) -> bool:
# TODO(MarshalX): implement
warnings.warn(
'verify_signature is not implemented yet. Do not trust to this signing_input',
RuntimeWarning,
stacklevel=0,
)

return True
"""Verify signature.
Args:
did_key: DID key.
signing_input: Signing input (data).
signature: Signature.
Returns:
bool: True if signature is valid, False otherwise.
"""
parsed_did_key = parse_did_key(did_key)
if parsed_did_key.jwt_alg not in ALGORITHM_TO_CLASS:
raise UnsupportedSignatureAlgorithmError('Unsupported signature alg')

algorithm_class = ALGORITHM_TO_CLASS[parsed_did_key.jwt_alg]
return algorithm_class().verify_signature(parsed_did_key.key_bytes, signing_input, signature)
4 changes: 2 additions & 2 deletions packages/atproto_server/auth/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def verify_jwt(

fresh_signing_key = get_signing_key_callback(payload.iss, True) # get signing key without a cache
if fresh_signing_key == signing_key:
raise TokenInvalidSignatureError('Could not verify JWT signature. Fresh signing key is equal to the old one')
raise TokenInvalidSignatureError('Invalid signature even with fresh signing key it is equal to the old one)')

if _verify_signature(fresh_signing_key, signing_input, signature):
return payload
Expand Down Expand Up @@ -252,7 +252,7 @@ async def verify_jwt_async(

fresh_signing_key = await get_signing_key_callback(payload.iss, True) # get signing key without a cache
if fresh_signing_key == signing_key:
raise TokenInvalidSignatureError('Could not verify JWT signature. Fresh signing key is equal to the old one')
raise TokenInvalidSignatureError('Invalid signature even with fresh signing key it is equal to the old one)')

if _verify_signature(fresh_signing_key, signing_input, signature):
return payload
Expand Down
36 changes: 26 additions & 10 deletions tests/test_atproto_server/auth/test_jwt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing as t

import pytest
from atproto_server.auth.jwt import get_jwt_payload, parse_jwt, validate_jwt_payload, verify_jwt, verify_jwt_async
from atproto_server.exceptions import TokenDecodeError, TokenExpiredSignatureError, TokenInvalidAudienceError
Expand All @@ -7,6 +9,10 @@
_TEST_JWT_INVALID_SIGN = 'eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NksifQ.eyJpc3MiOiJkaWQ6cGxjOmt2d3ZjbjVpcWZvb29wbXl6dmI0cXpiYSIsImF1ZCI6ImRpZDp3ZWI6ZmVlZC5hdHByb3RvLmJsdWUiLCJleHAiOjIwMDAwMDAwMDB9.50SlT6vw26HsDXVDM4D2D53_Dvzd6bjp3TDc5EyDVD4ob9i3EEB7fmaKE0XR4egMS9Kf9eMdVqH5gJNCaIah4Q' # noqa: E501


if t.TYPE_CHECKING:
from _pytest.monkeypatch import MonkeyPatch


def test_parse_jwt_empty() -> None:
with pytest.raises(TokenDecodeError):
parse_jwt('')
Expand Down Expand Up @@ -41,16 +47,26 @@ def test_validate_jwt_payload_valid() -> None:
validate_jwt_payload(payload)


def test_verify_jwt() -> None:
def test_verify_jwt_valid_signature(monkeypatch: 'MonkeyPatch') -> None:
def get_signing_key(_: str, __: bool) -> str:
return 'did:key:zQ3shc6V2kvUxn7hNmPy9JMToKT7u2NH27SnKNxGL1GcBcS4j'

# allow expired token
monkeypatch.setattr('atproto_server.auth.jwt._validate_exp', lambda *_: True)

verify_jwt(_TEST_JWT_EXPIRED, get_signing_key)


def test_verify_jwt_aud_validation(monkeypatch: 'MonkeyPatch') -> None:
expected_iss = 'did:plc:kvwvcn5iqfooopmyzvb4qzba'
expected_aud = 'did:web:feed.atproto.blue'

def get_signing_key(iss: str, force_refresh: bool) -> str:
def get_signing_key(iss: str, _: bool) -> str:
assert iss == expected_iss
return 'blabla'

if force_refresh:
return 'refreshedKey'
return 'key'
# allow invalid signature
monkeypatch.setattr('atproto_server.auth.jwt._verify_signature', lambda *_: True)

verify_jwt(_TEST_JWT_INVALID_SIGN, get_signing_key)
verify_jwt(_TEST_JWT_INVALID_SIGN, get_signing_key, expected_aud)
Expand All @@ -60,16 +76,16 @@ def get_signing_key(iss: str, force_refresh: bool) -> str:


@pytest.mark.asyncio
async def test_verify_jwt_async() -> None:
async def test_verify_jwt_aud_validation_async(monkeypatch: 'MonkeyPatch') -> None:
expected_iss = 'did:plc:kvwvcn5iqfooopmyzvb4qzba'
expected_aud = 'did:web:feed.atproto.blue'

async def get_signing_key(iss: str, force_refresh: bool) -> str:
async def get_signing_key(iss: str, _: bool) -> str:
assert iss == expected_iss
return 'blabla'

if force_refresh:
return 'refreshedKey'
return 'key'
# allow invalid signature
monkeypatch.setattr('atproto_server.auth.jwt._verify_signature', lambda *_: True)

await verify_jwt_async(_TEST_JWT_INVALID_SIGN, get_signing_key)
await verify_jwt_async(_TEST_JWT_INVALID_SIGN, get_signing_key, expected_aud)
Expand Down

0 comments on commit 2412083

Please sign in to comment.