Skip to content

Commit

Permalink
Apply Ruff autofixes and formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Aug 26, 2024
1 parent ea140d3 commit 8d99a87
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 20 deletions.
9 changes: 6 additions & 3 deletions oidckit/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@


def decode_jws(
payload: bytes, key: dict, expected_algorithm: str, verify: bool = True
payload: bytes,
key: dict,
expected_algorithm: str,
verify: bool = True,
) -> dict:
jws = JWS.from_compact(payload)
if verify:
Expand All @@ -18,7 +21,7 @@ def decode_jws(

if alg != expected_algorithm:
raise OIDCError(
f"Algorithm mismatch: offered {alg} is not expected {expected_algorithm}"
f"Algorithm mismatch: offered {alg} is not expected {expected_algorithm}",
)

jwk = JWK.from_json(key)
Expand All @@ -40,7 +43,7 @@ def get_key_from_keyset_json(keyset_json: dict, token: bytes) -> dict:
jwk_alg = jwk.get("alg")
if jwk_alg and jwk_alg != expected_alg:
raise OIDCError(
f"kid {header.kid} has alg {jwk_alg}, was expecting {header.alg}"
f"kid {header.kid} has alg {jwk_alg}, was expecting {header.alg}",
)
return jwk
raise OIDCError(f"Keyset has no matching key for kid {expected_kid}.")
Expand Down
4 changes: 3 additions & 1 deletion oidckit/excs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,7 @@ def raise_from_status(cls, response: requests.Response):
response.raise_for_status()
except requests.HTTPError as he:
raise cls(
he.response.text, request=he.request, response=he.response
he.response.text,
request=he.request,
response=he.response,
) from he
9 changes: 6 additions & 3 deletions oidckit/objects.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from oidckit.provider import OIDCProvider
Expand Down Expand Up @@ -47,7 +47,8 @@ def get_user_info(self):
def decode_id_token(self) -> dict:
if not self._decoded_id_token:
self._decoded_id_token = self.provider.decode_token(
self.id_token, nonce=self.auth_state.nonce
self.id_token,
nonce=self.auth_state.nonce,
)
return self._decoded_id_token

Expand All @@ -64,7 +65,9 @@ def decode_access_token(self, verify: bool = True) -> dict:
"""
if not self._decoded_access_token:
self._decoded_access_token = self.provider.decode_token(
self.access_token, nonce=self.auth_state.nonce, verify=verify
self.access_token,
nonce=self.auth_state.nonce,
verify=verify,
)
return self._decoded_access_token

Expand Down
25 changes: 18 additions & 7 deletions oidckit/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import requests

from oidckit.excs import OIDCError, RemoteError
from oidckit.crypto import decode_jws, get_key_from_keyset_json
from oidckit.objects import AuthenticationState, AuthenticationResult
from oidckit.excs import OIDCError, RemoteError
from oidckit.objects import AuthenticationResult, AuthenticationState


class OIDCProviderConfiguration:
Expand Down Expand Up @@ -40,7 +40,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.session.close()

def build_authentication_request_params(
self, *, redirect_uri: str, state: str, request=None
self,
*,
redirect_uri: str,
state: str,
request=None,
) -> dict:
return {
"client_id": self.config.rp_client_id,
Expand All @@ -51,7 +55,10 @@ def build_authentication_request_params(
}

def build_token_request_payload(
self, *, code: str, auth_state: AuthenticationState
self,
*,
code: str,
auth_state: AuthenticationState,
):
return {
"client_id": self.config.rp_client_id,
Expand All @@ -74,7 +81,8 @@ def retrieve_token_key(self, token) -> dict:
RemoteError.raise_from_status(response)
self._jwks_data = response.json()
return get_key_from_keyset_json(
keyset_json=self._jwks_data, token=token
keyset_json=self._jwks_data,
token=token,
)
raise NotImplementedError("No idea how to get token key – subclass, please")

Expand All @@ -87,7 +95,10 @@ def get_payload_data(self, token: bytes, key: dict, verify: bool = True):
)

def decode_token(
self, token: str, nonce: Optional[str] = None, verify: bool = True
self,
token: str,
nonce: Optional[str] = None,
verify: bool = True,
) -> dict:
token = str(token).encode("utf-8")
key = self.retrieve_token_key(token)
Expand All @@ -97,7 +108,7 @@ def decode_token(
token_nonce = payload.get("nonce")
if nonce != token_nonce:
raise OIDCError(
f"Token nonce mismatch – expected {nonce}, got {token_nonce}"
f"Token nonce mismatch – expected {nonce}, got {token_nonce}",
)
return payload

Expand Down
19 changes: 13 additions & 6 deletions oidckit/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from oidckit.crypto import get_random_string
from oidckit.excs import OIDCError
from oidckit.objects import (
AuthenticationState,
AuthenticationResult,
AuthenticationRequest,
AuthenticationResult,
AuthenticationState,
)
from oidckit.provider import OIDCProvider

Expand All @@ -24,7 +24,9 @@ def build_authentication_request(
state = get_random_string(state_size)

params = provider.build_authentication_request_params(
request=request, redirect_uri=redirect_uri, state=state
request=request,
redirect_uri=redirect_uri,
state=state,
)
if nonce_size > 0:
nonce = get_random_string(nonce_size)
Expand All @@ -35,7 +37,9 @@ def build_authentication_request(
return AuthenticationRequest(
redirect_url=f"{provider.config.op_authorization_endpoint}?{(urlencode(params))}",
auth_state=AuthenticationState(
nonce=nonce, state=state, redirect_uri=redirect_uri
nonce=nonce,
state=state,
redirect_uri=redirect_uri,
),
)

Expand All @@ -58,11 +62,14 @@ def process_callback_data(
raise OIDCError("Unexpected state code.")

token_payload = provider.build_token_request_payload(
code=code, auth_state=auth_state
code=code,
auth_state=auth_state,
)
token = provider.retrieve_token(token_payload)
auth_result = AuthenticationResult(
provider=provider, auth_state=auth_state, token=token
provider=provider,
auth_state=auth_state,
token=token,
)
if verify:
auth_result.decode_id_token()
Expand Down

0 comments on commit 8d99a87

Please sign in to comment.