diff --git a/oidckit/crypto.py b/oidckit/crypto.py index 7855911..a033fa7 100644 --- a/oidckit/crypto.py +++ b/oidckit/crypto.py @@ -7,7 +7,10 @@ def decode_jws( - payload: bytes, key: dict, expected_algorithm: str, verify: bool = True + payload: bytes, + key: dict, + expected_algorithm: str, + verify: bool = True, ) -> dict: jws = JWS.from_compact(payload) if verify: @@ -18,7 +21,7 @@ def decode_jws( if alg != expected_algorithm: raise OIDCError( - f"Algorithm mismatch: offered {alg} is not expected {expected_algorithm}" + f"Algorithm mismatch: offered {alg} is not expected {expected_algorithm}", ) jwk = JWK.from_json(key) @@ -40,7 +43,7 @@ def get_key_from_keyset_json(keyset_json: dict, token: bytes) -> dict: jwk_alg = jwk.get("alg") if jwk_alg and jwk_alg != expected_alg: raise OIDCError( - f"kid {header.kid} has alg {jwk_alg}, was expecting {header.alg}" + f"kid {header.kid} has alg {jwk_alg}, was expecting {header.alg}", ) return jwk raise OIDCError(f"Keyset has no matching key for kid {expected_kid}.") diff --git a/oidckit/excs.py b/oidckit/excs.py index e6fc82f..588dd0c 100644 --- a/oidckit/excs.py +++ b/oidckit/excs.py @@ -16,5 +16,7 @@ def raise_from_status(cls, response: requests.Response): response.raise_for_status() except requests.HTTPError as he: raise cls( - he.response.text, request=he.request, response=he.response + he.response.text, + request=he.request, + response=he.response, ) from he diff --git a/oidckit/objects.py b/oidckit/objects.py index b447300..b51ac56 100644 --- a/oidckit/objects.py +++ b/oidckit/objects.py @@ -1,4 +1,4 @@ -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from oidckit.provider import OIDCProvider @@ -47,7 +47,8 @@ def get_user_info(self): def decode_id_token(self) -> dict: if not self._decoded_id_token: self._decoded_id_token = self.provider.decode_token( - self.id_token, nonce=self.auth_state.nonce + self.id_token, + nonce=self.auth_state.nonce, ) return self._decoded_id_token @@ -64,7 +65,9 @@ def decode_access_token(self, verify: bool = True) -> dict: """ if not self._decoded_access_token: self._decoded_access_token = self.provider.decode_token( - self.access_token, nonce=self.auth_state.nonce, verify=verify + self.access_token, + nonce=self.auth_state.nonce, + verify=verify, ) return self._decoded_access_token diff --git a/oidckit/provider.py b/oidckit/provider.py index d00dd99..864a2cb 100644 --- a/oidckit/provider.py +++ b/oidckit/provider.py @@ -3,9 +3,9 @@ import requests -from oidckit.excs import OIDCError, RemoteError from oidckit.crypto import decode_jws, get_key_from_keyset_json -from oidckit.objects import AuthenticationState, AuthenticationResult +from oidckit.excs import OIDCError, RemoteError +from oidckit.objects import AuthenticationResult, AuthenticationState class OIDCProviderConfiguration: @@ -40,7 +40,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.session.close() def build_authentication_request_params( - self, *, redirect_uri: str, state: str, request=None + self, + *, + redirect_uri: str, + state: str, + request=None, ) -> dict: return { "client_id": self.config.rp_client_id, @@ -51,7 +55,10 @@ def build_authentication_request_params( } def build_token_request_payload( - self, *, code: str, auth_state: AuthenticationState + self, + *, + code: str, + auth_state: AuthenticationState, ): return { "client_id": self.config.rp_client_id, @@ -74,7 +81,8 @@ def retrieve_token_key(self, token) -> dict: RemoteError.raise_from_status(response) self._jwks_data = response.json() return get_key_from_keyset_json( - keyset_json=self._jwks_data, token=token + keyset_json=self._jwks_data, + token=token, ) raise NotImplementedError("No idea how to get token key – subclass, please") @@ -87,7 +95,10 @@ def get_payload_data(self, token: bytes, key: dict, verify: bool = True): ) def decode_token( - self, token: str, nonce: Optional[str] = None, verify: bool = True + self, + token: str, + nonce: Optional[str] = None, + verify: bool = True, ) -> dict: token = str(token).encode("utf-8") key = self.retrieve_token_key(token) @@ -97,7 +108,7 @@ def decode_token( token_nonce = payload.get("nonce") if nonce != token_nonce: raise OIDCError( - f"Token nonce mismatch – expected {nonce}, got {token_nonce}" + f"Token nonce mismatch – expected {nonce}, got {token_nonce}", ) return payload diff --git a/oidckit/routines.py b/oidckit/routines.py index 1d4482e..d7cd60a 100644 --- a/oidckit/routines.py +++ b/oidckit/routines.py @@ -4,9 +4,9 @@ from oidckit.crypto import get_random_string from oidckit.excs import OIDCError from oidckit.objects import ( - AuthenticationState, - AuthenticationResult, AuthenticationRequest, + AuthenticationResult, + AuthenticationState, ) from oidckit.provider import OIDCProvider @@ -24,7 +24,9 @@ def build_authentication_request( state = get_random_string(state_size) params = provider.build_authentication_request_params( - request=request, redirect_uri=redirect_uri, state=state + request=request, + redirect_uri=redirect_uri, + state=state, ) if nonce_size > 0: nonce = get_random_string(nonce_size) @@ -35,7 +37,9 @@ def build_authentication_request( return AuthenticationRequest( redirect_url=f"{provider.config.op_authorization_endpoint}?{(urlencode(params))}", auth_state=AuthenticationState( - nonce=nonce, state=state, redirect_uri=redirect_uri + nonce=nonce, + state=state, + redirect_uri=redirect_uri, ), ) @@ -58,11 +62,14 @@ def process_callback_data( raise OIDCError("Unexpected state code.") token_payload = provider.build_token_request_payload( - code=code, auth_state=auth_state + code=code, + auth_state=auth_state, ) token = provider.retrieve_token(token_payload) auth_result = AuthenticationResult( - provider=provider, auth_state=auth_state, token=token + provider=provider, + auth_state=auth_state, + token=token, ) if verify: auth_result.decode_id_token()