diff --git a/airbyte_cdk/sources/declarative/auth/token.py b/airbyte_cdk/sources/declarative/auth/token.py index 12fb899b9..caecf9d2c 100644 --- a/airbyte_cdk/sources/declarative/auth/token.py +++ b/airbyte_cdk/sources/declarative/auth/token.py @@ -5,7 +5,7 @@ import base64 import logging from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Union +from typing import Any, Mapping, MutableMapping, Union import requests from cachetools import TTLCache, cached @@ -45,11 +45,6 @@ class ApiKeyAuthenticator(DeclarativeAuthenticator): config: Config parameters: InitVar[Mapping[str, Any]] - def __post_init__(self, parameters: Mapping[str, Any]) -> None: - self._field_name = InterpolatedString.create( - self.request_option.field_name, parameters=parameters - ) - @property def auth_header(self) -> str: options = self._get_request_options(RequestOptionType.header) @@ -60,9 +55,9 @@ def token(self) -> str: return self.token_provider.get_token() def _get_request_options(self, option_type: RequestOptionType) -> Mapping[str, Any]: - options = {} + options: MutableMapping[str, Any] = {} if self.request_option.inject_into == option_type: - options[self._field_name.eval(self.config)] = self.token + self.request_option.inject_into_request(options, self.token, self.config) return options def get_request_params(self) -> Mapping[str, Any]: diff --git a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml index 687943360..cc3369682 100644 --- a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml +++ b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml @@ -2798,25 +2798,35 @@ definitions: enum: [RequestPath] RequestOption: title: Request Option - description: Specifies the key field and where in the request a component's value should be injected. + description: Specifies the key field or path and where in the request a component's value should be injected. type: object required: - type - - field_name - inject_into properties: type: type: string enum: [RequestOption] field_name: - title: Request Option - description: Configures which key should be used in the location that the descriptor is being injected into + title: Field Name + description: Configures which key should be used in the location that the descriptor is being injected into. We hope to eventually deprecate this field in favor of `field_path` for all request_options, but must currently maintain it for backwards compatibility in the Builder. type: string examples: - segment_id interpolation_context: - config - parameters + field_path: + title: Field Path + description: Configures a path to be used for nested structures in JSON body requests (e.g. GraphQL queries) + type: array + items: + type: string + examples: + - ["data", "viewer", "id"] + interpolation_context: + - config + - parameters inject_into: title: Inject Into description: Configures where the descriptor should be set on the HTTP requests. Note that request parameters that are already encoded in the URL path will not be duplicated. diff --git a/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py b/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py index d6d329aec..8ef1c89a4 100644 --- a/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py @@ -365,14 +365,15 @@ def _get_request_options( options: MutableMapping[str, Any] = {} if not stream_slice: return options + if self.start_time_option and self.start_time_option.inject_into == option_type: - options[self.start_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore # field_name is always casted to an interpolated string - self._partition_field_start.eval(self.config) - ) + start_time_value = stream_slice.get(self._partition_field_start.eval(self.config)) + self.start_time_option.inject_into_request(options, start_time_value, self.config) + if self.end_time_option and self.end_time_option.inject_into == option_type: - options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore [union-attr] - self._partition_field_end.eval(self.config) - ) + end_time_value = stream_slice.get(self._partition_field_end.eval(self.config)) + self.end_time_option.inject_into_request(options, end_time_value, self.config) + return options def should_be_synced(self, record: Record) -> bool: diff --git a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py index fa4a00d18..5e9ca1dc5 100644 --- a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py +++ b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py @@ -719,7 +719,7 @@ class HttpResponseFilter(BaseModel): class TypesMap(BaseModel): target_type: Union[str, List[str]] current_type: Union[str, List[str]] - condition: Optional[str] + condition: Optional[str] = None class SchemaTypeIdentifier(BaseModel): @@ -797,14 +797,11 @@ class DpathFlattenFields(BaseModel): field_path: List[str] = Field( ..., description="A path to field that needs to be flattened.", - examples=[ - ["data"], - ["data", "*", "field"], - ], + examples=[["data"], ["data", "*", "field"]], title="Field Path", ) delete_origin_value: Optional[bool] = Field( - False, + None, description="Whether to delete the origin value or keep it. Default is False.", title="Delete Origin Value", ) @@ -1173,11 +1170,17 @@ class InjectInto(Enum): class RequestOption(BaseModel): type: Literal["RequestOption"] - field_name: str = Field( - ..., - description="Configures which key should be used in the location that the descriptor is being injected into", + field_name: Optional[str] = Field( + None, + description="Configures which key should be used in the location that the descriptor is being injected into. We hope to eventually deprecate this field in favor of `field_path` for all request_options, but must currently maintain it for backwards compatibility in the Builder.", examples=["segment_id"], - title="Request Option", + title="Field Name", + ) + field_path: Optional[List[str]] = Field( + None, + description="Configures a path to be used for nested structures in JSON body requests (e.g. GraphQL queries)", + examples=[["data", "viewer", "id"]], + title="Field Path", ) inject_into: InjectInto = Field( ..., diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 12a7ea2cf..9cefdd0dc 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -709,8 +709,8 @@ def _json_schema_type_name_to_type(value_type: Optional[ValueType]) -> Optional[ } return names_to_types[value_type] - @staticmethod def create_api_key_authenticator( + self, model: ApiKeyAuthenticatorModel, config: Config, token_provider: Optional[TokenProvider] = None, @@ -732,10 +732,8 @@ def create_api_key_authenticator( ) request_option = ( - RequestOption( - inject_into=RequestOptionType(model.inject_into.inject_into.value), - field_name=model.inject_into.field_name, - parameters=model.parameters or {}, + self._create_component_from_model( + model.inject_into, config, parameters=model.parameters or {} ) if model.inject_into else RequestOption( @@ -744,6 +742,7 @@ def create_api_key_authenticator( parameters=model.parameters or {}, ) ) + return ApiKeyAuthenticator( token_provider=( token_provider @@ -825,7 +824,7 @@ def create_session_token_authenticator( token_provider=token_provider, ) else: - return ModelToComponentFactory.create_api_key_authenticator( + return self.create_api_key_authenticator( ApiKeyAuthenticatorModel( type="ApiKeyAuthenticator", api_token="", @@ -1272,19 +1271,15 @@ def create_datetime_based_cursor( ) end_time_option = ( - RequestOption( - inject_into=RequestOptionType(model.end_time_option.inject_into.value), - field_name=model.end_time_option.field_name, - parameters=model.parameters or {}, + self._create_component_from_model( + model.end_time_option, config, parameters=model.parameters or {} ) if model.end_time_option else None ) start_time_option = ( - RequestOption( - inject_into=RequestOptionType(model.start_time_option.inject_into.value), - field_name=model.start_time_option.field_name, - parameters=model.parameters or {}, + self._create_component_from_model( + model.start_time_option, config, parameters=model.parameters or {} ) if model.start_time_option else None @@ -1358,19 +1353,15 @@ def create_declarative_stream( cursor_model = model.incremental_sync end_time_option = ( - RequestOption( - inject_into=RequestOptionType(cursor_model.end_time_option.inject_into.value), - field_name=cursor_model.end_time_option.field_name, - parameters=cursor_model.parameters or {}, + self._create_component_from_model( + cursor_model.end_time_option, config, parameters=cursor_model.parameters or {} ) if cursor_model.end_time_option else None ) start_time_option = ( - RequestOption( - inject_into=RequestOptionType(cursor_model.start_time_option.inject_into.value), - field_name=cursor_model.start_time_option.field_name, - parameters=cursor_model.parameters or {}, + self._create_component_from_model( + cursor_model.start_time_option, config, parameters=cursor_model.parameters or {} ) if cursor_model.start_time_option else None @@ -1879,16 +1870,11 @@ def create_jwt_authenticator( additional_jwt_payload=model.additional_jwt_payload, ) - @staticmethod def create_list_partition_router( - model: ListPartitionRouterModel, config: Config, **kwargs: Any + self, model: ListPartitionRouterModel, config: Config, **kwargs: Any ) -> ListPartitionRouter: request_option = ( - RequestOption( - inject_into=RequestOptionType(model.request_option.inject_into.value), - field_name=model.request_option.field_name, - parameters=model.parameters or {}, - ) + self._create_component_from_model(model.request_option, config) if model.request_option else None ) @@ -2072,7 +2058,25 @@ def create_request_option( model: RequestOptionModel, config: Config, **kwargs: Any ) -> RequestOption: inject_into = RequestOptionType(model.inject_into.value) - return RequestOption(field_name=model.field_name, inject_into=inject_into, parameters={}) + field_path: Optional[List[Union[InterpolatedString, str]]] = ( + [ + InterpolatedString.create(segment, parameters=kwargs.get("parameters", {})) + for segment in model.field_path + ] + if model.field_path + else None + ) + field_name = ( + InterpolatedString.create(model.field_name, parameters=kwargs.get("parameters", {})) + if model.field_name + else None + ) + return RequestOption( + field_name=field_name, + field_path=field_path, + inject_into=inject_into, + parameters=kwargs.get("parameters", {}), + ) def create_record_selector( self, diff --git a/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py index 29b700b04..6049cefe2 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py @@ -3,7 +3,7 @@ # from dataclasses import InitVar, dataclass -from typing import Any, Iterable, List, Mapping, Optional, Union +from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter @@ -100,7 +100,9 @@ def _get_request_option( ): slice_value = stream_slice.get(self._cursor_field.eval(self.config)) if slice_value: - return {self.request_option.field_name.eval(self.config): slice_value} # type: ignore # field_name is always casted to InterpolatedString + options: MutableMapping[str, Any] = {} + self.request_option.inject_into_request(options, slice_value, self.config) + return options else: return {} else: diff --git a/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py index 1c7bb6961..76cb29666 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py @@ -4,7 +4,7 @@ import copy import logging from dataclasses import InitVar, dataclass -from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, MutableMapping, Optional, Union import dpath @@ -118,7 +118,7 @@ def get_request_body_json( def _get_request_option( self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice] ) -> Mapping[str, Any]: - params = {} + params: MutableMapping[str, Any] = {} if stream_slice: for parent_config in self.parent_stream_configs: if ( @@ -128,13 +128,7 @@ def _get_request_option( key = parent_config.partition_field.eval(self.config) # type: ignore # partition_field is always casted to an interpolated string value = stream_slice.get(key) if value: - params.update( - { - parent_config.request_option.field_name.eval( # type: ignore [union-attr] - config=self.config - ): value - } - ) + parent_config.request_option.inject_into_request(params, value, self.config) return params def stream_slices(self) -> Iterable[StreamSlice]: diff --git a/airbyte_cdk/sources/declarative/requesters/http_requester.py b/airbyte_cdk/sources/declarative/requesters/http_requester.py index 35d4b0f11..ad23f4d06 100644 --- a/airbyte_cdk/sources/declarative/requesters/http_requester.py +++ b/airbyte_cdk/sources/declarative/requesters/http_requester.py @@ -199,6 +199,9 @@ def _get_request_options( Raise a ValueError if there's a key collision Returned merged mapping otherwise """ + + is_body_json = requester_method.__name__ == "get_request_body_json" + return combine_mappings( [ requester_method( @@ -208,7 +211,8 @@ def _get_request_options( ), auth_options_method(), extra_options, - ] + ], + allow_same_value_merge=is_body_json, ) def _request_headers( diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py index 59255c75b..6fb412cd9 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py @@ -187,7 +187,7 @@ def get_request_body_json( def _get_request_options( self, option_type: RequestOptionType, next_page_token: Optional[Mapping[str, Any]] ) -> MutableMapping[str, Any]: - options = {} + options: MutableMapping[str, Any] = {} token = next_page_token.get("next_page_token") if next_page_token else None if ( @@ -196,15 +196,16 @@ def _get_request_options( and isinstance(self.page_token_option, RequestOption) and self.page_token_option.inject_into == option_type ): - options[self.page_token_option.field_name.eval(config=self.config)] = token # type: ignore # field_name is always cast to an interpolated string + self.page_token_option.inject_into_request(options, token, self.config) + if ( self.page_size_option and self.pagination_strategy.get_page_size() and self.page_size_option.inject_into == option_type ): - options[self.page_size_option.field_name.eval(config=self.config)] = ( # type: ignore [union-attr] - self.pagination_strategy.get_page_size() - ) # type: ignore # field_name is always cast to an interpolated string + page_size = self.pagination_strategy.get_page_size() + self.page_size_option.inject_into_request(options, page_size, self.config) + return options diff --git a/airbyte_cdk/sources/declarative/requesters/request_option.py b/airbyte_cdk/sources/declarative/requesters/request_option.py index d13d20566..e0946b53b 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_option.py +++ b/airbyte_cdk/sources/declarative/requesters/request_option.py @@ -4,9 +4,10 @@ from dataclasses import InitVar, dataclass from enum import Enum -from typing import Any, Mapping, Union +from typing import Any, List, Literal, Mapping, MutableMapping, Optional, Union from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString +from airbyte_cdk.sources.types import Config class RequestOptionType(Enum): @@ -26,13 +27,91 @@ class RequestOption: Describes an option to set on a request Attributes: - field_name (str): Describes the name of the parameter to inject + field_name (str): Describes the name of the parameter to inject. Mutually exclusive with field_path. + field_path (list(str)): Describes the path to a nested field as a list of field names. + Only valid for body_json injection type, and mutually exclusive with field_name. inject_into (RequestOptionType): Describes where in the HTTP request to inject the parameter """ - field_name: Union[InterpolatedString, str] inject_into: RequestOptionType parameters: InitVar[Mapping[str, Any]] + field_name: Optional[Union[InterpolatedString, str]] = None + field_path: Optional[List[Union[InterpolatedString, str]]] = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: - self.field_name = InterpolatedString.create(self.field_name, parameters=parameters) + # Validate inputs. We should expect either field_name or field_path, but not both + if self.field_name is None and self.field_path is None: + raise ValueError("RequestOption requires either a field_name or field_path") + + if self.field_name is not None and self.field_path is not None: + raise ValueError( + "Only one of field_name or field_path can be provided to RequestOption" + ) + + # Nested field injection is only supported for body JSON injection + if self.field_path is not None and self.inject_into != RequestOptionType.body_json: + raise ValueError( + "Nested field injection is only supported for body JSON injection. Please use a top-level field_name for other injection types." + ) + + # Convert field_name and field_path into InterpolatedString objects if they are strings + if self.field_name is not None: + self.field_name = InterpolatedString.create(self.field_name, parameters=parameters) + elif self.field_path is not None: + self.field_path = [ + InterpolatedString.create(segment, parameters=parameters) + for segment in self.field_path + ] + + @property + def _is_field_path(self) -> bool: + """Returns whether this option is a field path (ie, a nested field)""" + return self.field_path is not None + + def inject_into_request( + self, + target: MutableMapping[str, Any], + value: Any, + config: Config, + ) -> None: + """ + Inject a request option value into a target request structure using either field_name or field_path. + For non-body-json injection, only top-level field names are supported. + For body-json injection, both field names and nested field paths are supported. + + Args: + target: The request structure to inject the value into + value: The value to inject + config: The config object to use for interpolation + """ + if self._is_field_path: + if self.inject_into != RequestOptionType.body_json: + raise ValueError( + "Nested field injection is only supported for body JSON injection. Please use a top-level field_name for other injection types." + ) + + assert self.field_path is not None # for type checker + current = target + # Convert path segments into strings, evaluating any interpolated segments + # Example: ["data", "{{ config[user_type] }}", "id"] -> ["data", "admin", "id"] + *path_parts, final_key = [ + str( + segment.eval(config=config) + if isinstance(segment, InterpolatedString) + else segment + ) + for segment in self.field_path + ] + + # Build a nested dictionary structure and set the final value at the deepest level + for part in path_parts: + current = current.setdefault(part, {}) + current[final_key] = value + else: + # For non-nested fields, evaluate the field name if it's an interpolated string + key = ( + self.field_name.eval(config=config) + if isinstance(self.field_name, InterpolatedString) + else self.field_name + ) + target[str(key)] = value diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py index 05e06db71..437ea7b7b 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py @@ -80,12 +80,13 @@ def _get_request_options( options: MutableMapping[str, Any] = {} if not stream_slice: return options + if self.start_time_option and self.start_time_option.inject_into == option_type: - options[self.start_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore # field_name is always casted to an interpolated string - self._partition_field_start.eval(self.config) - ) + start_time_value = stream_slice.get(self._partition_field_start.eval(self.config)) + self.start_time_option.inject_into_request(options, start_time_value, self.config) + if self.end_time_option and self.end_time_option.inject_into == option_type: - options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore [union-attr] - self._partition_field_end.eval(self.config) - ) + end_time_value = stream_slice.get(self._partition_field_end.eval(self.config)) + self.end_time_option.inject_into_request(options, end_time_value, self.config) + return options diff --git a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py index d167a84bc..e42f0485a 100644 --- a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py @@ -128,6 +128,9 @@ def _get_request_options( Returned merged mapping otherwise """ # FIXME we should eventually remove the usage of stream_state as part of the interpolation + + is_body_json = paginator_method.__name__ == "get_request_body_json" + mappings = [ paginator_method( stream_state=stream_state, @@ -143,7 +146,7 @@ def _get_request_options( next_page_token=next_page_token, ) ) - return combine_mappings(mappings) + return combine_mappings(mappings, allow_same_value_merge=is_body_json) def _request_headers( self, diff --git a/airbyte_cdk/utils/mapping_helpers.py b/airbyte_cdk/utils/mapping_helpers.py index 469fb5e0a..c5682c288 100644 --- a/airbyte_cdk/utils/mapping_helpers.py +++ b/airbyte_cdk/utils/mapping_helpers.py @@ -3,43 +3,102 @@ # -from typing import Any, List, Mapping, Optional, Set, Union +import copy +from typing import Any, Dict, List, Mapping, Optional, Union + + +def _merge_mappings( + target: Dict[str, Any], + source: Mapping[str, Any], + path: Optional[List[str]] = None, + allow_same_value_merge: bool = False, +) -> None: + """ + Recursively merge two dictionaries, raising an error if there are any conflicts. + For body_json requests (allow_same_value_merge=True), a conflict occurs only when the same path has different values. + For other request types (allow_same_value_merge=False), any duplicate key is a conflict, regardless of value. + + Args: + target: The dictionary to merge into + source: The dictionary to merge from + path: The current path in the nested structure (for error messages) + allow_same_value_merge: Whether to allow merging the same value into the same key. Set to false by default, should only be true for body_json injections + """ + path = path or [] + for key, source_value in source.items(): + current_path = path + [str(key)] + + if key in target: + target_value = target[key] + if isinstance(target_value, dict) and isinstance(source_value, dict): + # Only body_json supports nested_structures + if not allow_same_value_merge: + raise ValueError(f"Duplicate keys found: {'.'.join(current_path)}") + # If both are dictionaries, recursively merge them + _merge_mappings(target_value, source_value, current_path, allow_same_value_merge) + + elif not allow_same_value_merge or target_value != source_value: + # If same key has different values, that's a conflict + raise ValueError(f"Duplicate keys found: {'.'.join(current_path)}") + else: + # No conflict, just copy the value (using deepcopy for nested structures) + target[key] = copy.deepcopy(source_value) def combine_mappings( mappings: List[Optional[Union[Mapping[str, Any], str]]], + allow_same_value_merge: bool = False, ) -> Union[Mapping[str, Any], str]: """ - Combine multiple mappings into a single mapping. If any of the mappings are a string, return - that string. Raise errors in the following cases: - * If there are duplicate keys across mappings - * If there are multiple string mappings - * If there are multiple mappings containing keys and one of them is a string + Combine multiple mappings into a single mapping. + + For body_json requests (allow_same_value_merge=True): + - Supports nested structures (e.g., {"data": {"user": {"id": 1}}}) + - Allows duplicate keys if their values match + - Raises error if same path has different values + + For other request types (allow_same_value_merge=False): + - Only supports flat structures + - Any duplicate key raises an error, regardless of value + + Args: + mappings: List of mappings to combine + allow_same_value_merge: Whether to allow duplicate keys with matching values. + Should only be True for body_json requests. + + Returns: + A single mapping combining all inputs, or a string if there is exactly one + string mapping and no other non-empty mappings. + + Raises: + ValueError: If there are: + - Multiple string mappings + - Both a string mapping and non-empty dictionary mappings + - Conflicting keys/paths based on allow_same_value_merge setting """ - all_keys: List[Set[str]] = [] - for part in mappings: - if part is None: - continue - keys = set(part.keys()) if not isinstance(part, str) else set() - all_keys.append(keys) - - string_options = sum(isinstance(mapping, str) for mapping in mappings) - # If more than one mapping is a string, raise a ValueError + if not mappings: + return {} + + # Count how many string options we have, ignoring None values + string_options = sum(isinstance(mapping, str) for mapping in mappings if mapping is not None) if string_options > 1: raise ValueError("Cannot combine multiple string options") - if string_options == 1 and sum(len(keys) for keys in all_keys) > 0: - raise ValueError("Cannot combine multiple options if one is a string") + # Filter out None values and empty mappings + non_empty_mappings = [ + m for m in mappings if m is not None and not (isinstance(m, Mapping) and not m) + ] - # If any mapping is a string, return it - for mapping in mappings: - if isinstance(mapping, str): - return mapping + # If there is only one string option and no other non-empty mappings, return it + if string_options == 1: + if len(non_empty_mappings) > 1: + raise ValueError("Cannot combine multiple options if one is a string") + return next(m for m in non_empty_mappings if isinstance(m, str)) - # If there are duplicate keys across mappings, raise a ValueError - intersection = set().union(*all_keys) - if len(intersection) < sum(len(keys) for keys in all_keys): - raise ValueError(f"Duplicate keys found: {intersection}") + # Start with an empty result and merge each mapping into it + result: Dict[str, Any] = {} + for mapping in non_empty_mappings: + if mapping and isinstance(mapping, Mapping): + _merge_mappings(result, mapping, allow_same_value_merge=allow_same_value_merge) - # Return the combined mappings - return {key: value for mapping in mappings if mapping for key, value in mapping.items()} # type: ignore # mapping can't be string here + return result diff --git a/unit_tests/sources/declarative/auth/test_token_auth.py b/unit_tests/sources/declarative/auth/test_token_auth.py index 2a23d4e19..4e90367c1 100644 --- a/unit_tests/sources/declarative/auth/test_token_auth.py +++ b/unit_tests/sources/declarative/auth/test_token_auth.py @@ -248,3 +248,38 @@ def test_api_key_authenticator_inject( parameters=parameters, ) assert {expected_field_name: expected_field_value} == getattr(token_auth, validation_fn)() + + +@pytest.mark.parametrize( + "field_path, token, expected_result", + [ + ( + ["data", "auth", "token"], + "test-token", + {"data": {"auth": {"token": "test-token"}}}, + ), + ( + ["api", "{{ config.api_version }}", "auth", "token"], + "test-token", + {"api": {"v2": {"auth": {"token": "test-token"}}}}, + ), + ], + ids=["Basic nested structure", "Nested with config interpolation"], +) +def test_api_key_authenticator_nested_token_injection(field_path, token, expected_result): + """Test that the ApiKeyAuthenticator can properly inject tokens into nested structures when using body_json""" + config = {"api_version": "v2"} + parameters = {"auth_type": "bearer"} + + token_provider = InterpolatedStringTokenProvider( + config=config, api_token=token, parameters=parameters + ) + token_auth = ApiKeyAuthenticator( + request_option=RequestOption( + inject_into=RequestOptionType.body_json, field_path=field_path, parameters=parameters + ), + token_provider=token_provider, + config=config, + parameters=parameters, + ) + assert token_auth.get_request_body_json() == expected_result diff --git a/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py b/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py index 37ed7ebfe..3ddc5847f 100644 --- a/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py +++ b/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py @@ -782,13 +782,14 @@ def test_given_different_format_and_slice_is_highest_when_close_slice_then_state @pytest.mark.parametrize( - "test_name, inject_into, field_name, expected_req_params, expected_headers, expected_body_json, expected_body_data", + "test_name, inject_into, field_name, field_path, expected_req_params, expected_headers, expected_body_json, expected_body_data", [ - ("test_start_time_inject_into_none", None, None, {}, {}, {}, {}), + ("test_start_time_inject_into_none", None, None, None, {}, {}, {}, {}), ( "test_start_time_passed_by_req_param", RequestOptionType.request_parameter, "start_time", + None, { "start_time": "2021-01-01T00:00:00.000000+0000", "endtime": "2021-01-04T00:00:00.000000+0000", @@ -801,6 +802,7 @@ def test_given_different_format_and_slice_is_highest_when_close_slice_then_state "test_start_time_inject_into_header", RequestOptionType.header, "start_time", + None, {}, { "start_time": "2021-01-01T00:00:00.000000+0000", @@ -810,9 +812,10 @@ def test_given_different_format_and_slice_is_highest_when_close_slice_then_state {}, ), ( - "test_start_time_inject_intoy_body_json", + "test_start_time_inject_into_body_json", RequestOptionType.body_json, "start_time", + None, {}, {}, { @@ -821,10 +824,30 @@ def test_given_different_format_and_slice_is_highest_when_close_slice_then_state }, {}, ), + ( + "test_nested_field_injection_into_body_json", + RequestOptionType.body_json, + None, + ["data", "queries", "time_range", "start"], + {}, + {}, + { + "data": { + "queries": { + "time_range": { + "start": "2021-01-01T00:00:00.000000+0000", + "end": "2021-01-04T00:00:00.000000+0000", + } + } + } + }, + {}, + ), ( "test_start_time_inject_into_body_data", RequestOptionType.body_data, "start_time", + None, {}, {}, {}, @@ -839,18 +862,26 @@ def test_request_option( test_name, inject_into, field_name, + field_path, expected_req_params, expected_headers, expected_body_json, expected_body_data, ): start_request_option = ( - RequestOption(inject_into=inject_into, parameters={}, field_name=field_name) + RequestOption( + inject_into=inject_into, parameters={}, field_name=field_name, field_path=field_path + ) if inject_into else None ) end_request_option = ( - RequestOption(inject_into=inject_into, parameters={}, field_name="endtime") + RequestOption( + inject_into=inject_into, + parameters={}, + field_name="endtime" if field_name else None, + field_path=["data", "queries", "time_range", "end"] if field_path else None, + ) if inject_into else None ) diff --git a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index c50e9e6e9..cd6edbae0 100644 --- a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -593,8 +593,8 @@ def test_list_based_stream_slicer_with_values_defined_in_config(): cursor_field: repository request_option: type: RequestOption - inject_into: header - field_name: repository + inject_into: body_json + field_path: ["repository", "id"] """ parsed_manifest = YamlDeclarativeSource._parse(content) resolved_manifest = resolver.preprocess_manifest(parsed_manifest) @@ -610,8 +610,10 @@ def test_list_based_stream_slicer_with_values_defined_in_config(): assert isinstance(partition_router, ListPartitionRouter) assert partition_router.values == ["airbyte", "airbyte-cloud"] - assert partition_router.request_option.inject_into == RequestOptionType.header - assert partition_router.request_option.field_name.eval(config=input_config) == "repository" + assert partition_router.request_option.inject_into == RequestOptionType.body_json + for field in partition_router.request_option.field_path: + assert isinstance(field, InterpolatedString) + assert len(partition_router.request_option.field_path) == 2 def test_create_substream_partition_router(): @@ -714,7 +716,7 @@ def test_datetime_based_cursor(): end_time_option: type: RequestOption inject_into: body_json - field_name: "before_{{ parameters['cursor_field'] }}" + field_path: ["before_{{ parameters['cursor_field'] }}"] partition_field_start: star partition_field_end: en """ @@ -743,7 +745,9 @@ def test_datetime_based_cursor(): == "since_updated_at" ) assert stream_slicer.end_time_option.inject_into == RequestOptionType.body_json - assert stream_slicer.end_time_option.field_name.eval({}) == "before_created_at" + assert [field.eval({}) for field in stream_slicer.end_time_option.field_path] == [ + "before_created_at" + ] assert stream_slicer._partition_field_start.eval({}) == "star" assert stream_slicer._partition_field_end.eval({}) == "en" @@ -904,8 +908,8 @@ def test_resumable_full_refresh_stream(): type: DefaultPaginator page_size_option: type: RequestOption - inject_into: request_parameter - field_name: page_size + inject_into: body_json + field_path: ["variables", "page_size"] page_token_option: type: RequestPath pagination_strategy: @@ -1003,11 +1007,10 @@ def test_resumable_full_refresh_stream(): assert isinstance(stream.retriever.paginator, DefaultPaginator) assert isinstance(stream.retriever.paginator.decoder, PaginationDecoderDecorator) - assert stream.retriever.paginator.page_size_option.field_name.eval(input_config) == "page_size" - assert ( - stream.retriever.paginator.page_size_option.inject_into - == RequestOptionType.request_parameter - ) + for string in stream.retriever.paginator.page_size_option.field_path: + assert isinstance(string, InterpolatedString) + assert len(stream.retriever.paginator.page_size_option.field_path) == 2 + assert stream.retriever.paginator.page_size_option.inject_into == RequestOptionType.body_json assert isinstance(stream.retriever.paginator.page_token_option, RequestPath) assert stream.retriever.paginator.url_base.string == "https://api.sendgrid.com/v3/" assert stream.retriever.paginator.url_base.default == "https://api.sendgrid.com/v3/" @@ -2509,7 +2512,6 @@ def test_merge_incremental_and_partition_router(incremental, partition_router, e assert isinstance(stream, DeclarativeStream) assert isinstance(stream.retriever, SimpleRetriever) - print(stream.retriever.stream_slicer) assert isinstance(stream.retriever.stream_slicer, expected_type) if incremental and partition_router: diff --git a/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py b/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py index cbe185a37..57b6d9d34 100644 --- a/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py +++ b/unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py @@ -437,7 +437,9 @@ def test_paginator_with_page_option_no_page_size(): DefaultPaginator( page_size_option=MagicMock(), page_token_option=RequestOption( - "limit", RequestOptionType.request_parameter, parameters={} + field_name="limit", + inject_into=RequestOptionType.request_parameter, + parameters={}, ), pagination_strategy=pagination_strategy, config=MagicMock(), diff --git a/unit_tests/sources/declarative/requesters/request_options/test_request_options.py b/unit_tests/sources/declarative/requesters/request_options/test_request_options.py new file mode 100644 index 000000000..115ce688d --- /dev/null +++ b/unit_tests/sources/declarative/requesters/request_options/test_request_options.py @@ -0,0 +1,199 @@ +from typing import Any, Dict, List, Optional, Type + +import pytest + +from airbyte_cdk.sources.declarative.requesters.request_option import ( + RequestOption, + RequestOptionType, +) + + +@pytest.mark.parametrize( + "field_name, field_path, inject_into, error_type, error_message", + [ + ( + None, + None, + RequestOptionType.body_json, + ValueError, + "RequestOption requires either a field_name or field_path", + ), + ( + "field", + ["data", "field"], + RequestOptionType.body_json, + ValueError, + "Only one of field_name or field_path can be provided", + ), + ( + None, + ["data", "field"], + RequestOptionType.header, + ValueError, + "Nested field injection is only supported for body JSON injection.", + ), + ], +) +def test_request_option_validation( + field_name: Optional[str], + field_path: Any, + inject_into: RequestOptionType, + error_type: Type[Exception], + error_message: str, +): + """Test various validation cases for RequestOption""" + with pytest.raises(error_type, match=error_message): + RequestOption( + field_name=field_name, field_path=field_path, inject_into=inject_into, parameters={} + ) + + +@pytest.mark.parametrize( + "request_option_args, value, expected_result", + [ + # Basic field_name test + ( + { + "field_name": "test_{{ config['base_field'] }}", + "inject_into": RequestOptionType.body_json, + }, + "test_value", + {"test_value": "test_value"}, + ), + # Basic field_path test + ( + { + "field_path": ["data", "nested_{{ config['base_field'] }}", "field"], + "inject_into": RequestOptionType.body_json, + }, + "test_value", + {"data": {"nested_value": {"field": "test_value"}}}, + ), + # Deep nesting test + ( + { + "field_path": ["level1", "level2", "level3", "level4", "field"], + "inject_into": RequestOptionType.body_json, + }, + "deep_value", + {"level1": {"level2": {"level3": {"level4": {"field": "deep_value"}}}}}, + ), + ], +) +def test_inject_into_request_cases( + request_option_args: Dict[str, Any], value: Any, expected_result: Dict[str, Any] +): + """Test various injection cases""" + config = {"base_field": "value"} + target: Dict[str, Any] = {} + + request_option = RequestOption(**request_option_args, parameters={}) + request_option.inject_into_request(target, value, config) + assert target == expected_result + + +@pytest.mark.parametrize( + "config, parameters, field_path, expected_structure", + [ + ( + {"nested": "user"}, + {"type": "profile"}, + ["data", "{{ config['nested'] }}", "{{ parameters['type'] }}"], + {"data": {"user": {"profile": "test_value"}}}, + ), + ( + {"user_type": "admin", "section": "profile"}, + {"id": "12345"}, + [ + "data", + "{{ config['user_type'] }}", + "{{ parameters['id'] }}", + "{{ config['section'] }}", + "details", + ], + {"data": {"admin": {"12345": {"profile": {"details": "test_value"}}}}}, + ), + ], +) +def test_interpolation_cases( + config: Dict[str, Any], + parameters: Dict[str, Any], + field_path: List[str], + expected_structure: Dict[str, Any], +): + """Test various interpolation scenarios""" + request_option = RequestOption( + field_path=field_path, inject_into=RequestOptionType.body_json, parameters=parameters + ) + target: Dict[str, Any] = {} + request_option.inject_into_request(target, "test_value", config) + assert target == expected_structure + + +@pytest.mark.parametrize( + "value, expected_type", + [ + (42, int), + (3.14, float), + (True, bool), + (["a", "b", "c"], list), + ({"key": "value"}, dict), + (None, type(None)), + ], +) +def test_value_type_handling(value: Any, expected_type: Type): + """Test handling of different value types""" + config = {} + target: Dict[str, Any] = {} + request_option = RequestOption( + field_path=["data", "test"], inject_into=RequestOptionType.body_json, parameters={} + ) + request_option.inject_into_request(target, value, config) + assert isinstance(target["data"]["test"], expected_type) + assert target["data"]["test"] == value + + +@pytest.mark.parametrize( + "field_name, field_path, inject_into, expected__is_field_path", + [ + ("field", None, RequestOptionType.body_json, False), + (None, ["data", "field"], RequestOptionType.body_json, True), + ], +) +def test__is_field_path( + field_name: Optional[str], + field_path: Optional[List[str]], + inject_into: RequestOptionType, + expected__is_field_path: bool, +): + """Test the _is_field_path property""" + request_option = RequestOption( + field_name=field_name, field_path=field_path, inject_into=inject_into, parameters={} + ) + assert request_option._is_field_path == expected__is_field_path + + +def test_multiple_injections(): + """Test injecting multiple values into the same target dict""" + config = {"base": "test"} + target = {"existing": "value"} + + # First injection with field_name + option1 = RequestOption( + field_name="field1", inject_into=RequestOptionType.body_json, parameters={} + ) + option1.inject_into_request(target, "value1", config) + + # Second injection with nested path + option2 = RequestOption( + field_path=["data", "nested", "field2"], + inject_into=RequestOptionType.body_json, + parameters={}, + ) + option2.inject_into_request(target, "value2", config) + + assert target == { + "existing": "value", + "field1": "value1", + "data": {"nested": {"field2": "value2"}}, + } diff --git a/unit_tests/sources/declarative/requesters/test_http_requester.py b/unit_tests/sources/declarative/requesters/test_http_requester.py index 28ea0cb9b..c1f3cee4f 100644 --- a/unit_tests/sources/declarative/requesters/test_http_requester.py +++ b/unit_tests/sources/declarative/requesters/test_http_requester.py @@ -277,8 +277,8 @@ def test_basic_send_request(): None, '{"field": "value", "field2": "value", "authfield": "val"}', ), - (None, {"field": "value"}, None, {"field": "value"}, None, None, ValueError, None), - (None, {"field": "value"}, None, None, None, {"field": "value"}, ValueError, None), + (None, {"field": "value"}, None, {"field": "value"}, None, None, None, "field=value"), + (None, {"field": "value"}, None, None, None, {"field": "value"}, None, "field=value"), # raise on mixed data and json params ( {"field": "value"}, diff --git a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py index 5878c758f..bd31f7b65 100644 --- a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py +++ b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py @@ -58,6 +58,11 @@ def test_simple_retriever_full(mock_http_stream): request_params = {"param": "value"} requester.get_request_params.return_value = request_params + requester.get_request_params.__name__ = "get_request_params" + requester.get_request_headers.__name__ = "get_request_headers" + requester.get_request_body_data.__name__ = "get_request_body_data" + requester.get_request_body_json.__name__ = "get_request_body_json" + paginator = MagicMock() paginator.get_initial_token.return_value = None next_page_token = {"cursor": "cursor_value"} @@ -65,6 +70,11 @@ def test_simple_retriever_full(mock_http_stream): paginator.next_page_token.return_value = next_page_token paginator.get_request_headers.return_value = {} + paginator.get_request_params.__name__ = "get_request_params" + paginator.get_request_headers.__name__ = "get_request_headers" + paginator.get_request_body_data.__name__ = "get_request_body_data" + paginator.get_request_body_json.__name__ = "get_request_body_json" + record_selector = MagicMock() record_selector.select_records.return_value = records @@ -442,11 +452,19 @@ def test_get_request_options_from_pagination( paginator.get_request_body_data.return_value = paginator_mapping paginator.get_request_body_json.return_value = paginator_mapping + paginator.get_request_params.__name__ = "get_request_params" + paginator.get_request_body_data.__name__ = "get_request_body_data" + paginator.get_request_body_json.__name__ = "get_request_body_json" + request_options_provider = MagicMock() request_options_provider.get_request_params.return_value = request_options_provider_mapping request_options_provider.get_request_body_data.return_value = request_options_provider_mapping request_options_provider.get_request_body_json.return_value = request_options_provider_mapping + request_options_provider.get_request_params.__name__ = "get_request_params" + request_options_provider.get_request_body_data.__name__ = "get_request_body_data" + request_options_provider.get_request_body_json.__name__ = "get_request_body_json" + record_selector = MagicMock() retriever = SimpleRetriever( name="stream_name", @@ -489,10 +507,12 @@ def test_get_request_headers(test_name, paginator_mapping, expected_mapping): # This test is separate from the other request options because request headers must be strings paginator = MagicMock() paginator.get_request_headers.return_value = paginator_mapping + paginator.get_request_headers.__name__ = "get_request_headers" requester = MagicMock(use_cache=False) stream_slicer = MagicMock() stream_slicer.get_request_headers.return_value = {"key": "value"} + stream_slicer.get_request_headers.__name__ = "get_request_headers" record_selector = MagicMock() retriever = SimpleRetriever( @@ -565,10 +585,12 @@ def test_ignore_stream_slicer_parameters_on_paginated_requests( # This test is separate from the other request options because request headers must be strings paginator = MagicMock() paginator.get_request_headers.return_value = paginator_mapping + paginator.get_request_headers.__name__ = "get_request_headers" requester = MagicMock(use_cache=False) stream_slicer = MagicMock() stream_slicer.get_request_headers.return_value = {"key_from_slicer": "value"} + stream_slicer.get_request_headers.__name__ = "get_request_headers" record_selector = MagicMock() retriever = SimpleRetriever( @@ -612,6 +634,7 @@ def test_request_body_data( ): paginator = MagicMock() paginator.get_request_body_data.return_value = paginator_body_data + paginator.get_request_body_data.__name__ = "get_request_body_data" requester = MagicMock(use_cache=False) request_option_provider = MagicMock() @@ -825,11 +848,25 @@ def test_emit_log_request_response_messages(mocker): "airbyte_cdk.sources.declarative.retrievers.simple_retriever.format_http_message" ) requester = MagicMock() + + # Add __name__ to mock methods + requester.get_request_params.__name__ = "get_request_params" + requester.get_request_headers.__name__ = "get_request_headers" + requester.get_request_body_data.__name__ = "get_request_body_data" + requester.get_request_body_json.__name__ = "get_request_body_json" + + # The paginator mock also needs __name__ attributes + paginator = MagicMock() + paginator.get_request_params.__name__ = "get_request_params" + paginator.get_request_headers.__name__ = "get_request_headers" + paginator.get_request_body_data.__name__ = "get_request_body_data" + paginator.get_request_body_json.__name__ = "get_request_body_json" + retriever = SimpleRetrieverTestReadDecorator( name="stream_name", primary_key=primary_key, requester=requester, - paginator=MagicMock(), + paginator=paginator, record_selector=record_selector, stream_slicer=SinglePartitionRouter(parameters={}), parameters={}, diff --git a/unit_tests/sources/declarative/test_manifest_declarative_source.py b/unit_tests/sources/declarative/test_manifest_declarative_source.py index b3c9ab4bb..667e32d60 100644 --- a/unit_tests/sources/declarative/test_manifest_declarative_source.py +++ b/unit_tests/sources/declarative/test_manifest_declarative_source.py @@ -1024,8 +1024,8 @@ def test_manifest_without_at_least_one_stream(self): "page_size": 10, "page_size_option": { "type": "RequestOption", - "inject_into": "request_parameter", - "field_name": "page_size", + "inject_into": "request_body", + "field_path": ["variables", "page_size"], }, "page_token_option": {"type": "RequestPath"}, "pagination_strategy": { diff --git a/unit_tests/utils/test_mapping_helpers.py b/unit_tests/utils/test_mapping_helpers.py index 272ce9b7a..124bf4565 100644 --- a/unit_tests/utils/test_mapping_helpers.py +++ b/unit_tests/utils/test_mapping_helpers.py @@ -1,55 +1,115 @@ -# -# Copyright (c) 2023 Airbyte, Inc., all rights reserved. -# - import pytest from airbyte_cdk.utils.mapping_helpers import combine_mappings -def test_basic_merge(): - mappings = [{"a": 1}, {"b": 2}, {"c": 3}, {}] - result = combine_mappings(mappings) - assert result == {"a": 1, "b": 2, "c": 3} - - -def test_combine_with_string(): - mappings = [{"a": 1}, "option"] - with pytest.raises(ValueError, match="Cannot combine multiple options if one is a string"): - combine_mappings(mappings) - - -def test_overlapping_keys(): - mappings = [{"a": 1, "b": 2}, {"b": 3}] - with pytest.raises(ValueError, match="Duplicate keys found"): - combine_mappings(mappings) - - -def test_multiple_strings(): - mappings = ["option1", "option2"] - with pytest.raises(ValueError, match="Cannot combine multiple string options"): - combine_mappings(mappings) - - -def test_handle_none_values(): - mappings = [{"a": 1}, None, {"b": 2}] - result = combine_mappings(mappings) - assert result == {"a": 1, "b": 2} - - -def test_empty_mappings(): - mappings = [] - result = combine_mappings(mappings) - assert result == {} - - -def test_single_mapping(): - mappings = [{"a": 1}] - result = combine_mappings(mappings) - assert result == {"a": 1} - - -def test_combine_with_string_and_empty_mappings(): - mappings = ["option", {}] - result = combine_mappings(mappings) - assert result == "option" +@pytest.mark.parametrize( + "test_name, mappings, expected_result", + [ + ("empty_mappings", [], {}), + ("single_mapping", [{"a": 1}], {"a": 1}), + ("handle_none_values", [{"a": 1}, None, {"b": 2}], {"a": 1, "b": 2}), + ], +) +def test_basic_functionality(test_name, mappings, expected_result): + """Test basic mapping operations that work the same regardless of request type""" + assert combine_mappings(mappings) == expected_result + + +@pytest.mark.parametrize( + "test_name, mappings, expected_result, expected_error", + [ + ( + "combine_with_string", + [{"a": 1}, "option"], + None, + "Cannot combine multiple options if one is a string", + ), + ( + "multiple_strings", + ["option1", "option2"], + None, + "Cannot combine multiple string options", + ), + ("string_with_empty_mapping", ["option", {}], "option", None), + ], +) +def test_string_handling(test_name, mappings, expected_result, expected_error): + """Test string handling behavior which is independent of request type""" + if expected_error: + with pytest.raises(ValueError, match=expected_error): + combine_mappings(mappings) + else: + assert combine_mappings(mappings) == expected_result + + +@pytest.mark.parametrize( + "test_name, mappings, expected_error", + [ + ("duplicate_keys_same_value", [{"a": 1}, {"a": 1}], "Duplicate keys found"), + ("duplicate_keys_different_value", [{"a": 1}, {"a": 2}], "Duplicate keys found"), + ( + "nested_structure_not_allowed", + [{"a": {"b": 1}}, {"a": {"c": 2}}], + "Duplicate keys found", + ), + ("any_nesting_not_allowed", [{"a": {"b": 1}}, {"a": {"d": 2}}], "Duplicate keys found"), + ], +) +def test_non_body_json_requests(test_name, mappings, expected_error): + """Test strict validation for non-body-json requests (headers, params, body_data)""" + with pytest.raises(ValueError, match=expected_error): + combine_mappings(mappings, allow_same_value_merge=False) + + +@pytest.mark.parametrize( + "test_name, mappings, expected_result, expected_error", + [ + ( + "simple_nested_merge", + [{"a": {"b": 1}}, {"c": {"d": 2}}], + {"a": {"b": 1}, "c": {"d": 2}}, + None, + ), + ( + "deep_nested_merge", + [{"a": {"b": {"c": 1}}}, {"d": {"e": {"f": 2}}}], + {"a": {"b": {"c": 1}}, "d": {"e": {"f": 2}}}, + None, + ), + ( + "nested_merge_same_level", + [ + {"data": {"user": {"id": 1}, "status": "active"}}, + {"data": {"user": {"name": "test"}, "type": "admin"}}, + ], + { + "data": { + "user": {"id": 1, "name": "test"}, + "status": "active", + "type": "admin", + }, + }, + None, + ), + ( + "nested_conflict", + [{"a": {"b": 1}}, {"a": {"b": 2}}], + None, + "Duplicate keys found", + ), + ( + "type_conflict", + [{"a": 1}, {"a": {"b": 2}}], + None, + "Duplicate keys found", + ), + ], +) +def test_body_json_requests(test_name, mappings, expected_result, expected_error): + """Test nested structure support for body_json requests""" + if expected_error: + with pytest.raises(ValueError, match=expected_error): + combine_mappings(mappings, allow_same_value_merge=True) + else: + assert combine_mappings(mappings, allow_same_value_merge=True) == expected_result