Skip to content

Commit

Permalink
feat(low-code): pass refresh headers to oauth (#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
darynaishchenko authored Jan 16, 2025
1 parent 40a9f1e commit 2185bd9
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 4 deletions.
8 changes: 8 additions & 0 deletions airbyte_cdk/sources/declarative/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
token_expiry_date_format str: format of the datetime; provide it if expires_in is returned in datetime instead of seconds
token_expiry_is_time_of_expiration bool: set True it if expires_in is returned as time of expiration instead of the number seconds until expiration
refresh_request_body (Optional[Mapping[str, Any]]): The request body to send in the refresh request
refresh_request_headers (Optional[Mapping[str, Any]]): The request headers to send in the refresh request
grant_type: The grant_type to request for access_token. If set to refresh_token, the refresh_token parameter has to be provided
message_repository (MessageRepository): the message repository used to emit logs on HTTP requests
"""
Expand All @@ -61,6 +62,7 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
expires_in_name: Union[InterpolatedString, str] = "expires_in"
refresh_token_name: Union[InterpolatedString, str] = "refresh_token"
refresh_request_body: Optional[Mapping[str, Any]] = None
refresh_request_headers: Optional[Mapping[str, Any]] = None
grant_type_name: Union[InterpolatedString, str] = "grant_type"
grant_type: Union[InterpolatedString, str] = "refresh_token"
message_repository: MessageRepository = NoopMessageRepository()
Expand Down Expand Up @@ -101,6 +103,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._refresh_request_body = InterpolatedMapping(
self.refresh_request_body or {}, parameters=parameters
)
self._refresh_request_headers = InterpolatedMapping(
self.refresh_request_headers or {}, parameters=parameters
)
self._token_expiry_date: pendulum.DateTime = (
pendulum.parse(
InterpolatedString.create(self.token_expiry_date, parameters=parameters).eval(
Expand Down Expand Up @@ -178,6 +183,9 @@ def get_grant_type(self) -> str:
def get_refresh_request_body(self) -> Mapping[str, Any]:
return self._refresh_request_body.eval(self.config)

def get_refresh_request_headers(self) -> Mapping[str, Any]:
return self._refresh_request_headers.eval(self.config)

def get_token_expiry_date(self) -> pendulum.DateTime:
return self._token_expiry_date # type: ignore # _token_expiry_date is a pendulum.DateTime. It is never None despite what mypy thinks

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,14 @@ definitions:
- applicationId: "{{ config['application_id'] }}"
applicationSecret: "{{ config['application_secret'] }}"
token: "{{ config['token'] }}"
refresh_request_headers:
title: Refresh Request Headers
description: Headers of the request sent to get a new access token.
type: object
additionalProperties: true
examples:
- Authorization: "<AUTH_TOKEN>"
Content-Type: "application/x-www-form-urlencoded"
scopes:
title: Scopes
description: List of scopes that should be granted to the access token.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,17 @@ class OAuthAuthenticator(BaseModel):
],
title="Refresh Request Body",
)
refresh_request_headers: Optional[Dict[str, Any]] = Field(
None,
description="Headers of the request sent to get a new access token.",
examples=[
{
"Authorization": "<AUTH_TOKEN>",
"Content-Type": "application/x-www-form-urlencoded",
}
],
title="Refresh Request Headers",
)
scopes: Optional[List[str]] = Field(
None,
description="List of scopes that should be granted to the access token.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1919,6 +1919,9 @@ def create_oauth_authenticator(
refresh_request_body=InterpolatedMapping(
model.refresh_request_body or {}, parameters=model.parameters or {}
).eval(config),
refresh_request_headers=InterpolatedMapping(
model.refresh_request_headers or {}, parameters=model.parameters or {}
).eval(config),
scopes=model.scopes,
token_expiry_date_format=model.token_expiry_date_format,
message_repository=self._message_repository,
Expand All @@ -1938,6 +1941,7 @@ def create_oauth_authenticator(
grant_type_name=model.grant_type_name or "grant_type",
grant_type=model.grant_type or "refresh_token",
refresh_request_body=model.refresh_request_body,
refresh_request_headers=model.refresh_request_headers,
refresh_token_name=model.refresh_token_name or "refresh_token",
refresh_token=model.refresh_token,
scopes=model.scopes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ def build_refresh_request_body(self) -> Mapping[str, Any]:

return payload

def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
"""
Returns the request headers to set on the refresh request
"""
headers = self.get_refresh_request_headers()
return headers if headers else None

def _wrap_refresh_token_exception(
self, exception: requests.exceptions.RequestException
) -> bool:
Expand Down Expand Up @@ -128,6 +136,7 @@ def _get_refresh_access_token_response(self) -> Any:
method="POST",
url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected.
data=self.build_refresh_request_body(),
headers=self.build_refresh_request_headers(),
)
if response.ok:
response_json = response.json()
Expand Down Expand Up @@ -254,6 +263,10 @@ def get_expires_in_name(self) -> str:
def get_refresh_request_body(self) -> Mapping[str, Any]:
"""Returns the request body to set on the refresh request"""

@abstractmethod
def get_refresh_request_headers(self) -> Mapping[str, Any]:
"""Returns the request headers to set on the refresh request"""

@abstractmethod
def get_grant_type(self) -> str:
"""Returns grant_type specified for requesting access_token"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
access_token_name: str = "access_token",
expires_in_name: str = "expires_in",
refresh_request_body: Mapping[str, Any] | None = None,
refresh_request_headers: Mapping[str, Any] | None = None,
grant_type_name: str = "grant_type",
grant_type: str = "refresh_token",
token_expiry_is_time_of_expiration: bool = False,
Expand All @@ -57,6 +58,7 @@ def __init__(
self._access_token_name = access_token_name
self._expires_in_name = expires_in_name
self._refresh_request_body = refresh_request_body
self._refresh_request_headers = refresh_request_headers
self._grant_type_name = grant_type_name
self._grant_type = grant_type

Expand Down Expand Up @@ -101,6 +103,9 @@ def get_expires_in_name(self) -> str:
def get_refresh_request_body(self) -> Mapping[str, Any]:
return self._refresh_request_body # type: ignore [return-value]

def get_refresh_request_headers(self) -> Mapping[str, Any]:
return self._refresh_request_headers # type: ignore [return-value]

def get_grant_type_name(self) -> str:
return self._grant_type_name

Expand Down Expand Up @@ -149,6 +154,7 @@ def __init__(
expires_in_name: str = "expires_in",
refresh_token_name: str = "refresh_token",
refresh_request_body: Mapping[str, Any] | None = None,
refresh_request_headers: Mapping[str, Any] | None = None,
grant_type_name: str = "grant_type",
grant_type: str = "refresh_token",
client_id_name: str = "client_id",
Expand All @@ -174,6 +180,7 @@ def __init__(
expires_in_name (str, optional): Name of the name of the field that characterizes when the current access token will expire, used to parse the refresh token response. Defaults to "expires_in".
refresh_token_name (str, optional): Name of the name of the refresh token field, used to parse the refresh token response. Defaults to "refresh_token".
refresh_request_body (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request body. Defaults to None.
refresh_request_headers (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request headers. Defaults to None.
grant_type (str, optional): OAuth grant type. Defaults to "refresh_token".
client_id (Optional[str]): The client id to authenticate. If not specified, defaults to credentials.client_id in the config object.
client_secret (Optional[str]): The client secret to authenticate. If not specified, defaults to credentials.client_secret in the config object.
Expand Down Expand Up @@ -220,6 +227,7 @@ def __init__(
access_token_name=access_token_name,
expires_in_name=expires_in_name,
refresh_request_body=refresh_request_body,
refresh_request_headers=refresh_request_headers,
grant_type_name=self._grant_type_name,
grant_type=grant_type,
token_expiry_date_format=token_expiry_date_format,
Expand Down
72 changes: 70 additions & 2 deletions unit_tests/sources/declarative/auth/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,42 @@ def test_refresh_request_body(self):
}
assert body == expected

def test_refresh_request_headers(self):
"""
Request headers should match given configuration.
"""
oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
client_id="{{ config['client_id'] }}",
client_secret="{{ config['client_secret'] }}",
refresh_token="{{ parameters['refresh_token'] }}",
config=config,
token_expiry_date="{{ config['token_expiry_date'] }}",
refresh_request_headers={
"Authorization": "Basic {{ [config['client_id'], config['client_secret']] | join(':') | base64encode }}",
"Content-Type": "application/x-www-form-urlencoded",
},
parameters=parameters,
)
headers = oauth.build_refresh_request_headers()
expected = {
"Authorization": "Basic c29tZV9jbGllbnRfaWQ6c29tZV9jbGllbnRfc2VjcmV0",
"Content-Type": "application/x-www-form-urlencoded",
}
assert headers == expected

oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
client_id="{{ config['client_id'] }}",
client_secret="{{ config['client_secret'] }}",
refresh_token="{{ parameters['refresh_token'] }}",
config=config,
token_expiry_date="{{ config['token_expiry_date'] }}",
parameters=parameters,
)
headers = oauth.build_refresh_request_headers()
assert headers is None

def test_refresh_with_encode_config_params(self):
oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
Expand Down Expand Up @@ -191,6 +227,36 @@ def test_refresh_access_token(self, mocker):
filtered = filter_secrets("access_token")
assert filtered == "****"

def test_refresh_access_token_when_headers_provided(self, mocker):
expected_headers = {
"Authorization": "Bearer some_access_token",
"Content-Type": "application/x-www-form-urlencoded",
}
oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
client_id="{{ config['client_id'] }}",
client_secret="{{ config['client_secret'] }}",
refresh_token="{{ config['refresh_token'] }}",
config=config,
scopes=["scope1", "scope2"],
token_expiry_date="{{ config['token_expiry_date'] }}",
refresh_request_headers=expected_headers,
parameters={},
)

resp.status_code = 200
mocker.patch.object(
resp, "json", return_value={"access_token": "access_token", "expires_in": 1000}
)
mocked_request = mocker.patch.object(
requests, "request", side_effect=mock_request, autospec=True
)
token = oauth.refresh_access_token()

assert ("access_token", 1000) == token

assert mocked_request.call_args.kwargs["headers"] == expected_headers

def test_refresh_access_token_missing_access_token(self, mocker):
oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
Expand Down Expand Up @@ -371,7 +437,9 @@ def test_error_handling(self, mocker):
assert e.value.errno == 400


def mock_request(method, url, data):
def mock_request(method, url, data, headers):
if url == "refresh_end":
return resp
raise Exception(f"Error while refreshing access token with request: {method}, {url}, {data}")
raise Exception(
f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,38 @@ def test_refresh_request_body(self):
}
assert body == expected

def test_refresh_request_headers(self):
"""
Request headers should match given configuration.
"""
oauth = Oauth2Authenticator(
token_refresh_endpoint="refresh_end",
client_id="some_client_id",
client_secret="some_client_secret",
refresh_token="some_refresh_token",
token_expiry_date=pendulum.now().add(days=3),
refresh_request_headers={
"Authorization": "Bearer some_refresh_token",
"Content-Type": "application/x-www-form-urlencoded",
},
)
headers = oauth.build_refresh_request_headers()
expected = {
"Authorization": "Bearer some_refresh_token",
"Content-Type": "application/x-www-form-urlencoded",
}
assert headers == expected

oauth = Oauth2Authenticator(
token_refresh_endpoint="refresh_end",
client_id="some_client_id",
client_secret="some_client_secret",
refresh_token="some_refresh_token",
token_expiry_date=pendulum.now().add(days=3),
)
headers = oauth.build_refresh_request_headers()
assert headers is None

def test_refresh_request_body_with_keys_override(self):
"""
Request body should match given configuration.
Expand Down Expand Up @@ -245,6 +277,35 @@ def test_refresh_access_token(self, mocker):
assert isinstance(expires_in, str)
assert ("access_token", "2022-04-24T00:00:00Z") == (token, expires_in)

def test_refresh_access_token_when_headers_provided(self, mocker):
expected_headers = {
"Authorization": "Bearer some_access_token",
"Content-Type": "application/x-www-form-urlencoded",
}
oauth = Oauth2Authenticator(
token_refresh_endpoint="refresh_end",
client_id="some_client_id",
client_secret="some_client_secret",
refresh_token="some_refresh_token",
scopes=["scope1", "scope2"],
token_expiry_date=pendulum.now().add(days=3),
refresh_request_headers=expected_headers,
)

resp.status_code = 200
mocker.patch.object(
resp, "json", return_value={"access_token": "access_token", "expires_in": 1000}
)
mocked_request = mocker.patch.object(
requests, "request", side_effect=mock_request, autospec=True
)
token, expires_in = oauth.refresh_access_token()

assert isinstance(expires_in, int)
assert ("access_token", 1000) == (token, expires_in)

assert mocked_request.call_args.kwargs["headers"] == expected_headers

@pytest.mark.parametrize(
"expires_in_response, token_expiry_date_format, expected_token_expiry_date",
[
Expand Down Expand Up @@ -557,7 +618,9 @@ def test_refresh_access_token(self, mocker, connector_config):
)


def mock_request(method, url, data):
def mock_request(method, url, data, headers):
if url == "refresh_end":
return resp
raise Exception(f"Error while refreshing access token with request: {method}, {url}, {data}")
raise Exception(
f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}"
)

0 comments on commit 2185bd9

Please sign in to comment.