diff --git a/haystack/components/embedders/azure_document_embedder.py b/haystack/components/embedders/azure_document_embedder.py index e60c8781b6..c6f94525f3 100644 --- a/haystack/components/embedders/azure_document_embedder.py +++ b/haystack/components/embedders/azure_document_embedder.py @@ -51,6 +51,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) embedding_separator: str = "\n", timeout: Optional[float] = None, max_retries: Optional[int] = None, + default_headers: Optional[Dict[str, str]] = None, ): """ Creates an AzureOpenAIDocumentEmbedder component. @@ -95,6 +96,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) `OPENAI_TIMEOUT` environment variable, or 30 seconds. :param max_retries: Maximum number of retries to contact AzureOpenAI after an internal error. If not set, defaults to either the `OPENAI_MAX_RETRIES` environment variable or to 5 retries. + :param default_headers: Default headers to use for the AzureOpenAI client. """ # if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT") @@ -119,6 +121,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) self.embedding_separator = embedding_separator self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0)) self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5)) + self.default_headers = default_headers or {} self._client = AzureOpenAI( api_version=api_version, @@ -129,6 +132,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) organization=organization, timeout=self.timeout, max_retries=self.max_retries, + default_headers=self.default_headers, ) def _get_telemetry_data(self) -> Dict[str, Any]: @@ -161,6 +165,7 @@ def to_dict(self) -> Dict[str, Any]: azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None, timeout=self.timeout, max_retries=self.max_retries, + default_headers=self.default_headers, ) @classmethod diff --git a/haystack/components/embedders/azure_text_embedder.py b/haystack/components/embedders/azure_text_embedder.py index 961cd910ad..d08614efa0 100644 --- a/haystack/components/embedders/azure_text_embedder.py +++ b/haystack/components/embedders/azure_text_embedder.py @@ -46,6 +46,7 @@ def __init__( max_retries: Optional[int] = None, prefix: str = "", suffix: str = "", + default_headers: Optional[Dict[str, str]] = None, ): """ Creates an AzureOpenAITextEmbedder component. @@ -82,6 +83,7 @@ def __init__( A string to add at the beginning of each text. :param suffix: A string to add at the end of each text. + :param default_headers: Default headers to use for the AzureOpenAI client. """ # Why is this here? # AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not @@ -105,6 +107,7 @@ def __init__( self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5)) self.prefix = prefix self.suffix = suffix + self.default_headers = default_headers or {} self._client = AzureOpenAI( api_version=api_version, @@ -115,6 +118,7 @@ def __init__( organization=organization, timeout=self.timeout, max_retries=self.max_retries, + default_headers=self.default_headers, ) def _get_telemetry_data(self) -> Dict[str, Any]: @@ -143,6 +147,7 @@ def to_dict(self) -> Dict[str, Any]: azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None, timeout=self.timeout, max_retries=self.max_retries, + default_headers=self.default_headers, ) @classmethod diff --git a/releasenotes/notes/azure-embeddings-default-headers-b9b9b3b054dd89d9.yaml b/releasenotes/notes/azure-embeddings-default-headers-b9b9b3b054dd89d9.yaml new file mode 100644 index 0000000000..d6bbe02d2c --- /dev/null +++ b/releasenotes/notes/azure-embeddings-default-headers-b9b9b3b054dd89d9.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Adds `default_headers` parameter to `AzureOpenAIDocumentEmbedder` and `AzureOpenAITextEmbedder` diff --git a/test/components/embedders/test_azure_document_embedder.py b/test/components/embedders/test_azure_document_embedder.py index 354f35a0fc..e0af24f063 100644 --- a/test/components/embedders/test_azure_document_embedder.py +++ b/test/components/embedders/test_azure_document_embedder.py @@ -45,6 +45,7 @@ def test_to_dict(self, monkeypatch): "embedding_separator": "\n", "max_retries": 5, "timeout": 30.0, + "default_headers": {}, }, } diff --git a/test/components/embedders/test_azure_text_embedder.py b/test/components/embedders/test_azure_text_embedder.py index 5f1f82e3d8..71dc9b4076 100644 --- a/test/components/embedders/test_azure_text_embedder.py +++ b/test/components/embedders/test_azure_text_embedder.py @@ -38,6 +38,7 @@ def test_to_dict(self, monkeypatch): "timeout": 30.0, "prefix": "", "suffix": "", + "default_headers": {}, }, }