Skip to content

Commit

Permalink
Implement the error handling (GH-28)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArtyomVancyan authored Oct 13, 2023
2 parents f3ae78f + b946f89 commit 3a82fae
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 23 deletions.
44 changes: 34 additions & 10 deletions docs/references/tutorials.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,20 @@ scopes required for the API endpoint.

```mermaid
flowchart TB
subgraph level2["request (Starlette's Request object)"]
direction TB
subgraph level1["auth (Starlette's extended Auth Credentials)"]
subgraph level2["request (Starlette's Request object)"]
direction TB
subgraph level0["provider (OAuth2 provider with client's credentials)"]
subgraph level1["auth (Starlette's extended Auth Credentials)"]
direction TB
token["access_token (Access token for the specified scopes)"]
subgraph level0["provider (OAuth2 provider with client's credentials)"]
direction TB
token["access_token (Access token for the specified scopes)"]
end
end
end
end
style level2 fill:#00948680,color:#f6f6f7,stroke:#3c3c43;
style level1 fill:#2b75a080,color:#f6f6f7,stroke:#3c3c43;
style level0 fill:#5c837480,color:#f6f6f7,stroke:#3c3c43;
style token fill:#44506980,color:#f6f6f7,stroke:#3c3c43;
style level2 fill: #00948680, color: #f6f6f7, stroke: #3c3c43;
style level1 fill: #2b75a080, color: #f6f6f7, stroke: #3c3c43;
style level0 fill: #5c837480, color: #f6f6f7, stroke: #3c3c43;
style token fill: #44506980, color: #f6f6f7, stroke: #3c3c43;
```

:::
Expand Down Expand Up @@ -129,6 +129,30 @@ approach is useful when there missing mandatory attributes in `request.user` for
database. You need to define a route for provisioning and provide it as `redirect_uri`, so
the [user context](/integration/integration#user-context) will be available for usage.

## Error handling

The exceptions that possibly can occur when using the library are reraised as `HTTPException` with the appropriate
status code and a message describing the actual error cause. So they can be handled in a natural way by following the
FastAPI [docs](https://fastapi.tiangolo.com/tutorial/handling-errors/) on handling errors and using the exceptions from
the `fastapi_oauth2.exceptions` module.

```python
from fastapi_oauth2.exceptions import OAuth2AuthenticationError

@app.exception_handler(OAuth2AuthenticationError)
async def error_handler(request: Request, exc: OAuth2AuthenticationError):
return RedirectResponse(url="/login", status_code=303)
```

The complete list of exceptions is the following.

- `OAuth2Error` - Base exception for all errors raised by the FastAPI OAuth2 library.
- `OAuth2AuthenticationError` - An exception is raised when the authentication fails.
- `OAuth2InvalidRequestError` - An exception is raised when the request is invalid.

The request is considered invalid when one of the mandatory parameters, such as `state` or `code` is missing or the
request fails. And the errors that occur during the OAuth steps are considered authentication errors.

<style>
.info, .details {
border: 0;
Expand Down
12 changes: 12 additions & 0 deletions examples/demonstration/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from fastapi import APIRouter
from fastapi import FastAPI
from fastapi import Request
from fastapi.staticfiles import StaticFiles
from sqlalchemy.orm import Session
from starlette.responses import RedirectResponse

from config import oauth2_config
from database import Base
from database import engine
from database import get_db
from fastapi_oauth2.exceptions import OAuth2Error
from fastapi_oauth2.middleware import Auth
from fastapi_oauth2.middleware import OAuth2Middleware
from fastapi_oauth2.middleware import User
Expand Down Expand Up @@ -37,6 +40,15 @@ async def on_auth(auth: Auth, user: User):


app = FastAPI()


# https://fastapi.tiangolo.com/tutorial/handling-errors/
@app.exception_handler(OAuth2Error)
async def error_handler(request: Request, e: OAuth2Error):
print("An error occurred in OAuth2Middleware", e)
return RedirectResponse(url="/", status_code=303)


app.include_router(router_api)
app.include_router(router_ssr)
app.include_router(oauth2_router)
Expand Down
27 changes: 16 additions & 11 deletions src/fastapi_oauth2/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,18 @@
from urllib.parse import urljoin

import httpx
from oauthlib.oauth2 import OAuth2Error
from oauthlib.oauth2 import WebApplicationClient
from oauthlib.oauth2.rfc6749.errors import CustomOAuth2Error
from social_core.backends.oauth import BaseOAuth2
from social_core.exceptions import AuthException
from social_core.strategy import BaseStrategy
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import RedirectResponse

from .claims import Claims
from .client import OAuth2Client


class OAuth2LoginError(HTTPException):
"""Raised when any login-related error occurs."""
from .exceptions import OAuth2AuthenticationError
from .exceptions import OAuth2InvalidRequestError


class OAuth2Strategy(BaseStrategy):
Expand Down Expand Up @@ -56,6 +54,7 @@ class OAuth2Core:
_oauth_client: Optional[WebApplicationClient] = None
_authorization_endpoint: str = None
_token_endpoint: str = None
_state: str = None

def __init__(self, client: OAuth2Client) -> None:
self.client_id = client.client_id
Expand Down Expand Up @@ -83,6 +82,8 @@ def authorization_url(self, request: Request) -> str:
oauth2_query_params = dict(state=state, scope=self.scope, redirect_uri=redirect_uri)
oauth2_query_params.update(request.query_params)

self._state = oauth2_query_params.get("state")

return str(self._oauth_client.prepare_request_uri(
self._authorization_endpoint,
**oauth2_query_params,
Expand All @@ -93,9 +94,11 @@ def authorization_redirect(self, request: Request) -> RedirectResponse:

async def token_data(self, request: Request, **httpx_client_args) -> dict:
if not request.query_params.get("code"):
raise OAuth2LoginError(400, "'code' parameter was not found in callback request")
raise OAuth2InvalidRequestError(400, "'code' parameter was not found in callback request")
if not request.query_params.get("state"):
raise OAuth2LoginError(400, "'state' parameter was not found in callback request")
raise OAuth2InvalidRequestError(400, "'state' parameter was not found in callback request")
if request.query_params.get("state") != self._state:
raise OAuth2InvalidRequestError(400, "'state' parameter does not match")

redirect_uri = self.get_redirect_uri(request)
scheme = "http" if request.auth.http else "https"
Expand All @@ -112,12 +115,14 @@ async def token_data(self, request: Request, **httpx_client_args) -> dict:
headers.update({"Accept": "application/json"})
auth = httpx.BasicAuth(self.client_id, self.client_secret)
async with httpx.AsyncClient(auth=auth, **httpx_client_args) as session:
response = await session.post(token_url, headers=headers, content=content)
try:
response = await session.post(token_url, headers=headers, content=content)
self._oauth_client.parse_request_body_response(json.dumps(response.json()))
return self.standardize(self.backend.user_data(self.access_token))
except (CustomOAuth2Error, Exception) as e:
raise OAuth2LoginError(400, str(e))
except (OAuth2Error, httpx.HTTPError) as e:
raise OAuth2InvalidRequestError(400, str(e))
except (AuthException, Exception) as e:
raise OAuth2AuthenticationError(401, str(e))

async def token_redirect(self, request: Request, **kwargs) -> RedirectResponse:
access_token = request.auth.jwt_create(await self.token_data(request, **kwargs))
Expand Down
13 changes: 13 additions & 0 deletions src/fastapi_oauth2/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from starlette.exceptions import HTTPException


class OAuth2Error(HTTPException):
"""Base OAuth2 exception."""


class OAuth2AuthenticationError(OAuth2Error):
"""Raised when authentication fails."""


class OAuth2InvalidRequestError(OAuth2Error):
"""Raised when request is invalid."""
11 changes: 10 additions & 1 deletion src/fastapi_oauth2/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from typing import Union

from fastapi.security.utils import get_authorization_scheme_param
from jose.exceptions import JOSEError
from jose.jwt import decode as jwt_decode
from jose.jwt import encode as jwt_encode
from starlette.authentication import AuthCredentials
from starlette.authentication import AuthenticationBackend
from starlette.authentication import BaseUser
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.types import ASGIApp
from starlette.types import Receive
from starlette.types import Scope
Expand Down Expand Up @@ -139,7 +141,14 @@ def __init__(
config = OAuth2Config(**config)
elif not isinstance(config, OAuth2Config):
raise TypeError("config is not a valid type")
self.default_application_middleware = app
self.auth_middleware = AuthenticationMiddleware(app, backend=OAuth2Backend(config, callback), **kwargs)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.auth_middleware(scope, receive, send)
if scope["type"] == "http":
try:
return await self.auth_middleware(scope, receive, send)
except (JOSEError, Exception) as e:
middleware = PlainTextResponse(str(e), status_code=401)
return await middleware(scope, receive, send)
await self.default_application_middleware(scope, receive, send)
21 changes: 20 additions & 1 deletion tests/test_oauth2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from urllib.parse import parse_qs
from urllib.parse import urlencode
from urllib.parse import urlparse

import pytest
from httpx import AsyncClient
Expand All @@ -14,7 +16,11 @@ async def oauth2_workflow(get_app, idp=False, ssr=True, authorize_query="", toke
response = await client.get("/oauth2/test/authorize" + authorize_query) # Get authorization endpoint
authorization_endpoint = response.headers.get("location") if ssr else response.json().get("url")
response = await client.get(authorization_endpoint) # Authorize
response = await client.get(response.headers.get("location") + token_query) # Obtain token
token_url = response.headers.get("location")
query = {k: v[0] for k, v in parse_qs(urlparse(token_url).query).items()}
query.update({k: v[0] for k, v in parse_qs(token_query).items()})
token_url = "%s?%s" % (token_url.split("?")[0], urlencode(query))
response = await client.get(token_url) # Obtain token

response = await client.get("/user", headers=dict(
Authorization=jwt_encode(response.json(), "") # Set token
Expand Down Expand Up @@ -43,3 +49,16 @@ async def test_oauth2_pkce_workflow(get_app):
tq = "&" + urlencode(dict(code_verifier=code_verifier))
await oauth2_workflow(get_app, idp=True, authorize_query=aq, token_query=tq)
await oauth2_workflow(get_app, idp=True, ssr=False, authorize_query=aq, token_query=tq, use_header=True)


@pytest.mark.anyio
async def test_oauth2_csrf_workflow(get_app):
for aq, tq in [
("?state=test_state", "&state=test_state"),
("?state=test_state", "&state=test_wrong_state")
]:
try:
await oauth2_workflow(get_app, idp=True, authorize_query=aq, token_query=tq)
await oauth2_workflow(get_app, idp=True, ssr=False, authorize_query=aq, token_query=tq, use_header=True)
except AssertionError:
assert aq != tq

0 comments on commit 3a82fae

Please sign in to comment.