Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customize RSA key import and JWKs endpoint response via hooks #407

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion docs/sections/settings.rst
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Settings doc should be easy to read and contain the less code possible.

Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,85 @@ A flag which toggles whether the scope is returned with successful response on i

Must be ``True`` to include ``scope`` into the successful response

Default is ``False``.
Default is ``False``.

OIDC_CLIENT_ALG_KEYS_HOOK
=========================

OPTIONAL. ``str``

A string with the location of your function hook.
Here you can customize the retrieval of the RSA keys for the ID token encoding and decoding.

.. note::
To ensure that the RSA keys provided by the ``/jwks`` endpoint is consistent with the
values retrieved from this hook function, define the ``OIDC_JWKS_RESPONSE_HOOK`` response appropriately.

The hook function receives following arguments:

* ``client``: Instance of the client.

The hook function should return a `List[jwkest.jwk.RSAKey]` of the available keys.

Default is::

def default_get_client_alg_keys(client):
"""
Takes a client and returns the set of keys associated with it.
Returns a list of keys.
"""
if client.jwt_alg == 'RS256':
keys = []
for rsakey in RSAKey.objects.all():
keys.append(jwk_RSAKey(key=importKey(rsakey.key), kid=rsakey.kid))
if not keys:
raise Exception('You must add at least one RSA Key.')
elif client.jwt_alg == 'HS256':
keys = [SYMKey(key=client.client_secret, alg=client.jwt_alg)]
else:
raise Exception('Unsupported key algorithm.')

return keys


OIDC_JWKS_RESPONSE_HOOK
=======================

OPTIONAL. ``str``

A string with the location of your function hook.
Here you can provide a customized list of JWKS that will be returned by the ``/jwks`` endpoint.

.. note::
To ensure that the RSA keys provided by the ``/jwks`` endpoint is consistent with the
values retrieved from this hook function, define the ``OIDC_CLIENT_ALG_KEYS_HOOK`` appropriately.

This hook function takes in no arguments, and should returns a list of dictionaries with the following keys:
* 'kty' with the value 'RSA'
* 'alg' with the value 'RS256'
* 'use' with the value 'sig'
* 'kid' with the kid for the key
* 'n' with the base64 representation of the modulus parameter
* 'e' with the base64 representation of the exponent parameter

Default is::

def default_get_jwks():
"""
Returns a list of dictionaries containing the JWKs for return by the ``jwks`` endpoint
"""

dic = dict(keys=[])

for rsakey in RSAKey.objects.all():
public_key = importKey(rsakey.key).publickey()
dic['keys'].append({
'kty': 'RSA',
'alg': 'RS256',
'use': 'sig',
'kid': rsakey.kid,
'n': long_to_base64(public_key.n),
'e': long_to_base64(public_key.e),
})

return dic
4 changes: 3 additions & 1 deletion example/app/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

urlpatterns = [
url(r'^$', TemplateView.as_view(template_name='home.html'), name='home'),
url(r'^accounts/login/$', auth_views.LoginView.as_view(template_name='login.html'), name='login'),
url(r'^accounts/login/$', auth_views.LoginView.as_view(
template_name='login.html'
), name='login'),
url(r'^accounts/logout/$', auth_views.LogoutView.as_view(next_page='/'), name='logout'),
url(r'^', include('oidc_provider.urls', namespace='oidc_provider')),
url(r'^admin/', admin.site.urls),
Expand Down
31 changes: 31 additions & 0 deletions oidc_provider/lib/utils/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from jwkest.jwk import SYMKey
from jwkest.jws import JWS
from jwkest.jwt import JWT
from jwkest import long_to_base64

from oidc_provider.lib.utils.common import get_issuer, run_processing_hook
from oidc_provider.lib.claims import StandardScopeClaims
Expand Down Expand Up @@ -149,6 +150,16 @@ def create_code(user, client, scope, nonce, is_authentication,


def get_client_alg_keys(client):
"""
Hook to customize RSA Key retrieval.
:param client:
:return:
"""
client_alg_keys_hook = settings.get('OIDC_CLIENT_ALG_KEYS_HOOK')
return settings.import_from_str(client_alg_keys_hook)(client)


def default_get_client_alg_keys(client):
"""
Takes a client and returns the set of keys associated with it.
Returns a list of keys.
Expand All @@ -165,3 +176,23 @@ def get_client_alg_keys(client):
raise Exception('Unsupported key algorithm.')

return keys


def default_get_jwks():
"""
Returns a list of dictionaries containing the JWKs for return by the ``jwks`` endpoint
"""
dic = dict(keys=[])

for rsakey in RSAKey.objects.all():
public_key = importKey(rsakey.key).publickey()
dic['keys'].append({
'kty': 'RSA',
'alg': 'RS256',
'use': 'sig',
'kid': rsakey.kid,
'n': long_to_base64(public_key.n),
'e': long_to_base64(public_key.e),
})

return dic
18 changes: 18 additions & 0 deletions oidc_provider/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,24 @@ def OIDC_IDTOKEN_PROCESSING_HOOK(self):
"""
return 'oidc_provider.lib.utils.common.default_idtoken_processing_hook'

@property
def OIDC_CLIENT_ALG_KEYS_HOOK(self):
"""
OPTIONAL. A string with the location of your hook.
Used to specify which keys to return for a particular client and algorithm.
Returns jwkest.jwk.SYMKey
"""
return 'oidc_provider.lib.utils.token.default_get_client_alg_keys'

@property
def OIDC_JWKS_RESPONSE_HOOK(self):
"""
OPTIONAL. A string with the location of your hook.
Used to specify the list of JWKS for your app.
Returns a list of dictionaries that will form the response of the ``jwks`` endpoint.
"""
return 'oidc_provider.lib.utils.token.default_get_jwks'

@property
def OIDC_INTROSPECTION_PROCESSING_HOOK(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion oidc_provider/version.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Version changes are not allowed in PRs.

Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.8.0'
__version__ = '0.9.0'
18 changes: 2 additions & 16 deletions oidc_provider/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
except ImportError:
from urllib.parse import urlsplit, parse_qs, urlunsplit, urlencode

from Cryptodome.PublicKey import RSA
from django.contrib.auth.views import (
redirect_to_login,
LogoutView,
Expand All @@ -26,7 +25,6 @@
from django.views.decorators.clickjacking import xframe_options_exempt
from django.views.decorators.http import require_http_methods
from django.views.generic import View
from jwkest import long_to_base64

from oidc_provider.compat import get_attr_or_callable
from oidc_provider.lib.claims import StandardScopeClaims
Expand All @@ -50,7 +48,6 @@
from oidc_provider.lib.utils.token import client_id_from_id_token
from oidc_provider.models import (
Client,
RSAKey,
ResponseType)
from oidc_provider import settings
from oidc_provider import signals
Expand Down Expand Up @@ -291,19 +288,8 @@ def get(self, request, *args, **kwargs):

class JwksView(View):
def get(self, request, *args, **kwargs):
dic = dict(keys=[])

for rsakey in RSAKey.objects.all():
public_key = RSA.importKey(rsakey.key).publickey()
dic['keys'].append({
'kty': 'RSA',
'alg': 'RS256',
'use': 'sig',
'kid': rsakey.kid,
'n': long_to_base64(public_key.n),
'e': long_to_base64(public_key.e),
})

jwks_hook = settings.get('OIDC_JWKS_RESPONSE_HOOK')
dic = settings.import_from_str(jwks_hook)()
response = JsonResponse(dic)
response['Access-Control-Allow-Origin'] = '*'

Expand Down