diff --git a/.buildkite/nightly_steps.yml b/.buildkite/nightly_steps.yml index 9a0ed5546..1a9759035 100644 --- a/.buildkite/nightly_steps.yml +++ b/.buildkite/nightly_steps.yml @@ -57,6 +57,8 @@ steps: - label: "🔨 Microsoft SQL" command: - ".buildkite/run_nigthly.sh mssql" + env: + SKIP_AARCH64: "true" artifact_paths: - "perf8-report-*/**/*" - label: "🔨 Jira" diff --git a/Dockerfile.ftest.wolfi b/Dockerfile.ftest.wolfi index 5ba2190e2..78047c6d8 100644 --- a/Dockerfile.ftest.wolfi +++ b/Dockerfile.ftest.wolfi @@ -1,4 +1,4 @@ -FROM docker.elastic.co/wolfi/python:3.10-dev +FROM docker.elastic.co/wolfi/python:3.10-dev@sha256:609e3fdb2c2a219941a6ea1e81f935899d05a787af634e4d47a3a3ec9a6d3379 USER root COPY . /connectors WORKDIR /connectors diff --git a/connectors/VERSION b/connectors/VERSION index 7bb3f3047..70e4ab71a 100644 --- a/connectors/VERSION +++ b/connectors/VERSION @@ -1 +1 @@ -8.15.0.0 +8.15.5.1 diff --git a/connectors/es/sink.py b/connectors/es/sink.py index 43b863e19..5a8b3c22e 100644 --- a/connectors/es/sink.py +++ b/connectors/es/sink.py @@ -197,7 +197,9 @@ async def _bulk_api_call(): for item in res["items"]: for op, data in item.items(): if "error" in data: - self._logger.error(f"operation {op} failed, {data['error']}") + self._logger.error( + f"operation {op} failed for doc {data['_id']}, {data['error']}" + ) self._populate_stats(stats, res) @@ -598,8 +600,18 @@ async def get_docs(self, generator, skip_unchanged_documents=False): "doc": doc, } ) + # We try raising every loop to not miss a moment when + # too many errors happened when downloading + lazy_downloads.raise_any_exception() await asyncio.sleep(0) + + # Sit and wait until an error happens + await lazy_downloads.join(raise_on_error=True) + except Exception as ex: + self._logger.error(f"Extractor failed with an error: {ex}") + lazy_downloads.cancel() + raise finally: # wait for all downloads to be finished await lazy_downloads.join() @@ -883,9 +895,7 @@ def _extractor_task_running(self): async def cancel(self): if self._sink_task_running(): - self._logger.info( - f"Cancling the Sink task: {self._sink_task.name}" # pyright: ignore - ) + self._logger.info(f"Canceling the Sink task: {self._sink_task.get_name()}") self._sink_task.cancel() else: self._logger.debug( @@ -894,7 +904,7 @@ async def cancel(self): if self._extractor_task_running(): self._logger.info( - f"Canceling the Extractor task: {self._extractor_task.name}" # pyright: ignore + f"Canceling the Extractor task: {self._extractor_task.get_name()}" ) self._extractor_task.cancel() else: @@ -1042,13 +1052,15 @@ async def async_bulk( def sink_task_callback(self, task): if task.exception(): self._logger.error( - f"Encountered an error in the sync's {type(self._sink).__name__}: {task.get_name()}: {task.exception()}", + f"Encountered an error in the sync's {type(self._sink).__name__}: {task.get_name()}", + exc_info=task.exception(), ) self.error = task.exception() def extractor_task_callback(self, task): if task.exception(): self._logger.error( - f"Encountered an error in the sync's {type(self._extractor).__name__}: {task.get_name()}: {task.exception()}", + f"Encountered an error in the sync's {type(self._extractor).__name__}: {task.get_name()}", + exc_info=task.exception(), ) self.error = task.exception() diff --git a/connectors/source.py b/connectors/source.py index 997fae7df..24ed10424 100644 --- a/connectors/source.py +++ b/connectors/source.py @@ -10,7 +10,7 @@ import importlib import re from contextlib import asynccontextmanager -from datetime import date, datetime +from datetime import date, datetime, time from decimal import Decimal from enum import Enum from functools import cache @@ -670,7 +670,7 @@ def _serialize(value): elif isinstance(value, dict): for key, svalue in value.items(): value[key] = _serialize(svalue) - elif isinstance(value, (datetime, date)): + elif isinstance(value, (datetime, date, time)): value = value.isoformat() elif isinstance(value, Decimal128): value = value.to_decimal() diff --git a/connectors/sources/confluence.py b/connectors/sources/confluence.py index ea5b735c7..20b26f521 100644 --- a/connectors/sources/confluence.py +++ b/connectors/sources/confluence.py @@ -895,14 +895,20 @@ async def fetch_documents(self, api_query): doc = { "_id": str(document["id"]), "type": document["type"], - "_timestamp": document["history"]["lastUpdated"]["when"], + "_timestamp": nested_get_from_dict( + document, ["history", "lastUpdated", "when"] + ), "title": document.get("title"), "ancestors": ancestor_title, - "space": document["space"]["name"], - "body": document["body"]["storage"]["value"], + "space": nested_get_from_dict(document, ["space", "name"]), + "body": nested_get_from_dict(document, ["body", "storage", "value"]), "url": document_url, - "author": document["history"]["createdBy"][self.authorkey], - "createdDate": document["history"]["createdDate"], + "author": nested_get_from_dict( + document, ["history", "createdBy", self.authorkey] + ), + "createdDate": nested_get_from_dict( + document, ["history", "createdDate"] + ), } if self.confluence_client.index_labels: doc["labels"] = document["labels"] diff --git a/connectors/sources/github.py b/connectors/sources/github.py index 50bdc3d8b..7c49f6ad9 100644 --- a/connectors/sources/github.py +++ b/connectors/sources/github.py @@ -12,7 +12,7 @@ import aiohttp import fastjsonschema from aiohttp.client_exceptions import ClientResponseError -from gidgethub import RateLimitExceeded, sansio +from gidgethub import QueryError, RateLimitExceeded, sansio from gidgethub.abc import ( BadGraphQLRequest, GraphQLAuthorizationFailure, @@ -52,6 +52,7 @@ RETRIES = 3 RETRY_INTERVAL = 2 FORBIDDEN = 403 +UNAUTHORIZED = 401 NODE_SIZE = 100 REVIEWS_COUNT = 45 @@ -673,14 +674,12 @@ def __init__( def set_logger(self, logger_): self._logger = logger_ - def get_rate_limit_encountered(self, status_code, message): - return status_code == FORBIDDEN and "rate limit" in str(message).lower() + def get_rate_limit_encountered(self, status_code, rate_limit_remaining): + return status_code == FORBIDDEN and not int(rate_limit_remaining) async def _get_retry_after(self, resource_type): current_time = time.time() - response = await self._get_client.getitem( - "/rate_limit", oauth_token=self._access_token() - ) + response = await self.get_github_item("/rate_limit") reset = nested_get_from_dict( response, ["resources", resource_type, "reset"], default=current_time ) @@ -725,6 +724,8 @@ async def _update_installation_access_token(self): private_key=self.private_key, ) self._installation_access_token = access_token_response["token"] + except RateLimitExceeded: + await self._put_to_sleep("core") except Exception: self._logger.exception( f"Failed to get access token for installation {self._installation_id}.", @@ -789,11 +790,20 @@ async def graphql(self, query, variables=None): msg = "Your Github token is either expired or revoked. Please check again." raise UnauthorizedException(msg) from exception except BadGraphQLRequest as exception: - if self.get_rate_limit_encountered(exception.status_code, exception): - await self._put_to_sleep(resource_type="graphql") - elif exception.status == FORBIDDEN: + if exception.status_code == FORBIDDEN: msg = f"Provided GitHub token does not have the necessary permissions to perform the request for the URL: {url} and query: {query}." raise ForbiddenException(msg) from exception + else: + raise + except QueryError as exception: + for error in exception.response.get("errors"): + if ( + error.get("type").lower() == "rate_limited" + and "api rate limit exceeded" in error.get("message").lower() + ): + await self._put_to_sleep(resource_type="graphql") + msg = f"Error while executing query. Exception: {exception.response.get('errors')}" + raise Exception(msg) from exception except Exception: raise @@ -819,7 +829,7 @@ async def get_github_item(self, resource): url=resource, oauth_token=self._access_token() ) except ClientResponseError as exception: - if exception.status == 401: + if exception.status == UNAUTHORIZED: if self.auth_method == GITHUB_APP: self._logger.debug( f"The access token for installation #{self._installation_id} expired, Regenerating a new token." @@ -828,6 +838,10 @@ async def get_github_item(self, resource): raise msg = "Your Github token is either expired or revoked. Please check again." raise UnauthorizedException(msg) from exception + elif self.get_rate_limit_encountered( + exception.status, exception.headers.get("X-RateLimit-Remaining") + ): + await self._put_to_sleep(resource_type="core") elif exception.status == FORBIDDEN: msg = f"Provided GitHub token does not have the necessary permissions to perform the request for the URL: {resource}." raise ForbiddenException(msg) from exception @@ -861,20 +875,44 @@ async def paginated_api_call(self, query, variables, keys): def get_repo_details(self, repo_name): return repo_name.split("/") + @retryable( + retries=RETRIES, + interval=RETRY_INTERVAL, + strategy=RetryStrategy.EXPONENTIAL_BACKOFF, + skipped_exceptions=UnauthorizedException, + ) async def get_personal_access_token_scopes(self): - request_headers = sansio.create_headers( - self._get_client.requester, - accept=sansio.accept_format(), - oauth_token=self._access_token(), - ) - _, headers, _ = await self._get_client._request( - "HEAD", self.base_url, request_headers - ) - scopes = headers.get("X-OAuth-Scopes") - if not scopes or not scopes.strip(): - self._logger.warning(f"Couldn't find 'X-OAuth-Scopes' in headers {headers}") - return set() - return {scope.strip() for scope in scopes.split(",")} + try: + request_headers = sansio.create_headers( + self._get_client.requester, + accept=sansio.accept_format(), + oauth_token=self._access_token(), + ) + url = f"{self.base_url}/graphql" + _, headers, _ = await self._get_client._request( + "HEAD", url, request_headers + ) + scopes = headers.get("X-OAuth-Scopes") + if not scopes or not scopes.strip(): + self._logger.warning( + f"Couldn't find 'X-OAuth-Scopes' in headers {headers}" + ) + return set() + return {scope.strip() for scope in scopes.split(",")} + except ClientResponseError as exception: + if exception.status == FORBIDDEN: + if self.get_rate_limit_encountered( + exception.status, exception.headers.get("X-RateLimit-Remaining") + ): + await self._put_to_sleep("graphql") + else: + msg = f"Provided GitHub token does not have the necessary permissions to perform the request for the URL: {self.base_url}." + raise ForbiddenException(msg) from exception + elif exception.status == UNAUTHORIZED: + msg = "Your Github token is either expired or revoked. Please check again." + raise UnauthorizedException(msg) from exception + else: + raise @retryable( retries=RETRIES, @@ -1279,7 +1317,6 @@ async def _get_invalid_repos_for_github_app(self): self.configured_repos, ) ) - for full_repo_name in self.configured_repos: if full_repo_name in invalid_repos: continue @@ -1432,7 +1469,7 @@ async def _validate_personal_access_token_scopes(self): if self.configuration["auth_method"] != PERSONAL_ACCESS_TOKEN: return - scopes = await self.github_client.get_personal_access_token_scopes() + scopes = await self.github_client.get_personal_access_token_scopes() or set() required_scopes = {"repo", "user", "read:org"} for scope in ["write:org", "admin:org"]: diff --git a/connectors/sources/google.py b/connectors/sources/google.py index 286d5fe5e..c29db2d4a 100644 --- a/connectors/sources/google.py +++ b/connectors/sources/google.py @@ -5,10 +5,12 @@ # import json import os +import urllib.parse from enum import Enum from aiogoogle import Aiogoogle, AuthError, HTTPError from aiogoogle.auth.creds import ServiceAccountCreds +from aiogoogle.models import Request from aiogoogle.sessions.aiohttp_session import AiohttpSession from connectors.logger import logger @@ -192,6 +194,41 @@ async def _call_api(google_client, method_object, kwargs): return await anext(self._execute_api_call(resource, method, _call_api, kwargs)) + async def api_call_custom(self, api_service, url_path, params=None, method="GET"): + """ + Make a custom API call to any Google API endpoint by specifying only the service name, + along with the path and parameters. + + Args: + api_service (str): The Google service to use (e.g., "slides" or "drive"). + url_path (str): The specific path for the API call (e.g., "/presentations/.../thumbnail"). + params (dict): Optional dictionary of query parameters. + method (str): HTTP method for the request. Default is "GET". + + Returns: + dict: The response from the Google API. + + Raises: + HTTPError: If the API request fails. + """ + # Construct the full URL with the specified Google service + base_url = f"https://{api_service}.googleapis.com/v1" + full_url = f"{base_url}{url_path}" + + if params: + full_url += f"?{urllib.parse.urlencode(params)}" + + # Define the request + request = Request(method=method, url=full_url) + + # Perform the request + async with Aiogoogle( + service_account_creds=self.service_account_credentials + ) as google_client: + response = await google_client.as_service_account(request) + + return response + async def _execute_api_call(self, resource, method, call_api_func, kwargs): """Execute the API call with common try/except logic. Args: diff --git a/connectors/sources/google_drive.py b/connectors/sources/google_drive.py index fb108c8ef..7094e1c46 100644 --- a/connectors/sources/google_drive.py +++ b/connectors/sources/google_drive.py @@ -4,16 +4,26 @@ # you may not use this file except in compliance with the Elastic License 2.0. # import asyncio +import base64 +import os +import time from functools import cached_property, partial +import requests from aiogoogle import HTTPError +from openai import AsyncAzureOpenAI from connectors.access_control import ( ACCESS_CONTROL, es_access_control_query, prefix_identity, ) -from connectors.source import BaseDataSource, ConfigurableFieldValueError +from connectors.es.sink import OP_DELETE, OP_INDEX +from connectors.source import ( + CURSOR_SYNC_TIMESTAMP, + BaseDataSource, + ConfigurableFieldValueError, +) from connectors.sources.google import ( GoogleServiceAccountClient, UserFields, @@ -23,33 +33,489 @@ ) from connectors.utils import ( EMAIL_REGEX_PATTERN, + ConcurrentTasks, + RetryStrategy, + iso_zulu, + retryable, + sleeps_for_retryable, validate_email_address, ) GOOGLE_DRIVE_SERVICE_NAME = "Google Drive" GOOGLE_ADMIN_DIRECTORY_SERVICE_NAME = "Google Admin Directory" +CURSOR_GOOGLE_DRIVE_KEY = "google_drives" -RETRIES = 3 -RETRY_INTERVAL = 2 - -GOOGLE_API_MAX_CONCURRENCY = 25 # Max open connections to Google API +RETRIES = 5 +RETRY_INTERVAL = 5 +GOOGLE_API_MAX_CONCURRENCY = 1 # Max open connections to Google API DRIVE_API_TIMEOUT = 1 * 60 # 1 min FOLDER_MIME_TYPE = "application/vnd.google-apps.folder" -DRIVE_ITEMS_FIELDS = "id,createdTime,driveId,modifiedTime,name,size,mimeType,fileExtension,webViewLink,owners,parents" +DRIVE_ITEMS_FIELDS = "id,createdTime,driveId,modifiedTime,name,size,mimeType,fileExtension,webViewLink,owners,parents,trashed,trashedTime" DRIVE_ITEMS_FIELDS_WITH_PERMISSIONS = f"{DRIVE_ITEMS_FIELDS},permissions" +SLIDES_MIME_TYPE = "application/vnd.google-apps.presentation" +SLIDES_FIELDS = "slides(objectId,slideProperties(isSkipped,notesPage(pageElements(shape(text(textElements)))))),title" + +AUDIO_VIDEO_MIME_TYPES = { + "audio/mpeg", + "audio/wav", + "audio/x-wav", + "video/mp4", + "video/quicktime", +} + +MB_TO_BYTES = 1024 * 1024 + +GOOGLE_PRESENTATION_QUEUE = ConcurrentTasks(max_concurrency=5) +AZURE_OPENAI_WHISPER_QUEUE = ConcurrentTasks(max_concurrency=1) + # Export Google Workspace documents to TIKA compatible format, prefer 'text/plain' where possible to be # mindful of the content extraction service resources GOOGLE_MIME_TYPES_MAPPING = { - "application/vnd.google-apps.document": "text/plain", + "application/vnd.google-apps.document": "text/markdown", "application/vnd.google-apps.presentation": "text/plain", - "application/vnd.google-apps.spreadsheet": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "application/vnd.google-apps.spreadsheet": "text/csv", } +class GoogleSlidesExtractor(GoogleServiceAccountClient): + """Handles API interactions and content extraction for Google Slides.""" + + def __init__(self, json_credentials, azure_llm_client, logger, subject=None): + """ + Initialize GoogleSlidesExtractor. + + Args: + json_credentials (dict): Service account credentials. + azure_llm_client: LLM client for content extraction. + logger: Logger instance. + subject (str, optional): Subject to impersonate. + """ + # Call the parent constructor + remove_universe_domain(json_credentials) + if subject: + json_credentials["subject"] = subject + + super().__init__( + json_credentials=json_credentials, + api="slides", + api_version="v1", + scopes=["https://www.googleapis.com/auth/presentations.readonly"], + api_timeout=DRIVE_API_TIMEOUT, + ) + + if subject: + self.subject = subject + + self.azure_llm_client = azure_llm_client + self.logger = logger + + # Parameters for retry logic and rate limiting + self.system_prompt = ( + "Summarize the core textual and conceptual content of the slide while:" + "\n- Including names of brands, products, or entities from logos" + "\n- Excluding design elements like colors, backgrounds, shapes" + "\n- Ignoring Elastic mentions unless contextually relevant" + "\n- Excluding confidentiality notices or access labels" + "\n- Describing themes and main points from text and images" + "\n- Using Mermaid syntax for diagrams (e.g., graph TD for flows)" + "\n- Using Markdown formatting" + ) + + @retryable( + retries=RETRIES, + interval=RETRY_INTERVAL, + strategy=RetryStrategy.EXPONENTIAL_BACKOFF, + ) + async def _get_presentation(self, presentation_id, fields): + """ + Retrieves the presentation metadata from Google Slides API + + Args: + presentation_id (str): The presentation ID. + fields (str): Fields to fetch. + + Returns: + dict: Presentation metadata + """ + return await self.api_call( + resource="presentations", + method="get", + presentationId=presentation_id, + fields=fields, + ) + + @retryable( + retries=RETRIES, + interval=RETRY_INTERVAL, + strategy=RetryStrategy.EXPONENTIAL_BACKOFF, + ) + async def _get_thumbnail_url( + self, presentation_id, page_id, mime_type="PNG", thumbnail_size="MEDIUM" + ): + """Retrieves the URL of a slide's thumbnail from the Google Slides API. + + Args: + presentation_id (str): The ID of the Google Slides presentation. + slide_id (str): The ID of the slide within the presentation. + + Returns: + str: The URL of the slide's thumbnail. + + Raises: + requests.exceptions.HTTPError: If there's an HTTP error during the API call. + Exception: If the API response does not contain a contentUrl. + """ + + self.logger.info( + f"[Presentation ID: {presentation_id}] Fetching thumbnail URL for slide {page_id}" + ) + response = await self.api_call_custom( + api_service="slides", + url_path=f"/presentations/{presentation_id}/pages/{page_id}/thumbnail", + params={ + "thumbnailProperties.mimeType": "PNG", + "thumbnailProperties.thumbnailSize": "MEDIUM", + }, + ) + self.logger.debug( + f"[Presentation ID: {presentation_id}] Google Slides API response: {response}" + ) + thumbnail_url = response.get("contentUrl") + if not thumbnail_url: + self.logger.error( + f"[Presentation ID: {presentation_id}] No contentUrl found in thumbnail API response: {response}" + ) + msg = "No thumbnail URL found." + raise Exception(msg) + self.logger.info( + f"[Presentation ID: {presentation_id}] Retrieved thumbnail URL: {thumbnail_url}" + ) + return thumbnail_url + + @retryable( + retries=RETRIES, + interval=RETRY_INTERVAL, + strategy=RetryStrategy.EXPONENTIAL_BACKOFF, + ) + async def _get_thumbnail_data(self, presentation_id, thumbnail_url): + """Downloads and encodes thumbnail data from a given URL. + + Args: + thumbnail_url (str): The URL of the thumbnail image. + + Returns: + tuple: A tuple containing the base64-encoded thumbnail data (str) and the content type (str). + Returns (None, None) if there's an error. + + Raises: + requests.exceptions.HTTPError: If there's an HTTP error during the download. + Exception: For any other errors during the download or encoding process. + """ + self.logger.info( + f"[Presentation ID: {presentation_id}] Downloading thumbnail data from: {thumbnail_url}" + ) + image_response = requests.get(thumbnail_url, timeout=DRIVE_API_TIMEOUT) + image_response.raise_for_status() + + self.logger.debug( + f"[Presentation ID: {presentation_id}] Thumbnail download successful. Status code: {image_response.status_code}" + ) + + image_data = image_response.content + content_type = image_response.headers.get("content-type", "image/png") + encoded_image = base64.b64encode(image_data).decode("utf-8") + thumbnail_encoded_type = f"data:{content_type};base64,{encoded_image}" + self.logger.debug( + f"[Presentation ID: {presentation_id}] Thumbnail encoded successfully." + ) + return thumbnail_encoded_type + + async def process_presentation(self, presentation_id, fields): + """ + Extract content from a Google Slides presentation. + + Args: + presentation_id (str): The presentation ID. + fields (str): Fields to fetch. + + Returns: + dict: Contains title and body of the processed presentation. + """ + try: + self.logger.info(f"[Presentation ID: {presentation_id}] Fetching metadata") + presentation = await self._get_presentation(presentation_id, fields) + slides = presentation.get("slides", []) + self.logger.info( + f"[Presentation ID: {presentation_id}] Processing {len(slides)} slides" + ) + body = await self._process_slides(presentation_id, slides) + return {"title": presentation.get("title", "Untitled"), "body": body} + except Exception as e: + self.logger.error( + f"[Presentation ID: {presentation_id}] Failed to process: {e}" + ) + raise + + async def _process_slides(self, presentation_id, slides): + """ + Process all slides in a presentation synchronously, with all logic in a single function. + + Args: + presentation_id (str): The presentation ID. + slides (list): List of slide data. + + Returns: + str: Combined content of all slides in Markdown format. + """ + combined_content = [] + + for i, slide in enumerate(slides): + slide_number = i + 1 + + # Skip slides marked as skipped + if slide.get("slideProperties", {}).get("isSkipped", False): + self.logger.info( + f"[Presentation ID: {presentation_id}] Skipping slide {slide_number} (marked as skipped)." + ) + combined_content.append(f"### Slide {slide_number}\n\nSkipped.") + continue + + # Extract speaker notes + speaker_notes = [] + try: + notes_page = slide.get("slideProperties", {}).get("notesPage", {}) + if notes_page and "pageElements" in notes_page: + for element in notes_page.get("pageElements", []): + if not isinstance(element, dict): + continue + shape = element.get("shape", {}) + if "text" in shape: + for text_element in shape.get("textElements", []): + if ( + isinstance(text_element, dict) + and "textRun" in text_element + ): + note_content = ( + text_element["textRun"] + .get("content", "") + .strip() + ) + if note_content: + speaker_notes.append(note_content) + except Exception as e: + self.logger.error( + f"[Presentation ID: {presentation_id}] Error extracting speaker notes for slide {slide_number}: {e}" + ) + + speaker_notes = "\n".join(speaker_notes) + + # Get thumbnail + try: + # Get thumbnail URL + page_id = slide.get("objectId") + thumbnail_url = await self._get_thumbnail_url( + presentation_id=presentation_id, + page_id=page_id, + mime_type="PNG", + thumbnail_size="MEDIUM", + ) + + # Download and encode thumbnail as base64 + thumbnail_encoded_type = await self._get_thumbnail_data( + presentation_id, thumbnail_url + ) + + except requests.exceptions.HTTPError as e: + self.logger.warning( + f"[Presentation ID: {presentation_id}] Error downloading thumbnail for slide {page_id}: {e}" + ) + except Exception as e: + self.logger.warning( + f"[Presentation ID: {presentation_id}] Error fetching or downloading thumbnail for slide {page_id}: {e}" + ) + + if not thumbnail_encoded_type: + self.logger.error( + f"[Presentation ID: {presentation_id}] Failed to retrieve thumbnail for slide {page_id}." + ) + + # Process slide content + content = None + try: + messages = [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": thumbnail_encoded_type}, + } + ], + }, + ] + response = await self.azure_llm_client.chat.completions.create( + model="gpt-4o-mini", + messages=messages, + max_tokens=1500, + temperature=1, + ) + content = ( + response.choices[0].message.content if response.choices else "" + ) + if speaker_notes: + content += f"\n\n**Speaker Notes:**\n\n{speaker_notes}" + except Exception as e: + self.logger.error( + f"[Presentation ID: {presentation_id}] Error processing content for slide {slide_number}: {e}" + ) + content = f"Error processing slide content: {e}" + + # Add slide content to combined output + combined_content.append(f"### Slide {slide_number}\n\n{content}") + + # Combine all slide contents into a single string + return "\n\n---\n\n".join(combined_content) + + +class AzureOpenAIWhisperClient(GoogleServiceAccountClient): + """Handles transcription using Azure OpenAI Whisper.""" + + def __init__( + self, + json_credentials, + azure_openai_whisper_client, + logger, + rate_limit=3, + interval=60, + subject=None, + max_concurrency=1, + ): + """ + Initialize the transcription client. + + Args: + json_credentials (dict): Service account credentials. + azure_openai_whisper_api_key (str): API key for Azure OpenAI Whisper. + azure_openai_whisper_endpoint (str): Endpoint URL for Azure OpenAI Whisper. + azure_openai_whisper_version (str): API version for Azure OpenAI Whisper. + logger (logging.Logger): Logger instance. + rate_limit (int): Maximum transcriptions per interval. + interval (int): Time interval for rate limiting in seconds. + max_concurrency (int): Maximum concurrent transcription calls + + This client uses the Google Drive API to download file content and Azure OpenAI Whisper to transcribe files. + It requires the following scopes: + - https://www.googleapis.com/auth/drive.readonly + - https://www.googleapis.com/auth/drive.file + """ + remove_universe_domain(json_credentials) + if subject: + json_credentials["subject"] = subject + + super().__init__( + json_credentials=json_credentials, + api="drive", + api_version="v3", + scopes=[ + "https://www.googleapis.com/auth/drive.readonly", + "https://www.googleapis.com/auth/drive.file", + ], + api_timeout=DRIVE_API_TIMEOUT, + ) + + if subject: + self.subject = subject + + self.azure_openai_whisper_client = azure_openai_whisper_client + self.logger = logger + self.rate_limit = rate_limit + self.interval = interval + self.transcriptions_done = 0 + self.last_reset_time = time.time() + self.allowed_mime_types = AUDIO_VIDEO_MIME_TYPES + self.MAX_RETRIES = RETRIES + self.BASE_DELAY = RETRY_INTERVAL + # self.semaphore = asyncio.Semaphore(max_concurrency) + + async def _check_rate_limit(self): + """Enforces the transcription rate limit.""" + current_time = time.time() + if current_time - self.last_reset_time >= self.interval: + self.transcriptions_done = 0 + self.last_reset_time = current_time + + if self.transcriptions_done >= self.rate_limit: + wait_time = self.interval - (current_time - self.last_reset_time) + self.logger.info(f"Rate limit reached. Waiting {wait_time:.2f} seconds.") + await sleeps_for_retryable.sleep(wait_time) + self.transcriptions_done = 0 + self.last_reset_time = time.time() + + @retryable( + retries=RETRIES, + interval=RETRY_INTERVAL, + strategy=RetryStrategy.EXPONENTIAL_BACKOFF, + ) + async def transcribe(self, temp_file, file): + """ + Transcribes an audio or video file. + + Args: + buffer (tempfile): file buffer + file (dict): Metadata of the file to be transcribed. + + Returns: + dict: Transcribed text or error information. + """ + + try: + self.logger.info(f"Transcription: Starting file: {file['name']}") + + # Rename temp file to original file extension because without it Azure OpenAI Whisper will not accept the file + try: + new_temp_file = temp_file + os.path.splitext(file["name"])[1] + os.rename(temp_file, new_temp_file) + self.logger.debug( + f"Transcription: Renamed temp file to: {new_temp_file}" + ) + temp_file = new_temp_file + except Exception as e: + self.logger.error(f"Transcription: Error renaming temp file: {e}") + + # async with self.semaphore: + with open(temp_file, "rb") as audio_file: + response = ( + await self.azure_openai_whisper_client.audio.transcriptions.create( + file=audio_file, + model="whisper", + ) + ) + + self.transcriptions_done += 1 + + # Handle response based on Azure OpenAI's format + transcribed_text = ( + response.text if hasattr(response, "text") else str(response) + ) + + except Exception as e: + self.logger.error(f"Transcription failed for file {file['name']}: {e}") + transcribed_text = f"Error: {e}" + + return transcribed_text + + +class SyncCursorEmpty(Exception): + """Exception class to notify that incremental sync can't run because sync_cursor is empty.""" + + pass + + class GoogleDriveClient(GoogleServiceAccountClient): """A google drive client to handle api calls made to Google Drive API.""" @@ -142,11 +608,12 @@ async def get_all_folders(self): return folders - async def list_files(self, fetch_permissions=False): + async def list_files(self, fetch_permissions=False, last_sync_time=None): """Get files from Google Drive. Files can have any type. Args: include_permissions (bool): flag to select permissions in the request query + last_sync_time (str): time when last sync happened Yields: dict: Documents from Google Drive. @@ -157,12 +624,15 @@ async def list_files(self, fetch_permissions=False): if fetch_permissions else DRIVE_ITEMS_FIELDS ) - + if last_sync_time is None: + list_query = "trashed=false" + else: + list_query = f"trashed=true or modifiedTime > '{last_sync_time}' or createdTime > '{last_sync_time}'" async for file in self.api_call_paged( resource="files", method="list", corpora="allDrives", - q="trashed=false", + q=list_query, orderBy="modifiedTime desc", fields=f"files({files_fields}),incompleteSearch,nextPageToken", includeItemsFromAllDrives=True, @@ -171,7 +641,9 @@ async def list_files(self, fetch_permissions=False): ): yield file - async def list_files_from_my_drive(self, fetch_permissions=False): + async def list_files_from_my_drive( + self, fetch_permissions=False, last_sync_time=None + ): """Retrieves files from Google Drive, with an option to fetch permissions (DLS). This function optimizes the retrieval process based on the 'fetch_permissions' flag. @@ -182,15 +654,22 @@ async def list_files_from_my_drive(self, fetch_permissions=False): Args: include_permissions (bool): flag to select permissions in the request query + last_sync_time (str): time when last sync happened Yields: dict: Documents from Google Drive. """ - if fetch_permissions: + if fetch_permissions and last_sync_time: + files_fields = DRIVE_ITEMS_FIELDS_WITH_PERMISSIONS + list_query = f"(trashed=true or modifiedTime > '{last_sync_time}' or createdTime > '{last_sync_time}') and 'me' in writers" + elif fetch_permissions and not last_sync_time: files_fields = DRIVE_ITEMS_FIELDS_WITH_PERMISSIONS # Google Drive API required write access to fetch file's permissions list_query = "trashed=false and 'me' in writers" + elif not fetch_permissions and last_sync_time: + files_fields = DRIVE_ITEMS_FIELDS + list_query = f"trashed=true or modifiedTime > '{last_sync_time}' or createdTime > '{last_sync_time}'" else: files_fields = DRIVE_ITEMS_FIELDS list_query = "trashed=false" @@ -338,6 +817,9 @@ def __init__(self, configuration): """ super().__init__(configuration=configuration) + self.azure_llm_client = None + self.azure_openai_whisper_client = None + def _set_internal_logger(self): if self._domain_wide_delegation_sync_enabled() or self._dls_enabled(): self.google_admin_directory_client.set_logger(self._logger) @@ -428,8 +910,188 @@ def get_default_configuration(cls): "ui_restrictions": ["advanced"], "value": False, }, + "enable_slide_content_extraction": { + "display": "toggle", + "label": "Enable Google Presentation to Markdown extraction", + "order": 9, + "tooltip": "Enable GPT-4o model to extract markdown from Google Slides presentations", + "type": "bool", + "value": False, + }, + "azure_openai_gpt4o_api_key": { + "depends_on": [ + {"field": "enable_slide_content_extraction", "value": True} + ], + "display": "text", + "label": "Azure OpenAI GPT-4o API Key", + "order": 10, + "type": "str", + "sensitive": True, + }, + "azure_openai_gpt4o_version": { + "depends_on": [ + {"field": "enable_slide_content_extraction", "value": True} + ], + "display": "text", + "label": "Azure OpenAI GPT-4o Version", + "order": 11, + "type": "str", + }, + "azure_openai_gpt4o_endpoint": { + "depends_on": [ + {"field": "enable_slide_content_extraction", "value": True} + ], + "display": "text", + "label": "Azure OpenAI GPT-4o Endpoint", + "order": 12, + "type": "str", + }, + "enable_audio_video_transcription": { + "display": "toggle", + "label": "Enable Audio/Video Transcription", + "order": 13, + "tooltip": "Enable Azure Whisper API to transcribe audio and video files", + "type": "bool", + "value": False, + }, + "max_size_audio_video_transcription": { + "depends_on": [ + {"field": "enable_audio_video_transcription", "value": True} + ], + "display": "numeric", + "label": "Max File Size for Audio/Video Transcription (MB)", + "order": 14, + "tooltip": "Specify the maximum file size in MB for audio/video transcription. Default is 25MB.", + "type": "int", + "value": 25, + }, + "azure_openai_whisper_api_key": { + "depends_on": [ + {"field": "enable_audio_video_transcription", "value": True} + ], + "display": "text", + "label": "Azure OpenAI Whisper API Key", + "order": 15, + "type": "str", + "sensitive": True, + }, + "azure_openai_whisper_version": { + "depends_on": [ + {"field": "enable_audio_video_transcription", "value": True} + ], + "display": "text", + "label": "Azure OpenAI Whisper Version", + "order": 16, + "type": "str", + }, + "azure_openai_whisper_endpoint": { + "depends_on": [ + {"field": "enable_audio_video_transcription", "value": True} + ], + "display": "text", + "label": "Azure OpenAI Whisper Endpoint", + "order": 17, + "type": "str", + }, } + def google_slides_extractor(self, impersonate_email=None): + """ + Initialize and return an instance of the GoogleSlidesExtractor. + + This method sets up a Google Slides client using service account credentials. + If an impersonate_email is provided, the client will be set up for domain-wide + delegation, allowing it to impersonate the provided user account within + a Google Workspace domain. + + Args: + impersonate_email (str, optional): The email of the user account to impersonate. + Defaults to None, in which case no impersonation is set up. + + Returns: + GoogleSlidesExtractor: An initialized instance of the GoogleSlidesExtractor. + """ + if not self.configuration.get("enable_slide_content_extraction"): + return None + + service_account_credentials = self.configuration["service_account_credentials"] + + validate_service_account_json( + service_account_credentials, GOOGLE_DRIVE_SERVICE_NAME + ) + + json_credentials = load_service_account_json( + service_account_credentials, GOOGLE_DRIVE_SERVICE_NAME + ) + + # Create the LLM client if not created + if not self.azure_llm_client: + self.azure_llm_client = AsyncAzureOpenAI( + api_key=self.configuration["azure_openai_gpt4o_api_key"], + api_version=self.configuration["azure_openai_gpt4o_version"], + azure_endpoint=self.configuration["azure_openai_gpt4o_endpoint"], + ) + + slides_client = GoogleSlidesExtractor( + json_credentials=json_credentials, + azure_llm_client=self.azure_llm_client, + logger=self._logger, + subject=impersonate_email, + ) + + slides_client.set_logger(self._logger) + + return slides_client + + def azure_openai_whisper_transcribe(self, impersonate_email=None): + """ + Initialize and return an instance of the AzureOpenAIWhisperClient. + + This method sets up a Google Drive client using service account credentials. + If an impersonate_email is provided, the client will be set up for domain-wide + delegation, allowing it to impersonate the provided user account within + a Google Workspace domain. + + Args: + impersonate_email (str, optional): The email of the user account to impersonate. + Defaults to None, in which case no impersonation is set up. + + Returns: + AzureOpenAIWhisperClient: An initialized instance of the AzureOpenAIWhisperClient. + """ + if not self.configuration.get("enable_audio_video_transcription"): + return None + + service_account_credentials = self.configuration["service_account_credentials"] + + validate_service_account_json( + service_account_credentials, GOOGLE_DRIVE_SERVICE_NAME + ) + + json_credentials = load_service_account_json( + service_account_credentials, GOOGLE_DRIVE_SERVICE_NAME + ) + + # Create the Whisper client if not created + if not self.azure_openai_whisper_client: + self.azure_openai_whisper_client = AsyncAzureOpenAI( + api_key=self.configuration["azure_openai_whisper_api_key"], + api_version=self.configuration["azure_openai_whisper_version"], + azure_endpoint=self.configuration["azure_openai_whisper_endpoint"], + ) + + whisper_client = AzureOpenAIWhisperClient( + json_credentials=json_credentials, + azure_openai_whisper_client=self.azure_openai_whisper_client, + logger=self._logger, + max_concurrency=self._max_concurrency(), + subject=impersonate_email, + ) + + whisper_client.set_logger(self._logger) + + return whisper_client + def google_drive_client(self, impersonate_email=None): """ Initialize and return an instance of the GoogleDriveClient. @@ -678,9 +1340,9 @@ async def prepare_single_access_control_document(self, user): user_email = user.get("primaryEmail") user_domain = _get_domain_from_email(user_email) user_groups = [] - async for groups_page in self.google_admin_directory_client.list_groups_for_user( - user_id - ): + async for ( + groups_page + ) in self.google_admin_directory_client.list_groups_for_user(user_id): for group in groups_page.get("groups", []): user_groups.append(group.get("email")) @@ -841,7 +1503,7 @@ async def get_google_workspace_content(self, client, file, timestamp=None): # We need to do sanity size after downloading the file because: # 1. We use files/export endpoint which converts large media-rich google slides/docs # into text/plain format. We usually we end up with tiny .txt files. - # 2. Google will ofter report the Google Workspace shared documents to have size 0 + # 2. Google will offer report the Google Workspace shared documents to have size 0 # as they don't count against user's storage quota. if not self.is_file_size_within_limit(file_size, file_name): return @@ -903,7 +1565,7 @@ async def get_generic_file_content(self, client, file, timestamp=None): return document async def get_content(self, client, file, timestamp=None, doit=None): - """Extracts the content from a file file. + """Extracts the content from a file. Args: file (dict): Formatted file document. @@ -915,20 +1577,144 @@ async def get_content(self, client, file, timestamp=None, doit=None): """ if not doit: + self._logger.info( + f"Skipping content extraction for file {file['name']} due to 'doit' flag being False." + ) return + self._logger.info( + f"Starting content extraction for file {file['name']} ({file['mime_type']})" + ) file_mime_type = file["mime_type"] + file_size = int(file["size"]) + + # Handle Google Slides with GPT-4o-mini model to extract markdown from google slides thumbnails + if file_mime_type == SLIDES_MIME_TYPE and self.configuration.get( + "enable_slide_content_extraction" + ): + try: + self._logger.info( + f"Extracting content from Google Slides: {file['name']} (ID: {file['id']})" + ) + google_slides_extractor = self.google_slides_extractor() + + async def process_presentation_task(presentation_id, fields): + return await google_slides_extractor.process_presentation( + presentation_id=presentation_id, fields=fields + ) + + task = await GOOGLE_PRESENTATION_QUEUE.put( + partial(process_presentation_task, file["id"], SLIDES_FIELDS) + ) + presentation = await task - if file_mime_type in GOOGLE_MIME_TYPES_MAPPING: - # Get content from native google workspace files (docs, slides, sheets) - return await self.get_google_workspace_content( - client, file, timestamp=timestamp + self._logger.info( + f"Successfully extracted content for file {file['name']} (ID: {file['id']})" + ) + + return { + "_id": file["id"], + "_timestamp": file["_timestamp"], + "title": presentation.get("title", ""), + "body": presentation.get("body", "Failed to process"), + } + + except Exception as e: + self._logger.error( + f"Error extracting slide content for file {file['name']} (ID: {file['id']}): {str(e)}" + ) + # Fallback to standard Google Workspace content extraction + return await self.get_google_workspace_content(client, file, timestamp) + + # Handle audio/video transcription with Azure OpenAI Whisper + if file_mime_type in AUDIO_VIDEO_MIME_TYPES and self.configuration.get( + "enable_audio_video_transcription" + ): + # Default to 25MB if not set + max_size_bytes = ( + self.configuration.get("max_size_audio_video_transcription", 25) + * MB_TO_BYTES + ) + + # Validate file size before proceeding + if file_size > max_size_bytes: + self._logger.warning( + f"File {file['name']} exceeds the configured size limit of {max_size_bytes / MB_TO_BYTES:.2f} MB." + ) + return { + "_id": file["id"], + "_timestamp": file["_timestamp"], + "error": f"File size {file_size / MB_TO_BYTES:.2f} MB exceeds the limit of {max_size_bytes / MB_TO_BYTES:.2f} MB.", + } + + azure_openai_whisper_transcribe = self.azure_openai_whisper_transcribe() + try: + async with self.create_temp_file( + file.get("file_extension", ".tmp") + ) as async_buffer: + for attempt in range(azure_openai_whisper_transcribe.MAX_RETRIES): + try: + await client.api_call( + resource="files", + method="get", + fileId=file["id"], + alt="media", + pipe_to=async_buffer, + ) + + break + except Exception as e: + self._logger.warning( + f"Error downloading file {file['name']}, attempt {attempt+1}: {e}" + ) + wait_time = azure_openai_whisper_transcribe.BASE_DELAY * ( + 2**attempt + ) + await sleeps_for_retryable.sleep(wait_time) + else: + # if we reach here, all retries have failed + msg = f"Failed to download file {file['name']} after {azure_openai_whisper_transcribe.MAX_RETRIES} attempts." + raise Exception( + msg + ) + + async def process_transcription_task(temp_file, file): + return await azure_openai_whisper_transcribe.transcribe( + temp_file, file + ) + + transcribed_text = await azure_whisper_queue.put( + partial(process_transcription_task, async_buffer.name, file) + ) + # transcribed_text = await task + + return { + "_id": file["id"], + "_timestamp": file["_timestamp"], + "body": transcribed_text, + } + + except Exception as e: + self._logger.error(f"Error transcribing file {file['name']}: {e}") + return { + "_id": file["id"], + "_timestamp": file["_timestamp"], + "error": f"Error transcribing file: {e}", + } + + # Handle other Google Workspace files + elif file_mime_type in GOOGLE_MIME_TYPES_MAPPING: + self._logger.info( + f"Extracting Google Workspace document content for file {file['name']} (ID: {file['id']})" ) + return await self.get_google_workspace_content(client, file, timestamp) + + # Handle generic files else: - # Get content from all other file types - return await self.get_generic_file_content( - client, file, timestamp=timestamp + self._logger.info( + f"Extracting generic file content for file {file['name']} (ID: {file['id']})" ) + return await self.get_generic_file_content(client, file, timestamp) async def _get_permissions_on_shared_drive(self, client, file_id): """Retrieves the permissions on a shared drive for the given file ID. @@ -988,7 +1774,7 @@ async def prepare_file(self, client, file, paths): file (dict): File metadata returned from the Drive. Returns: - dict: Formatted file metadata. + file_document, trashedTime (tuple): Formatted file metadata along with trashedTime for files deleted from shared drive """ file_id, file_name = file.get("id"), file.get("name") @@ -1003,6 +1789,7 @@ async def prepare_file(self, client, file, paths): "mime_type": file.get("mimeType"), "file_extension": file.get("fileExtension"), "url": file.get("webViewLink"), + "trashed": file.get("trashed"), } # record "file" or "folder" type @@ -1061,8 +1848,7 @@ async def prepare_file(self, client, file, paths): self._logger.error(exception_log_msg) file_document[ACCESS_CONTROL] = self._process_permissions(permissions) - - return file_document + return file_document, file.get("trashedTime") async def prepare_files(self, client, files_page, paths, seen_ids): """Generate file document. @@ -1079,7 +1865,12 @@ async def prepare_files(self, client, files_page, paths, seen_ids): new_files = [file for file in files if file.get("id") not in seen_ids] prepared_files = await self._process_items_concurrently( - new_files, lambda f: self.prepare_file(client=client, file=f, paths=paths) + new_files, + lambda f: self.prepare_file( + client=client, + file=f, + paths=paths, + ), ) for file in prepared_files: @@ -1089,7 +1880,7 @@ async def get_docs(self, filtering=None): """Executes the logic to fetch Google Drive objects in an async manner. Args: - filtering (optional): Advenced filtering rules. Defaults to None. + filtering (optional): Advanced filtering rules. Defaults to None. Yields: dict, partial: dict containing meta-data of the Google Drive objects, @@ -1101,6 +1892,8 @@ async def get_docs(self, filtering=None): # This is an optimization to process unique files only once. seen_ids = set() + self.init_sync_cursor() + if self._domain_wide_delegation_sync_enabled(): # sync personal drives first async for user in self.google_admin_directory_client.users(): @@ -1110,7 +1903,7 @@ async def get_docs(self, filtering=None): async for files_page in google_drive_client.list_files_from_my_drive( fetch_permissions=self._dls_enabled() ): - async for file in self.prepare_files( + async for file, _ in self.prepare_files( client=google_drive_client, files_page=files_page, paths={}, @@ -1138,7 +1931,7 @@ async def get_docs(self, filtering=None): async for files_page in shared_drives_client.list_files( fetch_permissions=self._dls_enabled() ): - async for file in self.prepare_files( + async for file, _ in self.prepare_files( client=shared_drives_client, files_page=files_page, paths=resolved_paths, @@ -1156,10 +1949,156 @@ async def get_docs(self, filtering=None): async for files_page in google_drive_client.list_files( fetch_permissions=self._dls_enabled() ): - async for file in self.prepare_files( + async for file, _ in self.prepare_files( client=google_drive_client, files_page=files_page, paths=resolved_paths, seen_ids=seen_ids, ): yield file, partial(self.get_content, google_drive_client, file) + + async def get_docs_incrementally(self, sync_cursor, filtering=None): + """Executes the logic to fetch Google Drive objects incrementally in an async manner. + + Args: + sync_cursor (str): Last sync time. + filtering (optional): Advanced filtering rules. Defaults to None. + + Yields: + dict, partial: dict containing meta-data of the Google Drive objects, + partial download content function + """ + self._sync_cursor = sync_cursor + timestamp = iso_zulu() + self._logger.debug(f"Current Sync Time {timestamp}") + + if not self._sync_cursor: + msg = "Unable to start incremental sync. Please perform a full sync to re-enable incremental syncs." + raise SyncCursorEmpty(msg) + + seen_ids = set() + + if self._domain_wide_delegation_sync_enabled(): + # sync personal drives first + async for user in self.google_admin_directory_client.users(): + email = user.get(UserFields.EMAIL.value) + self._logger.debug(f"Syncing personal drive content for: {email}") + google_drive_client = self.google_drive_client(impersonate_email=email) + async for files_page in google_drive_client.list_files_from_my_drive( + fetch_permissions=self._dls_enabled(), + last_sync_time=self.last_sync_time(), + ): + # personal drive files have no property called trashedTime(time when file was deleted) + async for file, _ in self.prepare_files( + client=google_drive_client, + files_page=files_page, + paths={}, + seen_ids=seen_ids, + ): + if file.get("trashed") is True: + yield ( + file, + partial(self.get_content, google_drive_client, file), + OP_DELETE, + ) + else: + yield ( + file, + partial(self.get_content, google_drive_client, file), + OP_INDEX, + ) + + email_for_shared_drives_sync = ( + self._google_google_workspace_email_for_shared_drives_sync() + ) + + shared_drives_client = self.google_drive_client( + impersonate_email=email_for_shared_drives_sync + ) + + # Build a path lookup, parentId -> parent path + resolved_paths = await self.resolve_paths( + google_drive_client=shared_drives_client + ) + + # sync shared drives + self._logger.debug( + f"Syncing shared drives using admin account: {email_for_shared_drives_sync}" + ) + async for files_page in shared_drives_client.list_files( + fetch_permissions=self._dls_enabled(), + last_sync_time=self.last_sync_time(), + ): + # trashedTime(time when file was deleted) is a property exclusive to files present in shared drive + async for file, trashedTime in self.prepare_files( + client=shared_drives_client, + files_page=files_page, + paths=resolved_paths, + seen_ids=seen_ids, + ): + if ( + trashedTime is None or trashedTime > self.last_sync_time() + ) and file.get("trashed") is True: + yield ( + file, + partial(self.get_content, shared_drives_client, file), + OP_DELETE, + ) + elif ( + trashedTime is not None and trashedTime < self.last_sync_time() + ) and file.get("trashed") is True: + continue + else: + yield ( + file, + partial(self.get_content, shared_drives_client, file), + OP_INDEX, + ) + + else: + # Build a path lookup, parentId -> parent path + resolved_paths = await self.resolve_paths() + + google_drive_client = self.google_drive_client() + + # sync anything shared with the service account + # shared drives can also be shared with service account + # making it possible to sync shared drives without domain wide delegation + async for files_page in google_drive_client.list_files( + fetch_permissions=self._dls_enabled(), + last_sync_time=self.last_sync_time(), + ): + async for file, trashedTime in self.prepare_files( + client=google_drive_client, + files_page=files_page, + paths=resolved_paths, + seen_ids=seen_ids, + ): + if ( + trashedTime is None or trashedTime > self.last_sync_time() + ) and file.get("trashed") is True: + yield ( + file, + partial(self.get_content, google_drive_client, file), + OP_DELETE, + ) + elif ( + trashedTime is not None and trashedTime < self.last_sync_time() + ) and file.get("trashed") is True: + continue + else: + yield ( + file, + partial(self.get_content, google_drive_client, file), + OP_INDEX, + ) + self.update_sync_timestamp_cursor(timestamp) + + def init_sync_cursor(self): + if not self._sync_cursor: + self._sync_cursor = { + CURSOR_GOOGLE_DRIVE_KEY: {}, + CURSOR_SYNC_TIMESTAMP: iso_zulu(), + } + + return self._sync_cursor diff --git a/connectors/sources/network_drive.py b/connectors/sources/network_drive.py index 965bb6e65..897cbd415 100644 --- a/connectors/sources/network_drive.py +++ b/connectors/sources/network_drive.py @@ -724,8 +724,7 @@ async def _user_access_control_doc(self, user, sid, groups_info=None): rid_groups = [] for group_sid in groups_info or []: - rid = group_sid.split("-")[-1] - rid_groups.append(_prefix_rid(rid)) + rid_groups.append(_prefix_rid(group_sid.split("-")[-1])) access_control = [rid_user, prefixed_username, *rid_groups] @@ -744,13 +743,14 @@ def read_user_info_csv(self): try: csv_reader = csv.reader(file, delimiter=";") for row in csv_reader: - user_info.append( - { - "name": row[0], - "user_sid": row[1], - "groups": row[2].split(",") if len(row[2]) > 0 else [], - } - ) + if row: + user_info.append( + { + "name": row[0], + "user_sid": row[1], + "groups": row[2].split(",") if len(row[2]) > 0 else [], + } + ) except csv.Error as e: self._logger.exception( f"Error while reading user mapping file at the location: {self.identity_mappings}. Error: {e}" @@ -880,9 +880,14 @@ async def get_docs(self, filtering=None): if filtering and filtering.has_advanced_rules(): advanced_rules = filtering.get_advanced_rules() async for document in self.fetch_filtered_directory(advanced_rules): - yield document, partial(self.get_content, document) if document[ - "type" - ] == "file" else None + yield ( + document, + ( + partial(self.get_content, document) + if document["type"] == "file" + else None + ), + ) else: groups_info = {} @@ -892,8 +897,13 @@ async def get_docs(self, filtering=None): async for document in self.traverse_diretory( path=rf"\\{self.server_ip}/{self.drive_path}" ): - yield await self._decorate_with_access_control( - document, document["path"], document["type"], groups_info - ), partial(self.get_content, document) if document[ - "type" - ] == "file" else None + yield ( + await self._decorate_with_access_control( + document, document["path"], document["type"], groups_info + ), + ( + partial(self.get_content, document) + if document["type"] == "file" + else None + ), + ) diff --git a/connectors/sources/outlook.py b/connectors/sources/outlook.py index dc3c612a6..f238772c6 100644 --- a/connectors/sources/outlook.py +++ b/connectors/sources/outlook.py @@ -44,6 +44,7 @@ html_to_text, iso_utc, retryable, + url_encode, ) RETRIES = 3 @@ -410,7 +411,8 @@ async def _fetch_token(self): ) async def get_users(self): access_token = await self._fetch_token() - url = f"https://graph.microsoft.com/v1.0/users?$top={TOP}" + filter_ = url_encode("accountEnabled eq true") + url = f"https://graph.microsoft.com/v1.0/users?$top={TOP}&$filter={filter_}" while True: try: async with self._get_session.get( diff --git a/connectors/sources/postgresql.py b/connectors/sources/postgresql.py index c4f229b65..62b8a250f 100644 --- a/connectors/sources/postgresql.py +++ b/connectors/sources/postgresql.py @@ -83,6 +83,7 @@ class PostgreSQLAdvancedRulesValidator(AdvancedRulesValidator): "properties": { "tables": {"type": "array", "minItems": 1}, "query": {"type": "string", "minLength": 1}, + "id_columns": {"type": "array", "minItems": 1}, }, "required": ["tables", "query"], "additionalProperties": False, @@ -548,7 +549,7 @@ async def fetch_documents_from_table(self, table): f"Something went wrong while fetching document for table '{table}'. Error: {exception}" ) - async def fetch_documents_from_query(self, tables, query): + async def fetch_documents_from_query(self, tables, query, id_columns): """Fetches all the data from the given query and format them in Elasticsearch documents Args: @@ -562,7 +563,9 @@ async def fetch_documents_from_query(self, tables, query): f"Fetching records for {tables} tables using custom query: {query}" ) try: - docs_generator = self._yield_docs_custom_query(tables=tables, query=query) + docs_generator = self._yield_docs_custom_query( + tables=tables, query=query, id_columns=id_columns + ) async for doc in docs_generator: yield doc except (InternalClientError, ProgrammingError) as exception: @@ -570,23 +573,29 @@ async def fetch_documents_from_query(self, tables, query): f"Something went wrong while fetching document for query '{query}' and tables {', '.join(tables)}. Error: {exception}" ) - async def _yield_docs_custom_query(self, tables, query): + async def _yield_docs_custom_query(self, tables, query, id_columns): primary_key_columns, _ = await self.get_primary_key(tables=tables) + + if id_columns: + primary_key_columns = id_columns + if not primary_key_columns: self._logger.warning( f"Skipping tables {', '.join(tables)} from database {self.database} since no primary key is associated with them. Assign primary key to the tables to index it in the next sync interval." ) return - last_update_times = list( - filter( - lambda update_time: update_time is not None, - [ + last_update_times = [] + for table in tables: + try: + last_update_time = ( await self.postgresql_client.get_table_last_update_time(table) - for table in tables - ], - ) - ) + ) + last_update_times.append(last_update_time) + except Exception: + self._logger.warning("Last update time is not found for Table: {table}") + last_update_times.append(iso_utc()) + last_update_time = ( max(last_update_times) if len(last_update_times) else iso_utc() ) @@ -598,7 +607,10 @@ async def _yield_docs_custom_query(self, tables, query): yield self.serialize( doc=self.row2doc( - row=row, doc_id=doc_id, table=tables, timestamp=last_update_time + row=row, + doc_id=doc_id, + table=tables, + timestamp=last_update_time or iso_utc(), ) ) @@ -669,7 +681,7 @@ async def yield_rows_for_query( yield row else: self._logger.warning( - f"Skipping query {query} for tables {', '.join(tables)} as primary key column name is not present in query." + f"Skipping query {query} for tables {', '.join(tables)} as primary key column or unique ID column name is not present in query." ) async def get_docs(self, filtering=None): @@ -687,9 +699,14 @@ async def get_docs(self, filtering=None): for rule in advanced_rules: query = rule.get("query") tables = rule.get("tables") - + id_columns = rule.get("id_columns") + if id_columns: + id_columns = [ + f"{self.schema}_{'_'.join(sorted(tables))}_{column}" + for column in id_columns + ] async for row in self.fetch_documents_from_query( - tables=tables, query=query + tables=tables, query=query, id_columns=id_columns ): yield row, None diff --git a/connectors/utils.py b/connectors/utils.py index b7bc06d32..9d8e1986e 100644 --- a/connectors/utils.py +++ b/connectors/utils.py @@ -457,7 +457,7 @@ def _callback(self, task): ) elif task.exception(): logger.error( - f"Exception found for task {task.get_name()}: {task.exception()}", + f"Exception found for task {task.get_name()}", exc_info=task.exception() ) def _add_task(self, coroutine, name=None): diff --git a/docs/REFERENCE.md b/docs/REFERENCE.md index a8a0dbe03..af7c2429a 100644 --- a/docs/REFERENCE.md +++ b/docs/REFERENCE.md @@ -34,7 +34,7 @@ The columns provide specific information about each connector: | [Microsoft SQL Server](https://www.elastic.co/guide/en/enterprise-search/current/connectors-ms-sql.html) | **GA** | 8.8+ | 8.11+ | - | - | - | [View code](https://github.com/elastic/connectors/tree/main/connectors/sources/mssql.py) | | [MySQL](https://www.elastic.co/guide/en/enterprise-search/current/connectors-mysql.html) | **GA** | 8.5+ | 8.8+ | - | - | - | [View code](https://github.com/elastic/connectors/tree/main/connectors/sources/mysql.py) | | [Network drive](https://www.elastic.co/guide/en/enterprise-search/current/connectors-network-drive.html) | **GA** | 8.9+ | 8.10+ | 8.14+ | 8.13+ | 8.11+ | [View code](https://github.com/elastic/connectors/tree/main/connectors/sources/network_drive.py) | -| [Notion](https://www.elastic.co/guide/en/enterprise-search/current/connectors-notion.html) | **Beta** | 8.14+ | 8.14+ | - | - | - | [View code](https://github.com/elastic/connectors/tree/main/connectors/sources/notion.py) | +| [Notion](https://www.elastic.co/guide/en/enterprise-search/current/connectors-notion.html) | **GA** | 8.14+ | 8.14+ | - | - | - | [View code](https://github.com/elastic/connectors/tree/main/connectors/sources/notion.py) | | [OneDrive](https://www.elastic.co/guide/en/enterprise-search/current/connectors-onedrive.html) | **GA** | 8.11+ | 8.11+ | 8.11+ | 8.13+ | 8.11+ | [View code](https://github.com/elastic/connectors/tree/main/connectors/sources/onedrive.py) | | [Opentext Documentum](https://www.elastic.co/guide/en/enterprise-search/current/connectors-opentext.html) | **Example** | n/a | n/a | n/a | n/a | - | [View code](https://github.com/elastic/connectors/tree/main/connectors/sources/opentext_documentum.py) | | [Oracle](https://www.elastic.co/guide/en/enterprise-search/current/connectors-oracle.html) | **GA** | 8.12+ | - | - | - | - | [View code](https://github.com/elastic/connectors/tree/main/connectors/sources/oracle.py) | @@ -45,7 +45,7 @@ The columns provide specific information about each connector: | [Salesforce](https://www.elastic.co/guide/en/enterprise-search/current/connectors-salesforce.html) | **GA** | 8.12+ | 8.12+ | 8.11+ | 8.13+ | 8.13+ | [View code](https://github.com/elastic/connectors/tree/main/connectors/sources/salesforce.py) | | [ServiceNow](https://www.elastic.co/guide/en/enterprise-search/current/connectors-servicenow.html) | **GA** | 8.10+ | 8.10+ | 8.11+ | 8.13+ | 8.13+ | [View code](https://github.com/elastic/connectors/tree/main/connectors/sources/servicenow.py) | | [Sharepoint Online](https://www.elastic.co/guide/en/enterprise-search/current/connectors-sharepoint-online.html) | **GA** | 8.9+ | 8.9+ | 8.9+ | 8.9+ | 8.9+ | [View code](https://github.com/elastic/connectors/tree/main/connectors/sources/sharepoint_online.py) | -| [Sharepoint Server](https://www.elastic.co/guide/en/enterprise-search/current/connectors-sharepoint.html) | **Beta** | - | - | 8.11+ | 8.13+ | 8.14+ | [View code](https://github.com/elastic/connectors/tree/main/connectors/sources/sharepoint_server.py) | +| [Sharepoint Server](https://www.elastic.co/guide/en/enterprise-search/current/connectors-sharepoint.html) | **Beta** | 8.15+ | - | 8.11+ | 8.13+ | 8.14+ | [View code](https://github.com/elastic/connectors/tree/main/connectors/sources/sharepoint_server.py) | | [Slack](https://www.elastic.co/guide/en/enterprise-search/current/connectors-slack.html) | **Preview** | 8.14+ | - | - | - | - | [View code](https://github.com/elastic/connectors/tree/main/connectors/sources/slack.py) | | [Teams](https://www.elastic.co/guide/en/enterprise-search/current/connectors-teams.html) | **Preview** | 8.14+ | - | - | 8.13+ | - | [View code](https://github.com/elastic/connectors/tree/main/connectors/sources/teams.py) | | [Zoom](https://www.elastic.co/guide/en/enterprise-search/current/connectors-zoom.html) | **Preview** | 8.14+ | - | 8.11+ | 8.13+ | - | [View code](https://github.com/elastic/connectors/tree/main/connectors/sources/zoom.py) | diff --git a/requirements/aarch64.txt b/requirements/aarch64.txt index 57c326ea0..656a370e3 100644 --- a/requirements/aarch64.txt +++ b/requirements/aarch64.txt @@ -3,5 +3,5 @@ pymongo==4.6.3 motor==3.4.0 -smbprotocol==1.9.0 pymongo[srv]==4.6.3 +smbprotocol==1.10.1 diff --git a/requirements/framework.txt b/requirements/framework.txt index bb6aeb8fa..bbeee6a2f 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -1,4 +1,4 @@ -aiohttp==3.9.5 +aiohttp==3.10.3 aiofiles==23.2.1 aiomysql==0.1.1 elasticsearch[async]==8.14.0 @@ -26,7 +26,7 @@ dropbox==11.36.2 beautifulsoup4==4.12.2 gidgethub==5.2.1 wcmatch==8.4.1 -msal==1.23.0 +msal==1.30.0 exchangelib==5.4.0 ldap3==2.9.1 lxml==4.9.3 @@ -38,5 +38,7 @@ redis==5.0.1 simple-term-menu==1.6.4 graphql-core==3.2.3 notion-client==2.2.1 -certifi==2023.7.22 +certifi==2024.7.4 aioboto3==12.4.0 +pyasn1<0.6.1 +openai<1.58.0 diff --git a/requirements/x86_64.txt b/requirements/x86_64.txt index 72582dab6..65722904e 100644 --- a/requirements/x86_64.txt +++ b/requirements/x86_64.txt @@ -3,4 +3,4 @@ pymongo==4.6.3 motor==3.4.0 -smbprotocol==1.9.0 +smbprotocol==1.10.1 diff --git a/tests/sources/fixtures/github/fixture.py b/tests/sources/fixtures/github/fixture.py index ce61aad79..a17bac959 100644 --- a/tests/sources/fixtures/github/fixture.py +++ b/tests/sources/fixtures/github/fixture.py @@ -53,7 +53,7 @@ def __init__(self): self.get_commits ) self.app.route("/api/graphql", methods=["POST"])(self.mock_graphql_response) - self.app.route("/api", methods=["HEAD"])(self.get_scopes) + self.app.route("/api/graphql", methods=["HEAD"])(self.get_scopes) self.files = {} def encode_cursor(self, value): diff --git a/tests/sources/fixtures/mssql/fixture.py b/tests/sources/fixtures/mssql/fixture.py index 668fa671f..ba6783d35 100644 --- a/tests/sources/fixtures/mssql/fixture.py +++ b/tests/sources/fixtures/mssql/fixture.py @@ -4,10 +4,12 @@ # you may not use this file except in compliance with the Elastic License 2.0. # # ruff: noqa: T201 +import base64 import os import random import pytds +from faker import Faker from tests.commons import WeightedFakeProvider @@ -24,6 +26,8 @@ weights=[0.65, 0.3, 0.05, 0] ) # SQL does not like huge blobs +faked = Faker() + BATCH_SIZE = 1000 DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() @@ -61,8 +65,82 @@ def inject_lines(table, cursor, lines): rows = [] batch_size = min(BATCH_SIZE, lines - inserted) for row_id in range(batch_size): - rows.append((fake_provider.fake.name(), row_id, fake_provider.get_text())) - sql_query = f"INSERT INTO customers_{table} (name, age, description) VALUES (%s, %s, %s)" + rows.append( + ( + fake_provider.fake.name(), # name + row_id, # age + fake_provider.get_text(), # description + faked.date_time(), # record_time + faked.pydecimal(left_digits=10, right_digits=2), # balance + faked.random_letter(), # initials + faked.boolean(), # active + faked.random_int(), # points + faked.pydecimal(left_digits=10, right_digits=2), # salary + faked.pydecimal( + left_digits=8, + right_digits=4, + min_value=-214748, + max_value=214748, + ), # bonus + faked.random_int(), # score + faked.pydecimal(left_digits=2, right_digits=1), # rating + faked.pydecimal(left_digits=2, right_digits=2), # discount + faked.date(), # birthdate + faked.date_time(), # appointment + faked.date_time(), # created_at + faked.date_time(), # updated_at + faked.date_time(), # last_login + faked.date_time(), # expiration + faked.random_element(elements=("A", "I")), # status + faked.text(max_nb_chars=100), # notes + faked.text(max_nb_chars=100), # additional_info + faked.uuid4(), # unique_key + faked.json(), # config + faked.random_int(min=1, max=10), # small_age + base64.b64encode( + b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x10\x00\x00\x00\x10\x08\x06\x00\x00\x00\x1f\xf3\xff\xa6\x00\x00\x00" + ).decode( + "utf-8" + ), # profile_pic + ) + ) + + columns = [ + "name", + "age", + "description", + "record_time", + "balance", + "initials", + "active", + "points", + "salary", + "bonus", + "score", + "rating", + "discount", + "birthdate", + "appointment", + "created_at", + "updated_at", + "last_login", + "expiration", + "status", + "notes", + "additional_info", + "unique_key", + "config", + "small_age", + "profile_pic", + ] + + placeholders = ", ".join(["%s"] * len(columns)) + column_names = ", ".join(columns) + + sql_query = ( + f"INSERT INTO customers_{table} ({column_names}) VALUES ({placeholders})" + ) + cursor.executemany(sql_query, rows) inserted += batch_size print(f"Inserted batch #{batch} of {batch_size} documents.") @@ -89,7 +167,7 @@ async def load(): for table in range(NUM_TABLES): print(f"Adding data to table customers_{table}...") - sql_query = f"CREATE TABLE customers_{table} (id INT IDENTITY(1,1), name VARCHAR(255), age int, description TEXT, PRIMARY KEY (id))" + sql_query = f"CREATE TABLE customers_{table} (id INT IDENTITY(1,1), name VARCHAR(255), age SMALLINT, description TEXT, record_time TIME, balance DECIMAL(18, 2), initials CHAR(3), active BIT, points BIGINT, salary MONEY, bonus SMALLMONEY, score NUMERIC(10, 2), rating FLOAT, discount REAL, birthdate DATE, appointment TIME, created_at DATETIME2, updated_at DATETIMEOFFSET, last_login DATETIME, expiration SMALLDATETIME, status CHAR(1), notes VARCHAR(1000), additional_info NTEXT, unique_key UNIQUEIDENTIFIER, config XML, small_age TINYINT, profile_pic nvarchar(max), PRIMARY KEY (id))" cursor.execute(sql_query) inject_lines(table, cursor, RECORD_COUNT) database.commit() diff --git a/tests/sources/test_github.py b/tests/sources/test_github.py index 02e964d24..a1b4c2b9f 100644 --- a/tests/sources/test_github.py +++ b/tests/sources/test_github.py @@ -6,12 +6,13 @@ """Tests the Github source class methods""" from contextlib import asynccontextmanager from copy import deepcopy +from http import HTTPStatus from unittest.mock import ANY, AsyncMock, Mock, patch import aiohttp import pytest from aiohttp.client_exceptions import ClientResponseError -from gidgethub.abc import GraphQLAuthorizationFailure, QueryError +from gidgethub.abc import BadGraphQLRequest, GraphQLAuthorizationFailure, QueryError from connectors.access_control import DLS_QUERY from connectors.filtering.validation import SyncRuleValidationResult @@ -23,6 +24,7 @@ REPOSITORY_OBJECT, ForbiddenException, GitHubAdvancedRulesValidator, + GitHubClient, GitHubDataSource, UnauthorizedException, ) @@ -964,19 +966,82 @@ async def test_get_retry_after(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_graphql_with_errors(): +@pytest.mark.parametrize( + "exceptions, raises", + [ + ( + BadGraphQLRequest( + status_code=HTTPStatus.FORBIDDEN, response={"message": None} + ), + ForbiddenException, + ), + ( + BadGraphQLRequest( + status_code=HTTPStatus.CONFLICT, response={"message": None} + ), + BadGraphQLRequest, + ), + ], +) +async def test_graphql_with_BadGraphQLRequest(exceptions, raises): async with create_github_source() as source: - source.github_client._get_client.graphql = Mock( - side_effect=QueryError( - {"errors": [{"type": "QUERY", "message": "Invalid query"}]} - ) - ) - with pytest.raises(Exception): + source.github_client._get_client.graphql = Mock(side_effect=exceptions) + with pytest.raises(raises): await source.github_client.graphql( {"variable": {"owner": "demo_user"}, "query": "QUERY"} ) +@pytest.mark.asyncio +@patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) +@pytest.mark.parametrize( + "exceptions, raises, is_raised", + [ + ( + QueryError( + { + "errors": [ + { + "type": "RATE_LIMITED", + "message": "API rate limit exceeded for user ID: 123456", + } + ] + } + ), + Exception, + False, + ), + ( + QueryError( + { + "errors": [ + {"type": "SOME_QUERY_ERROR", "message": "Some query error."} + ] + } + ), + Exception, + True, + ), + ], +) +async def test_graphql_with_QueryError(exceptions, raises, is_raised): + async with create_github_source() as source: + source.github_client._get_client.graphql = Mock(side_effect=exceptions) + if is_raised: + with pytest.raises(raises): + await source.github_client.graphql( + {"variable": {"owner": "demo_user"}, "query": "QUERY"} + ) + else: + with patch.object( + GitHubClient, "_get_retry_after", AsyncMock(return_value=0) + ): + with pytest.raises(raises): + await source.github_client.graphql( + {"variable": {"owner": "demo_user"}, "query": "QUERY"} + ) + + @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) async def test_graphql_with_unauthorized(): @@ -1478,6 +1543,19 @@ async def test_fetch_files(): assert expected_response == document +@pytest.mark.asyncio +@pytest.mark.parametrize( + "exception", + [UnauthorizedException, ForbiddenException], +) +async def test_fetch_files_when_error_occurs(exception): + async with create_github_source() as source: + source.github_client.get_github_item = Mock(side_effect=exception()) + with pytest.raises(exception): + async for _ in source._fetch_files("demo_repo", "main"): + pass + + @pytest.mark.asyncio async def test_get_docs(): expected_response = [ @@ -1918,6 +1996,53 @@ async def test_get_personal_access_token_scopes(scopes, expected_scopes): assert actual_scopes == expected_scopes +@pytest.mark.asyncio +@patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) +@pytest.mark.parametrize( + "exception, raises", + [ + ( + ClientResponseError( + status=401, + request_info=aiohttp.RequestInfo( + real_url="", method=None, headers=None, url="" + ), + history=None, + headers={"X-RateLimit-Remaining": 5000}, + ), + UnauthorizedException, + ), + ( + ClientResponseError( + status=403, + request_info=aiohttp.RequestInfo( + real_url="", method=None, headers=None, url="" + ), + history=None, + headers={"X-RateLimit-Remaining": 2500}, + ), + ForbiddenException, + ), + ( + ClientResponseError( + status=404, + request_info=aiohttp.RequestInfo( + real_url="", method=None, headers=None, url="" + ), + history=None, + headers={"X-RateLimit-Remaining": 2500}, + ), + ClientResponseError, + ), + ], +) +async def test_get_personal_access_token_scopes_when_error_occurs(exception, raises): + async with create_github_source() as source: + source.github_client._get_client._request = AsyncMock(side_effect=exception) + with pytest.raises(raises): + await source.github_client.get_personal_access_token_scopes() + + @pytest.mark.asyncio async def test_github_client_get_installations(): async with create_github_source(auth_method=GITHUB_APP) as source: @@ -2074,3 +2199,62 @@ async def test_get_owners(auth_method, repo_type, expected_owners): ): actual_owners = [owner async for owner in source._get_owners()] assert actual_owners == expected_owners + + +@pytest.mark.asyncio +@patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) +async def test_update_installation_access_token_when_error_occurs(): + async with create_github_source() as source: + source.github_client.get_installation_access_token = AsyncMock( + side_effect=Exception() + ) + with pytest.raises(Exception): + await source.github_client._update_installation_access_token() + + +@pytest.mark.asyncio +@patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) +@pytest.mark.parametrize( + "exceptions, raises", + [ + ( + ClientResponseError( + status=403, + request_info=aiohttp.RequestInfo( + real_url="", method=None, headers=None, url="" + ), + headers={"X-RateLimit-Remaining": 4000}, + history=None, + ), + ForbiddenException, + ), + ( + ClientResponseError( + status=401, + request_info=aiohttp.RequestInfo( + real_url="", method=None, headers=None, url="" + ), + headers={"X-RateLimit-Remaining": 4000}, + history=None, + ), + UnauthorizedException, + ), + ( + ClientResponseError( + status=404, + request_info=aiohttp.RequestInfo( + real_url="", method=None, headers=None, url="" + ), + headers={"X-RateLimit-Remaining": 4000}, + history=None, + ), + ClientResponseError, + ), + (Exception(), Exception), + ], +) +async def test_get_github_item_when_error_occurs(exceptions, raises): + async with create_github_source() as source: + source.github_client._get_client.getitem = Mock(side_effect=exceptions) + with pytest.raises(raises): + await source.github_client.get_github_item("/core") diff --git a/tests/sources/test_network_drive.py b/tests/sources/test_network_drive.py index 6912bbde2..208a13bf0 100644 --- a/tests/sources/test_network_drive.py +++ b/tests/sources/test_network_drive.py @@ -855,6 +855,28 @@ async def test_get_access_control_dls_enabled(): assert expected_user_access_control == user_access_control +@pytest.mark.asyncio +async def test_get_access_control_without_duplicate_ids(): + async with create_source(NASDataSource) as source: + source._dls_enabled = MagicMock(return_value=True) + source.drive_type = LINUX + source.identity_mappings = "/a/b" + + source.read_user_info_csv = MagicMock( + return_value=[ + {"name": "user1", "user_sid": "S-1", "groups": ["S-11", "S-22"]}, + {"name": "user2", "user_sid": "S-2", "groups": ["S-22"]}, + {"name": "user3", "user_sid": "S-3", "groups": ["S-11"]}, + ] + ) + + seen_users = set() + async for access_control in source.get_access_control(): + seen_users.add(access_control["_id"]) + + assert len(seen_users) == 3 + + @mock.patch.object( NASDataSource, "traverse_diretory", diff --git a/tests/sources/test_outlook.py b/tests/sources/test_outlook.py index 43922c9de..2eb45c580 100644 --- a/tests/sources/test_outlook.py +++ b/tests/sources/test_outlook.py @@ -419,7 +419,10 @@ def side_effect_function(url, headers): Args: url, ssl: Params required for get call """ - if url == "https://graph.microsoft.com/v1.0/users?$top=999": + if ( + url + == "https://graph.microsoft.com/v1.0/users?$top=999&$filter=accountEnabled%20eq%20true" + ): return get_json_mock( mock_response={ "@odata.nextLink": "https://graph.microsoft.com/v1.0/users?$top=999&$skipToken=fake-skip-token", diff --git a/tests/sources/test_postgresql.py b/tests/sources/test_postgresql.py index 0e61a6947..3fe1136c4 100644 --- a/tests/sources/test_postgresql.py +++ b/tests/sources/test_postgresql.py @@ -9,6 +9,7 @@ from unittest.mock import ANY, Mock, patch import pytest +from freezegun import freeze_time from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio.engine import AsyncEngine @@ -20,6 +21,7 @@ PostgreSQLDataSource, PostgreSQLQueries, ) +from connectors.utils import iso_utc from tests.sources.support import create_source ADVANCED_SNIPPET = "advanced_snippet" @@ -30,6 +32,12 @@ TABLE = "emp_table" CUSTOMER_TABLE = "customer" +TIME = iso_utc() + +ID_ONE = "id1" +ID_TWO = "id2" +ID_THREE = "id3" + @asynccontextmanager async def create_postgresql_source(): @@ -288,6 +296,129 @@ async def test_advanced_rules_validation(advanced_rules, expected_validation_res assert validation_result == expected_validation_result +@pytest.mark.parametrize( + "advanced_rules, id_in_source, expected_validation_result", + [ + ( + # valid: empty array should be valid + [], + [ID_ONE], + SyncRuleValidationResult.valid_result( + SyncRuleValidationResult.ADVANCED_RULES + ), + ), + ( + # valid: empty object should also be valid -> default value in Kibana + {}, + [ID_ONE], + SyncRuleValidationResult.valid_result( + SyncRuleValidationResult.ADVANCED_RULES + ), + ), + ( + # valid: valid queries + [ + { + "tables": ["emp_table"], + "query": "select * from emp_table", + "id_columns": [ID_ONE], + } + ], + [ID_ONE], + SyncRuleValidationResult.valid_result( + SyncRuleValidationResult.ADVANCED_RULES + ), + ), + ( + # valid: id_columns field missing + [ + { + "tables": ["emp_table"], + "query": "select * from emp_table", + "id_columns": [ID_ONE], + } + ], + [ID_ONE], + SyncRuleValidationResult.valid_result( + SyncRuleValidationResult.ADVANCED_RULES + ), + ), + ( + # invalid: tables not present in database + [ + { + "tables": ["table_name"], + "query": "select * from table_name", + "id_columns": [ID_ONE], + } + ], + [ID_ONE], + SyncRuleValidationResult( + SyncRuleValidationResult.ADVANCED_RULES, + is_valid=False, + validation_message=ANY, + ), + ), + ( + # invalid: tables key missing + [{"query": "select * from table_name", "id_columns": [ID_ONE]}], + [ID_ONE], + SyncRuleValidationResult( + SyncRuleValidationResult.ADVANCED_RULES, + is_valid=False, + validation_message=ANY, + ), + ), + ( + # invalid: invalid key + [ + { + "tables": "table_name", + "query": "select * from table_name", + "id_columns": [ID_ONE], + } + ], + [], + SyncRuleValidationResult( + SyncRuleValidationResult.ADVANCED_RULES, + is_valid=False, + validation_message=ANY, + ), + ), + ( + # invalid: tables can be empty + [ + { + "tables": [], + "query": "select * from table_name", + "id_columns": [ID_ONE], + } + ], + [ID_ONE], + SyncRuleValidationResult( + SyncRuleValidationResult.ADVANCED_RULES, + is_valid=False, + validation_message=ANY, + ), + ), + ], +) +@pytest.mark.asyncio +async def test_advanced_rules_validation_when_id_in_source_available( + advanced_rules, id_in_source, expected_validation_result +): + async with create_source( + PostgreSQLDataSource, database="xe", tables="*", schema="public", port=5432 + ) as source: + with patch.object(AsyncEngine, "connect", return_value=ConnectionAsync()): + validation_result = await PostgreSQLAdvancedRulesValidator(source).validate( + advanced_rules + ) + + assert validation_result == expected_validation_result + + +@freeze_time(TIME) @pytest.mark.asyncio async def test_get_docs(): # Setup @@ -324,6 +455,7 @@ async def test_get_docs(): assert actual_response == expected_response +@freeze_time(TIME) @pytest.mark.parametrize( "filtering, expected_response", [ diff --git a/tests/test_sink.py b/tests/test_sink.py index 3d4d882ea..e86f7e9b4 100644 --- a/tests/test_sink.py +++ b/tests/test_sink.py @@ -60,6 +60,7 @@ DOC_TWO = {"_id": 2, "_timestamp": TIMESTAMP} DOC_THREE = {"_id": 3, "_timestamp": TIMESTAMP} +DOC_FOUR = {"_id": 4, "_timestamp": TIMESTAMP} BULK_ACTION_ERROR = "some error" @@ -433,6 +434,14 @@ async def lazy_download(**kwargs): return lazy_download +def crashing_lazy_download_fake(): + async def lazy_download(**kwargs): + msg = "Could not download" + raise Exception(msg) + + return lazy_download + + def queue_called_with_operations(queue, operations): expected_calls = [call(operation) for operation in operations] actual_calls = queue.put.call_args_list @@ -1219,7 +1228,7 @@ async def test_batch_bulk_with_errors(patch_logger): } client.client.bulk = AsyncMock(return_value=mock_result) await sink._batch_bulk([], {OP_INDEX: {"1": 20}, OP_UPDATE: {}, OP_DELETE: {}}) - patch_logger.assert_present(f"operation index failed, {error}") + patch_logger.assert_present(f"operation index failed for doc 1, {error}") @patch("connectors.es.sink.CANCELATION_TIMEOUT", -1) @@ -1282,6 +1291,34 @@ async def test_extractor_put_doc(): queue.put.assert_awaited_once_with(doc) +@pytest.mark.asyncio +@mock.patch( + "connectors.es.management_client.ESManagementClient.yield_existing_documents_metadata" +) +@mock.patch("connectors.utils.ConcurrentTasks.cancel") +async def test_extractor_get_docs_when_downloads_fail( + yield_existing_documents_metadata, concurrent_tasks_cancel +): + queue = await queue_mock() + + yield_existing_documents_metadata.return_value = AsyncIterator([]) + + docs_from_source = [ + (DOC_ONE, crashing_lazy_download_fake(), "index"), + (DOC_TWO, crashing_lazy_download_fake(), "index"), + (DOC_THREE, crashing_lazy_download_fake(), "index"), + (DOC_FOUR, crashing_lazy_download_fake(), "index"), + ] + # deep copying docs is needed as get_docs mutates the document ids which has side effects on other test + # instances + doc_generator = AsyncIterator([deepcopy(doc) for doc in docs_from_source]) + + extractor = await setup_extractor(queue, content_extraction_enabled=True) + + await extractor.run(doc_generator, JobType.FULL) + concurrent_tasks_cancel.assert_called_once() + + @pytest.mark.asyncio async def test_force_canceled_extractor_put_doc(): doc = {"id": 123} @@ -1419,11 +1456,11 @@ async def test_cancel_sync(extractor_task_done, sink_task_done, force_cancel): es._sink = Mock() es._sink.force_cancel = Mock() - es._extractor_task = Mock() + es._extractor_task = mock.create_autospec(asyncio.Task) es._extractor_task.cancel = Mock() es._extractor_task.done = Mock(side_effect=extractor_task_done) - es._sink_task = Mock() + es._sink_task = mock.create_autospec(asyncio.Task) es._sink_task.cancel = Mock() es._sink_task.done = Mock(side_effect=sink_task_done)