Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable handling of nested fields when injecting request_option in request body_json #201

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions airbyte_cdk/sources/declarative/auth/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import base64
import logging
from dataclasses import InitVar, dataclass
from typing import Any, Mapping, Union
from typing import Any, Mapping, MutableMapping, Union

import requests
from cachetools import TTLCache, cached
Expand Down Expand Up @@ -45,11 +45,6 @@ class ApiKeyAuthenticator(DeclarativeAuthenticator):
config: Config
parameters: InitVar[Mapping[str, Any]]

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._field_name = InterpolatedString.create(
self.request_option.field_name, parameters=parameters
)

@property
def auth_header(self) -> str:
options = self._get_request_options(RequestOptionType.header)
Expand All @@ -60,9 +55,9 @@ def token(self) -> str:
return self.token_provider.get_token()

def _get_request_options(self, option_type: RequestOptionType) -> Mapping[str, Any]:
options = {}
options: MutableMapping[str, Any] = {}
if self.request_option.inject_into == option_type:
options[self._field_name.eval(self.config)] = self.token
self.request_option.inject_into_request(options, self.token, self.config)
return options

def get_request_params(self) -> Mapping[str, Any]:
Expand Down
16 changes: 13 additions & 3 deletions airbyte_cdk/sources/declarative/declarative_component_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2616,25 +2616,35 @@ definitions:
enum: [RequestPath]
RequestOption:
title: Request Option
description: Specifies the key field and where in the request a component's value should be injected.
description: Specifies the key field or path and where in the request a component's value should be injected.
type: object
required:
- type
- field_name
- inject_into
properties:
type:
type: string
enum: [RequestOption]
field_name:
title: Request Option
title: Field Name
description: Configures which key should be used in the location that the descriptor is being injected into
type: string
examples:
- segment_id
interpolation_context:
- config
- parameters
field_path:
title: Field Path
description: Configures a path to be used for nested structures in JSON body requests (e.g. GraphQL queries)
type: array
items:
type: string
examples:
- ["data", "viewer", "id"]
interpolation_context:
- config
- parameters
inject_into:
title: Inject Into
description: Configures where the descriptor should be set on the HTTP requests. Note that request parameters that are already encoded in the URL path will not be duplicated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,15 @@ def _get_request_options(
options: MutableMapping[str, Any] = {}
if not stream_slice:
return options

if self.start_time_option and self.start_time_option.inject_into == option_type:
options[self.start_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore # field_name is always casted to an interpolated string
self._partition_field_start.eval(self.config)
)
start_time_value = stream_slice.get(self._partition_field_start.eval(self.config))
self.start_time_option.inject_into_request(options, start_time_value, self.config)

if self.end_time_option and self.end_time_option.inject_into == option_type:
options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore [union-attr]
self._partition_field_end.eval(self.config)
)
end_time_value = stream_slice.get(self._partition_field_end.eval(self.config))
self.end_time_option.inject_into_request(options, end_time_value, self.config)

return options

def should_be_synced(self, record: Record) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1057,11 +1057,17 @@ class InjectInto(Enum):

class RequestOption(BaseModel):
type: Literal["RequestOption"]
field_name: str = Field(
...,
field_name: Optional[str] = Field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we slowly deprecate this field in favor of field_path and therefore add this information in the description/title?

None,
description="Configures which key should be used in the location that the descriptor is being injected into",
examples=["segment_id"],
title="Request Option",
title="Field Name",
)
field_path: Optional[List[str]] = Field(
None,
description="Configures a path to be used for nested structures in JSON body requests (e.g. GraphQL queries)",
examples=[["data", "viewer", "id"]],
title="Field Path",
)
inject_into: InjectInto = Field(
...,
Expand Down
97 changes: 93 additions & 4 deletions airbyte_cdk/sources/declarative/requesters/request_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from dataclasses import InitVar, dataclass
from enum import Enum
from typing import Any, Mapping, Union
from typing import Any, List, Literal, Mapping, MutableMapping, Optional, Union

from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.types import Config


class RequestOptionType(Enum):
Expand All @@ -26,13 +27,101 @@ class RequestOption:
Describes an option to set on a request

Attributes:
field_name (str): Describes the name of the parameter to inject
field_name (str): Describes the name of the parameter to inject. Mutually exclusive with field_path.
field_path (list(str)): Describes the path to a nested field as a list of field names. Mutually exclusive with field_name.
inject_into (RequestOptionType): Describes where in the HTTP request to inject the parameter
"""

field_name: Union[InterpolatedString, str]
inject_into: RequestOptionType
parameters: InitVar[Mapping[str, Any]]
field_name: Optional[Union[InterpolatedString, str]] = None
field_path: Optional[List[Union[InterpolatedString, str]]] = None

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self.field_name = InterpolatedString.create(self.field_name, parameters=parameters)
# Validate inputs. We should expect either field_name or field_path, but not both
if self.field_name is None and self.field_path is None:
raise ValueError("RequestOption requires either a field_name or field_path")

if self.field_name is not None and self.field_path is not None:
raise ValueError(
"Only one of field_name or field_path can be provided to RequestOption"
)

if self.field_name is not None and not isinstance(
self.field_name, (str, InterpolatedString)
):
raise TypeError(f"field_name expects a string, but got {type(self.field_name)}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is overly cautious and wouldn't validate that. In Python, any parameter can be of any type but we don't validate them everywhere. We have mypy for that which should hopefully catch most of these. I think at some point, there are readability concerns that we need to balance. Was there a specific gap we wanted to address?

The some comment applies to the self.field_path condition below.


if self.field_path is not None:
if not isinstance(self.field_path, list):
raise TypeError(f"field_path expects a list, but got {type(self.field_path)}")
for value in self.field_path:
if not isinstance(value, (str, InterpolatedString)):
raise TypeError(f"field_path values must be strings, got {type(value)}")

if self.field_path is not None and self.inject_into != RequestOptionType.body_json:
raise ValueError(
"Nested field injection is only supported for body JSON injection. Please use a top-level field_name for other injection types."
)

# Convert field_name and field_path into InterpolatedString objects if they are strings
if self.field_name is not None:
self.field_name = InterpolatedString.create(self.field_name, parameters=parameters)
if self.field_path is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since only field_name or field_path will be defined, we can probably use elif as elif self.field_path is not None: instead of if self.field_path is not None: to be more explicit

self.field_path = [
InterpolatedString.create(segment, parameters=parameters)
for segment in self.field_path
]

@property
def _is_field_path(self) -> bool:
"""Returns whether this option is a field path (ie, a nested field)"""
return self.field_path is not None

def inject_into_request(
self,
target: MutableMapping[str, Any],
value: Any,
config: Config,
) -> None:
"""
Inject a request option value into a target request structure using either field_name or field_path.
For non-body-json injection, only top-level field names are supported.
For body-json injection, both field names and nested field paths are supported.

Args:
target: The request structure to inject the value into
value: The value to inject
config: The config object to use for interpolation
"""
if self._is_field_path:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we simplify the logic by having in post_init:

        if self.field_name is not None:
            self.field_path = [InterpolatedString.create(self.field_name, parameters=parameters)]

This way, we would only have one logic to maintain and it would be the field_path one.

if self.inject_into != RequestOptionType.body_json:
raise ValueError(
"Nested field injection is only supported for body JSON injection. Please use a top-level field_name for other injection types."
)

assert self.field_path is not None # for type checker
current = target
# Convert path segments into strings, evaluating any interpolated segments
# Example: ["data", "{{ config[user_type] }}", "id"] -> ["data", "admin", "id"]
*path_parts, final_key = [
str(
segment.eval(config=config)
if isinstance(segment, InterpolatedString)
else segment
)
for segment in self.field_path
]

# Build a nested dictionary structure and set the final value at the deepest level
for part in path_parts:
current = current.setdefault(part, {})
current[final_key] = value
else:
# For non-nested fields, evaluate the field name if it's an interpolated string
key = (
self.field_name.eval(config=config)
if isinstance(self.field_name, InterpolatedString)
else self.field_name
)
target[str(key)] = value
107 changes: 84 additions & 23 deletions airbyte_cdk/utils/mapping_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,104 @@
#


from typing import Any, List, Mapping, Optional, Set, Union
from typing import Any, Dict, List, Mapping, Optional, Union


def _has_nested_conflict(path1: List[str], value1: Any, path2: List[str], value2: Any) -> bool:
"""
Check if two paths conflict with each other.
e.g. ["a", "b"] conflicts with ["a", "b"] if values differ
e.g. ["a"] conflicts with ["a", "b"] (can't have both a value and a nested structure)
"""
# If one path is a prefix of the other, they conflict
shorter, longer = sorted([path1, path2], key=len)
if longer[: len(shorter)] == shorter:
return True

# If paths are the same but values differ, they conflict
if path1 == path2 and value1 != value2:
return True

return False


def _flatten_mapping(
mapping: Mapping[str, Any], prefix: Optional[List[str]] = None
) -> List[tuple[List[str], Any]]:
"""
Convert a nested mapping into a list of (path, value) pairs to make conflict detection easier.
e.g. {"a": {"b": 1}} -> [(["a", "b"], 1)]
"""
prefix = prefix or []
result = []

for key, value in mapping.items():
current_path = prefix + [key]
if isinstance(value, Mapping):
result.extend(_flatten_mapping(value, current_path))
else:
result.append((current_path, value))

return result


def combine_mappings(
mappings: List[Optional[Union[Mapping[str, Any], str]]],
) -> Union[Mapping[str, Any], str]:
"""
Combine multiple mappings into a single mapping. If any of the mappings are a string, return
that string. Raise errors in the following cases:
* If there are duplicate keys across mappings
Combine multiple mappings into a single mapping, supporting nested structures.
If any of the mappings are a string, return that string. Raise errors in the following cases:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would precise that it is "If there only one mapping that is a string, it will return this mapping regardless of the other mappings".

That being said, I don't understand this behavior though. Maybe I'm wrong but before, I think it would fail if there was a string and a mapping, right?

* If there are conflicting paths across mappings (including nested conflicts)
* If there are multiple string mappings
* If there are multiple mappings containing keys and one of them is a string
"""
all_keys: List[Set[str]] = []
for part in mappings:
if part is None:
continue
keys = set(part.keys()) if not isinstance(part, str) else set()
all_keys.append(keys)

string_options = sum(isinstance(mapping, str) for mapping in mappings)
# If more than one mapping is a string, raise a ValueError
# Count how many string options we have, ignoring None values
string_options = sum(isinstance(mapping, str) for mapping in mappings if mapping is not None)
if string_options > 1:
raise ValueError("Cannot combine multiple string options")

if string_options == 1 and sum(len(keys) for keys in all_keys) > 0:
raise ValueError("Cannot combine multiple options if one is a string")
# Filter out None values and empty mappings
non_empty_mappings = [
m for m in mappings if m is not None and not (isinstance(m, Mapping) and not m)
]

# If there is only one string option, return it
if string_options == 1:
if len(non_empty_mappings) > 1:
raise ValueError("Cannot combine multiple options if one is a string")
return next(m for m in non_empty_mappings if isinstance(m, str))

# If any mapping is a string, return it
# Convert all mappings to flat (path, value) pairs for conflict detection
all_paths: List[List[tuple[List[str], Any]]] = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like the following logic is a bit complex. Would there be a way to simplify this? Right now, I see that we:

  • Extract all paths
  • Compare all paths to find conflicts
  • Create a new mapping that we can return

Would it be simpler to do just the last one and validate for conflicts while we do that? Something like:

def combine_mappings(
    mappings: List[Optional[Union[Mapping[str, Any], str]]],
) -> Union[Mapping[str, Any], str]:
    """
    Combine multiple mappings into a single mapping, supporting nested structures.
    If any of the mappings are a string, return that string. Raise errors in the following cases:
    * If there are conflicting paths across mappings (including nested conflicts)
    * If there are multiple string mappings
    * If there are multiple mappings containing keys and one of them is a string
    """
    if not mappings:
        return {}

    # Count how many string options we have, ignoring None values
    string_options = sum(isinstance(mapping, str) for mapping in mappings if mapping is not None)
    if string_options > 1:
        raise ValueError("Cannot combine multiple string options")

    # Filter out None values and empty mappings
    non_empty_mappings = [
        m for m in mappings if m is not None and not (isinstance(m, Mapping) and not m)
    ]

    # If there is only one string option, return it
    if string_options == 1:
        if len(non_empty_mappings) > 1:
            raise ValueError("Cannot combine multiple options if one is a string")
        return next(m for m in non_empty_mappings if isinstance(m, str))

    # Convert all mappings to flat (path, value) pairs for conflict detection
    for other in mappings[1:]:
        if other:
            merge(mappings[0], other)

    return mappings[0]


def merge(a: dict, b: dict, path=[]):
    """
    Blindly and shamelessly taken from https://stackoverflow.com/a/7205107
    """
    for key in b:
        if key in a:
            if isinstance(a[key], dict) and isinstance(b[key], dict):
                merge(a[key], b[key], path + [str(key)])
            elif a[key] != b[key]:
                raise ValueError('Duplicate keys')
        else:
            a[key] = b[key]
    return a

If we are afraid of modifying the mappings in memory, we can create a deepcopy of it as well.

for mapping in mappings:
if isinstance(mapping, str):
return mapping
if mapping is None or not isinstance(mapping, Mapping):
continue
all_paths.append(_flatten_mapping(mapping))

# Check each path against all other paths for conflicts
# Conflicts occur when the same path has different values or when one path is a prefix of another
# e.g. {"a": 1} and {"a": {"b": 2}} conflict because "a" can't be both a value and a nested structure
for i, paths1 in enumerate(all_paths):
for path1, value1 in paths1:
for paths2 in all_paths[i + 1 :]:
for path2, value2 in paths2:
if _has_nested_conflict(path1, value1, path2, value2):
raise ValueError(
f"Duplicate keys or nested path conflict found: {'.'.join(path1)} conflicts with {'.'.join(path2)}"
)

# If there are duplicate keys across mappings, raise a ValueError
intersection = set().union(*all_keys)
if len(intersection) < sum(len(keys) for keys in all_keys):
raise ValueError(f"Duplicate keys found: {intersection}")
# If no conflicts were found, merge all mappings
result: Dict[str, Any] = {}
for mapping in mappings:
if mapping is None or not isinstance(mapping, Mapping):
continue
for path, value in _flatten_mapping(mapping):
current = result
*prefix, last = path
# Create nested dictionaries for each prefix segment
for key in prefix:
current = current.setdefault(key, {})
current[last] = value

# Return the combined mappings
return {key: value for mapping in mappings if mapping for key, value in mapping.items()} # type: ignore # mapping can't be string here
return result
35 changes: 35 additions & 0 deletions unit_tests/sources/declarative/auth/test_token_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,38 @@ def test_api_key_authenticator_inject(
parameters=parameters,
)
assert {expected_field_name: expected_field_value} == getattr(token_auth, validation_fn)()


@pytest.mark.parametrize(
"field_path, token, expected_result",
[
(
["data", "auth", "token"],
"test-token",
{"data": {"auth": {"token": "test-token"}}},
),
(
["api", "{{ config.api_version }}", "auth", "token"],
"test-token",
{"api": {"v2": {"auth": {"token": "test-token"}}}},
),
],
ids=["Basic nested structure", "Nested with config interpolation"],
)
def test_api_key_authenticator_nested_token_injection(field_path, token, expected_result):
"""Test that the ApiKeyAuthenticator can properly inject tokens into nested structures when using body_json"""
config = {"api_version": "v2"}
parameters = {"auth_type": "bearer"}

token_provider = InterpolatedStringTokenProvider(
config=config, api_token=token, parameters=parameters
)
token_auth = ApiKeyAuthenticator(
request_option=RequestOption(
inject_into=RequestOptionType.body_json, field_path=field_path, parameters=parameters
),
token_provider=token_provider,
config=config,
parameters=parameters,
)
assert token_auth.get_request_body_json() == expected_result
Loading
Loading