diff --git a/.gitignore b/.gitignore index 2585d4d36..f602e6700 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ __pycache__ # jetbrains files .idea *.iml +.cli diff --git a/config.yml b/config.yml index e645c08c8..578b90814 100644 --- a/config.yml +++ b/config.yml @@ -171,14 +171,6 @@ #service.log_level: INFO # # -## Whether telemetry is enabled -#service.telemetry.enabled: true -# -# -## The interval (in seconds) to run telemetry job -#service.telemetry.interval: 3600 -# -# ## ------------------------------- Extraction Service ---------------------------------- # ## Local extraction service-related configurations. diff --git a/connectors/cli/.gitkeep b/connectors/cli/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/connectors/cli/auth.py b/connectors/cli/auth.py new file mode 100644 index 000000000..9c80a6b00 --- /dev/null +++ b/connectors/cli/auth.py @@ -0,0 +1,40 @@ +import asyncio +import os + +import yaml +from elasticsearch import ApiError + +from connectors.es import ESClient + +CONFIG_FILE_PATH = ".cli/config.yml" + + +class Auth: + def __init__(self, host, username, password): + self.elastic_config = {"host": host, "username": username, "password": password} + self.es_client = ESClient(self.elastic_config) + + def authenticate(self): + if asyncio.run(self.__ping_es_client()): + self.__save_config() + return True + else: + return False + + def is_config_present(self): + return os.path.isfile(CONFIG_FILE_PATH) + + async def __ping_es_client(self): + try: + return await self.es_client.ping() + except ApiError: + return False + finally: + await self.es_client.close() + + def __save_config(self): + yaml_content = yaml.dump({"elasticsearch": self.elastic_config}) + os.makedirs(os.path.dirname(CONFIG_FILE_PATH), exist_ok=True) + + with open(CONFIG_FILE_PATH, "w") as f: + f.write(yaml_content) diff --git a/connectors/cli/connector.py b/connectors/cli/connector.py new file mode 100644 index 000000000..e7b7d779e --- /dev/null +++ b/connectors/cli/connector.py @@ -0,0 +1,183 @@ +import asyncio +from collections import OrderedDict + +from connectors.es.client import ESClient +from connectors.es.settings import DEFAULT_LANGUAGE, Mappings, Settings +from connectors.protocol import ( + CONCRETE_CONNECTORS_INDEX, + CONCRETE_JOBS_INDEX, + ConnectorIndex, +) +from connectors.source import get_source_klass +from connectors.utils import iso_utc + + +class IndexAlreadyExists(Exception): + pass + + +class Connector: + def __init__(self, config): + self.config = config + + # initialize ES client + self.es_client = ESClient(self.config) + + self.connector_index = ConnectorIndex(self.config) + + async def list_connectors(self): + # TODO move this on top + try: + await self.es_client.ensure_exists( + indices=[CONCRETE_CONNECTORS_INDEX, CONCRETE_JOBS_INDEX] + ) + + return [ + connector async for connector in self.connector_index.all_connectors() + ] + + # TODO catch exceptions + finally: + await self.connector_index.close() + await self.es_client.close() + + def service_type_configuration(self, source_class): + source_klass = get_source_klass(source_class) + configuration = source_klass.get_default_configuration() + + return OrderedDict(sorted(configuration.items(), key=lambda x: x[1]["order"])) + + def create( + self, index_name, service_type, configuration, language=DEFAULT_LANGUAGE + ): + return asyncio.run( + self.__create(index_name, service_type, configuration, language) + ) + + async def __create( + self, index_name, service_type, configuration, language=DEFAULT_LANGUAGE + ): + try: + return await asyncio.gather( + self.__create_search_index(index_name, language), + self.__create_connector( + index_name, service_type, configuration, language + ), + ) + except Exception as e: + raise e + finally: + await self.es_client.close() + + async def __create_search_index(self, index_name, language): + mappings = Mappings.default_text_fields_mappings( + is_connectors_index=True, + ) + + settings = Settings(language_code=language, analysis_icu=False).to_hash() + + settings["auto_expand_replicas"] = "0-3" + settings["number_of_shards"] = 2 + + await self.es_client.client.indices.create( + index=index_name, mappings=mappings, settings=settings + ) + + async def __create_connector( + self, index_name, service_type, configuration, language + ): + try: + await self.es_client.ensure_exists( + indices=[CONCRETE_CONNECTORS_INDEX, CONCRETE_JOBS_INDEX] + ) + timestamp = iso_utc() + + doc = { + "api_key_id": "", + "configuration": configuration, + "index_name": index_name, + "service_type": service_type, + "status": "configured", # TODO use a predefined constant + "is_native": True, # TODO make it optional + "language": language, + "last_access_control_sync_error": None, + "last_access_control_sync_scheduled_at": None, + "last_access_control_sync_status": None, + "last_sync_status": None, + "last_sync_error": None, + "last_sync_scheduled_at": None, + "last_synced": None, + "last_seen": None, + "created_at": timestamp, + "updated_at": timestamp, + "filtering": self.default_filtering(timestamp), + "scheduling": self.default_scheduling(), + "custom_scheduling": {}, + "pipeline": { + "extract_binary_content": True, + "name": "ent-search-generic-ingestion", + "reduce_whitespace": True, + "run_ml_inference": True, + }, + "last_indexed_document_count": 0, + "last_deleted_document_count": 0, + } + + connector = await self.connector_index.index(doc) + return connector["_id"] + finally: + await self.connector_index.close() + + def default_scheduling(self): + return { + "access_control": {"enabled": False, "interval": "0 0 0 * * ?"}, + "full": {"enabled": False, "interval": "0 0 0 * * ?"}, + "incremental": {"enabled": False, "interval": "0 0 0 * * ?"}, + } + + def default_filtering(self, timestamp): + return [ + { + "active": { + "advanced_snippet": { + "created_at": timestamp, + "updated_at": timestamp, + "value": {}, + }, + "rules": [ + { + "created_at": timestamp, + "field": "_", + "id": "DEFAULT", + "order": 0, + "policy": "include", + "rule": "regex", + "updated_at": timestamp, + "value": ".*", + } + ], + "validation": {"errors": [], "state": "valid"}, + }, + "domain": "DEFAULT", + "draft": { + "advanced_snippet": { + "created_at": timestamp, + "updated_at": timestamp, + "value": {}, + }, + "rules": [ + { + "created_at": timestamp, + "field": "_", + "id": "DEFAULT", + "order": 0, + "policy": "include", + "rule": "regex", + "updated_at": timestamp, + "value": ".*", + } + ], + "validation": {"errors": [], "state": "valid"}, + }, + } + ] diff --git a/connectors/cli/index.py b/connectors/cli/index.py new file mode 100644 index 000000000..319ee3026 --- /dev/null +++ b/connectors/cli/index.py @@ -0,0 +1,45 @@ +import asyncio + +from elasticsearch import ApiError + +from connectors.es import ESClient + + +class Index: + def __init__(self, config): + self.elastic_config = config + self.es_client = ESClient(self.elastic_config) + + def list_indices(self): + return asyncio.run(self.__list_indices())["indices"] + + def clean(self, index_name): + return asyncio.run(self.__clean_index(index_name)) + + def delete(self, index_name): + return asyncio.run(self.__delete_index(index_name)) + + async def __list_indices(self): + try: + return await self.es_client.list_indices() + except ApiError as e: + raise e + finally: + await self.es_client.close() + + async def __clean_index(self, index_name): + try: + return await self.es_client.clean_index(index_name) + except ApiError: + return False + finally: + await self.es_client.close() + + async def __delete_index(self, index_name): + try: + await self.es_client.delete_indices([index_name]) + return True + except ApiError: + return False + finally: + await self.es_client.close() diff --git a/connectors/cli/job.py b/connectors/cli/job.py new file mode 100644 index 000000000..3e1cc8f56 --- /dev/null +++ b/connectors/cli/job.py @@ -0,0 +1,95 @@ +import asyncio + +from elasticsearch import ApiError + +from connectors.es.client import ESClient +from connectors.protocol import ( + CONCRETE_CONNECTORS_INDEX, + CONCRETE_JOBS_INDEX, + ConnectorIndex, + JobStatus, + JobTriggerMethod, + JobType, + Sort, + SyncJobIndex, +) + + +class Job: + def __init__(self, config): + self.config = config + self.es_client = ESClient(self.config) + self.sync_job_index = SyncJobIndex(self.config) + self.connector_index = ConnectorIndex(self.config) + + def list_jobs(self, connector_id=None, index_name=None, job_id=None): + return asyncio.run(self.__async_list_jobs(connector_id, index_name, job_id)) + + def cancel(self, connector_id=None, index_name=None, job_id=None): + return asyncio.run(self.__async_cancel_jobs(connector_id, index_name, job_id)) + + def start(self, connector_id, job_type): + return asyncio.run(self.__async_start(connector_id, job_type)) + + async def __async_start(self, connector_id, job_type): + try: + connector = await self.connector_index.fetch_by_id(connector_id) + await self.sync_job_index.create( + connector=connector, + trigger_method=JobTriggerMethod.ON_DEMAND, + job_type=JobType(job_type), + ) + + return True + finally: + await self.sync_job_index.close() + await self.connector_index.close() + await self.es_client.close() + + async def __async_list_jobs(self, connector_id, index_name, job_id): + try: + await self.es_client.ensure_exists( + indices=[CONCRETE_CONNECTORS_INDEX, CONCRETE_JOBS_INDEX] + ) + jobs = self.sync_job_index.get_all_docs( + query=self.__job_list_query(connector_id, index_name, job_id), + sort=self.__job_list_sort(), + ) + + return [job async for job in jobs] + + # TODO catch exceptions + finally: + await self.sync_job_index.close() + await self.es_client.close() + + async def __async_cancel_jobs(self, connector_id, index_name, job_id): + try: + jobs = await self.__async_list_jobs(connector_id, index_name, job_id) + + for job in jobs: + await job._terminate(JobStatus.CANCELING) + + return True + except ApiError: + return False + finally: + await self.sync_job_index.close() + await self.es_client.close() + + def __job_list_query(self, connector_id, index_name, job_id): + if job_id: + return {"bool": {"must": [{"term": {"_id": job_id}}]}} + + if index_name: + return { + "bool": {"filter": [{"term": {"connector.index_name": index_name}}]} + } + + if connector_id: + return {"bool": {"must": [{"term": {"connector.id": connector_id}}]}} + + return None + + def __job_list_sort(self): + return [{"created_at": Sort.ASC.value}] diff --git a/connectors/config.py b/connectors/config.py index 66dba54cc..30cfa4b4d 100644 --- a/connectors/config.py +++ b/connectors/config.py @@ -79,10 +79,6 @@ def _default_config(): "max_concurrent_access_control_syncs": 1, "job_cleanup_interval": 300, "log_level": "INFO", - "telemetry": { - "enabled": True, - "interval": 3600, - }, }, "sources": { "azure_blob_storage": "connectors.sources.azure_blob_storage:AzureBlobStorageDataSource", diff --git a/connectors/connectors_cli.py b/connectors/connectors_cli.py new file mode 100644 index 000000000..b582d6f8d --- /dev/null +++ b/connectors/connectors_cli.py @@ -0,0 +1,411 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License 2.0; +# you may not use this file except in compliance with the Elastic License 2.0. +# +""" +Command Line Interface. + +This is the main entry point of the framework. When the project is installed as +a Python package, an `elastic-ingest` executable is added in the PATH and +executes the `main` function of this module, which starts the service. +""" +import asyncio +import os + +import click +import yaml +from tabulate import tabulate + +from connectors import __version__ # NOQA +from connectors.cli.auth import CONFIG_FILE_PATH, Auth +from connectors.cli.connector import Connector +from connectors.cli.index import Index +from connectors.cli.job import Job +from connectors.config import _default_config +from connectors.es.settings import Settings + +__all__ = ["main"] + + +def load_config(ctx, config): + if config: + return yaml.safe_load(config) + elif os.path.isfile(CONFIG_FILE_PATH): + with open(CONFIG_FILE_PATH, "r") as f: + return yaml.safe_load(f.read()) + elif ctx.invoked_subcommand == "login": + pass + else: + msg = f"{CONFIG_FILE_PATH} is not found" + raise FileNotFoundError(msg) + + +# Main group +@click.group(invoke_without_command=True) +@click.version_option(__version__, "-v", "--version", message="%(version)s") +@click.option("-c", "--config", type=click.File("rb")) +@click.pass_context +def cli(ctx, config): + # print help page if no subcommands provided + if ctx.invoked_subcommand is None: + click.echo(ctx.get_help()) + return + + ctx.ensure_object(dict) + ctx.obj["config"] = load_config(ctx, config) + + +@click.command(help="Authenticate Connectors CLI with an Elasticsearch instance") +@click.option("--host", prompt="Elastic host") +@click.option("--username", prompt="Username") +@click.option("--password", prompt="Password", hide_input=True) +def login(host, username, password): + auth = Auth(host, username, password) + if auth.is_config_present(): + click.confirm( + click.style( + "Config is already present. Are you sure you want to override it?", + fg="yellow", + ), + abort=True, + ) + if auth.authenticate(): + click.echo(click.style("Authentication successful", fg="green")) + else: + click.echo("") + click.echo( + click.style( + "Authentication failed. Please check your credentials.", fg="red" + ), + err=True, + ) + return + + +cli.add_command(login) + + +# Connector group +@click.group(invoke_without_command=False, help="Connectors management") +@click.pass_context +def connector(ctx): + pass + + +@click.command(name="list", help="List all existing connectors") +@click.pass_obj +def list_connectors(obj): + connector = Connector(config=obj["config"]["elasticsearch"]) + coro = connector.list_connectors() + + try: + connectors = asyncio.run(coro) + click.echo("") + if len(connectors) == 0: + click.echo("No connectors found") + return + + click.echo(f"Showing {len(connectors)} connectors \n") + + table_rows = [] + for connector in connectors: + formatted_connector = [ + click.style(connector.id, blink=True, fg="green"), + click.style(connector.index_name, blink=True, fg="white"), + click.style(connector.service_type, blink=True, fg="white"), + click.style(connector.status.value, fg="white"), + click.style(connector.last_sync_status.value, fg="white"), + ] + table_rows.append(formatted_connector) + + click.echo( + tabulate( + table_rows, + headers=[ + "ID", + "Index name", + "Service type", + "Status", + "Last sync job status", + ], + ) + ) + except asyncio.CancelledError as e: + click.echo(e) + + +language_keys = [*Settings().language_data.keys()] + + +# Support blank values for languge +def validate_language(ctx, param, value): + if value not in language_keys: + return None + + return value + + +@click.command(help="Creates a new connector and a search index") +@click.option( + "--index_name", + prompt=f"{click.style('?', blink=True, fg='green')} Search index name (search-)", +) +@click.option( + "--service_type", + prompt=f"{click.style('?', blink=True, fg='green')} Service type", + type=click.Choice(list(_default_config()["sources"].keys()), case_sensitive=False), +) +@click.option( + "--index_language", + prompt=f"{click.style('?', blink=True, fg='green')} Index language (leave empty for universal) {language_keys}", + default="", + callback=validate_language, +) +@click.pass_obj +def create(obj, index_name, service_type, index_language): + index_name = f"search-{index_name}" + connector = Connector(obj["config"]["elasticsearch"]) + configuration = connector.service_type_configuration( + source_class=_default_config()["sources"][service_type] + ) + + def prompt(): + return click.prompt( + f"{click.style('?', blink=True, fg='green')} {item['label']}", + default=item.get("value", None), + hide_input=True if item.get("sensitive") is True else False, + ) + + # first fill in the fields that do not depend on other fields + for key, item in configuration.items(): + if "depends_on" in item: + continue + + configuration[key]["value"] = prompt() + + for key, item in configuration.items(): + if "depends_on" not in item: + continue + + if all( + configuration[field_item["field"]]["value"] == field_item["value"] + for field_item in item["depends_on"] + ): + configuration[key]["value"] = prompt() + + result = connector.create(index_name, service_type, configuration, index_language) + click.echo( + "Connector (ID: " + + click.style(result[1], fg="green") + + ", service_type: " + + click.style(service_type, fg="green") + + ") has been created!" + ) + + +connector.add_command(create) +connector.add_command(list_connectors) + +cli.add_command(connector) + + +# Index group +@click.group(invoke_without_command=True, help="Search indices management") +@click.pass_obj +def index(obj): + pass + + +@click.command(name="list", help="Show all indices") +@click.pass_obj +def list_indices(obj): + index = Index(config=obj["config"]["elasticsearch"]) + indices = index.list_indices() + + click.echo("") + + if len(indices) == 0: + click.echo("No indices found") + return + + click.echo(f"Showing {len(indices)} indices \n") + table_rows = [] + for index in indices: + formatted_index = [ + click.style(index, blink=True, fg="white"), + click.style(indices[index]["primaries"]["docs"]["count"]), + ] + table_rows.append(formatted_index) + + click.echo(tabulate(table_rows, headers=["Index name", "Number of documents"])) + + +index.add_command(list_indices) + + +@click.command(help="Remove all documents from the index") +@click.pass_obj +@click.argument("index", nargs=1) +def clean(obj, index): + index_cli = Index(config=obj["config"]["elasticsearch"]) + click.confirm( + click.style("Are you sure you want to clean " + index + "?", fg="yellow"), + abort=True, + ) + if index_cli.clean(index): + click.echo(click.style("The index has been cleaned.", fg="green")) + else: + click.echo("") + click.echo( + click.style( + "Something went wrong. Please try again later or check your credentials", + fg="red", + ), + err=True, + ) + + +index.add_command(clean) + + +@click.command(help="Delete an index") +@click.pass_obj +@click.argument("index", nargs=1) +def delete(obj, index): + index_cli = Index(config=obj["config"]["elasticsearch"]) + click.confirm( + click.style("Are you sure you want to delete " + index + "?", fg="yellow"), + abort=True, + ) + if index_cli.delete(index): + click.echo(click.style("The index has been deleted.", fg="green")) + else: + click.echo("") + click.echo( + click.style( + "Something went wrong. Please try again later or check your credentials", + fg="red", + ), + err=True, + ) + + +index.add_command(delete) + +cli.add_command(index) + + +# Job group +@click.group(invoke_without_command=False, help="Sync jobs management") +@click.pass_obj +def job(obj): + pass + + +@click.command(help="Start a sync job.") +@click.pass_obj +@click.option("-i", help="Connector ID", required=True) +@click.option( + "-t", + help="Job type", + type=click.Choice(["full", "incremental", "access_control"], case_sensitive=False), + required=True, +) +def start(obj, i, t): + job_cli = Job(config=obj["config"]["elasticsearch"]) + click.echo("Starting a job...") + if job_cli.start(connector_id=i, job_type=t): + click.echo(click.style("The job has been started.", fg="green")) + else: + click.echo("") + click.echo( + click.style( + "Something went wrong. Please try again later or check your credentials", + fg="red", + ), + err=True, + ) + + +job.add_command(start) + + +@click.command(name="list", help="List of jobs sorted by date.") +@click.pass_obj +@click.argument("connector_id", nargs=1) +def list_jobs(obj, connector_id): + job_cli = Job(config=obj["config"]["elasticsearch"]) + jobs = job_cli.list_jobs(connector_id=connector_id) + + if len(jobs) == 0: + click.echo("No jobs found") + + click.echo(f"Showing {len(jobs)} jobs \n") + table_rows = [] + for job in jobs: + formatted_job = [ + click.style(job.id, blink=True, fg="green"), + click.style(job.connector_id, blink=True, fg="white"), + click.style(job.index_name, blink=True, fg="white"), + click.style(job.status.value, blink=True, fg="white"), + click.style(job.job_type.value, blink=True, fg="white"), + click.style(job.indexed_document_count, blink=True, fg="white"), + click.style(job.indexed_document_volume, blink=True, fg="white"), + click.style(job.deleted_document_count, blink=True, fg="white"), + ] + table_rows.append(formatted_job) + + click.echo( + tabulate( + table_rows, + headers=[ + "Job id", + "Connector id", + "Index name", + "Job status", + "Job type", + "Documents indexed", + "Volume documents indexed (MiB)", + "Documents deleted", + ], + ) + ) + + +job.add_command(list_jobs) + + +@click.command(help="Cancel a job") +@click.pass_obj +@click.argument("job_id") +def cancel(obj, job_id): + job_cli = Job(config=obj["config"]["elasticsearch"]) + click.confirm( + click.style("Are you sure you want to cancel jobs?", fg="yellow"), abort=True + ) + click.echo("Canceling jobs...") + if job_cli.cancel(job_id=job_id): + click.echo(click.style("The job has been cancelled", fg="green")) + else: + click.echo("") + click.echo( + click.style( + "Something went wrong. Please try again later or check your credentials", + fg="red", + ), + err=True, + ) + + +job.add_command(cancel) + +cli.add_command(job) + + +def main(args=None): + cli() + + +if __name__ == "__main__": + main() diff --git a/connectors/es/client.py b/connectors/es/client.py index 52b054669..3261fa8af 100644 --- a/connectors/es/client.py +++ b/connectors/es/client.py @@ -176,6 +176,14 @@ async def ensure_exists(self, indices=None): async def delete_indices(self, indices): await self.client.indices.delete(index=indices, ignore_unavailable=True) + async def clean_index(self, index_name): + return await self.client.delete_by_query( + index=index_name, body={"query": {"match_all": {}}}, ignore_unavailable=True + ) + + async def list_indices(self): + return await self.client.indices.stats(index="search-*") + def with_concurrency_control(retries=3): def wrapper(func): diff --git a/connectors/protocol/connectors.py b/connectors/protocol/connectors.py index e4074adba..85298fdfb 100644 --- a/connectors/protocol/connectors.py +++ b/connectors/protocol/connectors.py @@ -12,7 +12,6 @@ - SyncJob: represents a document in `.elastic-connectors-sync-jobs` """ -import re import socket from collections import UserDict from copy import deepcopy @@ -954,11 +953,7 @@ async def create(self, connector, trigger_method, job_type): index_name = connector.index_name if job_type == JobType.ACCESS_CONTROL: - index_name = re.sub( - r"^(?:search-)?(.*)$", - rf"{ACCESS_CONTROL_INDEX_PREFIX}\g<1>", - index_name, - ) + index_name = f"{ACCESS_CONTROL_INDEX_PREFIX}{index_name}" job_def = { "connector": { diff --git a/connectors/cli.py b/connectors/service_cli.py similarity index 100% rename from connectors/cli.py rename to connectors/service_cli.py diff --git a/connectors/sources/mongo.py b/connectors/sources/mongo.py index ec0ffff18..3ffb4e598 100644 --- a/connectors/sources/mongo.py +++ b/connectors/sources/mongo.py @@ -7,7 +7,7 @@ from datetime import datetime import fastjsonschema -from bson import Decimal128, ObjectId +from bson import DBRef, Decimal128, ObjectId from fastjsonschema import JsonSchemaValueException from motor.motor_asyncio import AsyncIOMotorClient @@ -167,6 +167,8 @@ def _serialize(value): value = value.isoformat() elif isinstance(value, Decimal128): value = value.to_decimal() + elif isinstance(value, DBRef): + value = _serialize(value.as_doc().to_dict()) return value for key, value in doc.items(): diff --git a/connectors/sources/oracle.py b/connectors/sources/oracle.py index df8a77404..0c6413516 100644 --- a/connectors/sources/oracle.py +++ b/connectors/sources/oracle.py @@ -335,6 +335,7 @@ def get_default_configuration(cls): "ui_restrictions": ["advanced"], }, "oracle_protocol": { + "default_value": DEFAULT_PROTOCOL, "display": "dropdown", "label": "Oracle connection protocol", "options": [ @@ -344,6 +345,7 @@ def get_default_configuration(cls): "order": 9, "type": "str", "value": DEFAULT_PROTOCOL, + "ui_restrictions": ["advanced"], }, "oracle_home": { "default_value": DEFAULT_ORACLE_HOME, @@ -351,6 +353,8 @@ def get_default_configuration(cls): "order": 10, "required": False, "type": "str", + "value": DEFAULT_ORACLE_HOME, + "ui_restrictions": ["advanced"], }, "wallet_configuration_path": { "default_value": "", @@ -358,6 +362,7 @@ def get_default_configuration(cls): "order": 11, "required": False, "type": "str", + "ui_restrictions": ["advanced"], }, } diff --git a/connectors/sources/s3.py b/connectors/sources/s3.py index a17852f73..7cf1b6c5d 100644 --- a/connectors/sources/s3.py +++ b/connectors/sources/s3.py @@ -10,10 +10,16 @@ from functools import partial import aioboto3 +import fastjsonschema from aiobotocore.config import AioConfig from aiobotocore.utils import logger as aws_logger from botocore.exceptions import ClientError +from fastjsonschema import JsonSchemaValueException +from connectors.filtering.validation import ( + AdvancedRulesValidator, + SyncRuleValidationResult, +) from connectors.logger import logger, set_extra_logger from connectors.source import BaseDataSource from connectors.utils import hash_id @@ -105,7 +111,7 @@ async def get_bucket_list(self): buckets = self.configuration["buckets"] return buckets - async def get_bucket_objects(self, bucket): + async def get_bucket_objects(self, bucket, **kwargs): """Returns bucket list from list_buckets response Args: bucket (str): Name of bucket @@ -126,7 +132,14 @@ async def get_bucket_objects(self, bucket): bucket_obj = await s3.Bucket(bucket) await asyncio.sleep(0) - async for obj_summary in bucket_obj.objects.page_size(page_size): + if kwargs.get("prefix"): + objects = bucket_obj.objects.filter( + Prefix=kwargs["prefix"] + ).page_size(page_size) + else: + objects = bucket_obj.objects.page_size(page_size) + + async for obj_summary in objects: yield obj_summary, s3_client except Exception as exception: self._logger.warning( @@ -153,11 +166,49 @@ async def get_bucket_region(self, bucket_name): return region +class S3AdvancedRulesValidator(AdvancedRulesValidator): + RULES_OBJECT_SCHEMA_DEFINITION = { + "type": "object", + "properties": { + "bucket": {"type": "string", "minLength": 1}, + "prefix": {"type": "string"}, + "extension": {"type": "array"}, + }, + "required": ["bucket"], + "additionalProperties": False, + } + + SCHEMA_DEFINITION = {"type": "array", "items": RULES_OBJECT_SCHEMA_DEFINITION} + + SCHEMA = fastjsonschema.compile(definition=SCHEMA_DEFINITION) + + def __init__(self, source): + self.source = source + + async def validate(self, advanced_rules): + if len(advanced_rules) == 0: + return SyncRuleValidationResult.valid_result( + SyncRuleValidationResult.ADVANCED_RULES + ) + try: + S3AdvancedRulesValidator.SCHEMA(advanced_rules) + return SyncRuleValidationResult.valid_result( + rule_id=SyncRuleValidationResult.ADVANCED_RULES + ) + except JsonSchemaValueException as e: + return SyncRuleValidationResult( + rule_id=SyncRuleValidationResult.ADVANCED_RULES, + is_valid=False, + validation_message=e.message, + ) + + class S3DataSource(BaseDataSource): """Amazon S3""" name = "Amazon S3" service_type = "s3" + advanced_rules_enabled = True def __init__(self, configuration): """Set up the connection to the Amazon S3. @@ -171,6 +222,9 @@ def __init__(self, configuration): def _set_internal_logger(self): self.s3_client.set_logger(self._logger) + def advanced_rules_validators(self): + return [S3AdvancedRulesValidator(self)] + async def ping(self): """Verify the connection with AWS""" try: @@ -203,6 +257,26 @@ async def format_document(self, bucket_name, bucket_object): } return document + async def advanced_sync(self, rule): + async def process_object(obj_summary, s3_client): + document = await self.format_document( + bucket_name=bucket, bucket_object=obj_summary + ) + return document, partial( + self.get_content, doc=document, s3_client=s3_client + ) + + bucket = rule["bucket"] + prefix = rule.get("prefix", "") + async for obj_summary, s3_client in self.s3_client.get_bucket_objects( + bucket=bucket, prefix=prefix + ): + if not rule.get("extension"): + yield await process_object(obj_summary, s3_client) + + elif self.get_file_extension(obj_summary.key) in rule.get("extension", []): + yield await process_object(obj_summary, s3_client) + async def get_docs(self, filtering=None): """Get documents from Amazon S3 @@ -212,19 +286,25 @@ async def get_docs(self, filtering=None): Yields: dictionary: Document from Amazon S3. """ - bucket_list = await self.s3_client.get_bucket_list() - for bucket in bucket_list: - async for obj_summary, s3_client in self.s3_client.get_bucket_objects( - bucket - ): - document = await self.format_document( - bucket_name=bucket, bucket_object=obj_summary - ) - yield document, partial( - self.get_content, - doc=document, - s3_client=s3_client, - ) + if filtering and filtering.has_advanced_rules(): + for rule in filtering.get_advanced_rules(): + async for document, attachment in self.advanced_sync(rule=rule): + yield document, attachment + + else: + bucket_list = await self.s3_client.get_bucket_list() + for bucket in bucket_list: + async for obj_summary, s3_client in self.s3_client.get_bucket_objects( + bucket=bucket + ): + document = await self.format_document( + bucket_name=bucket, bucket_object=obj_summary + ) + yield document, partial( + self.get_content, + doc=document, + s3_client=s3_client, + ) async def get_content(self, doc, s3_client, timestamp=None, doit=None): if not (doit): @@ -272,6 +352,7 @@ def get_default_configuration(cls): "display": "textarea", "label": "AWS Buckets", "order": 1, + "tooltip": "AWS Buckets are ignored when Advanced Sync Rules are used.", "type": "list", }, "aws_access_key_id": { diff --git a/connectors/sources/sharepoint_online.py b/connectors/sources/sharepoint_online.py index c1a8124ae..56efdf5fe 100644 --- a/connectors/sources/sharepoint_online.py +++ b/connectors/sources/sharepoint_online.py @@ -799,6 +799,11 @@ async def drive_items_permissions_batch(self, drive_id, drive_item_ids): permissions_uri = f"/drives/{drive_id}/items/{item_id}/permissions" requests.append({"id": item_id, "method": "GET", "url": permissions_uri}) + if not requests: + self._logger.debug( + "Skipping fetching empty batch of drive item permissions" + ) + return try: batch_url = f"{GRAPH_API_URL}/$batch" batch_request = {"requests": requests} diff --git a/connectors/utils.py b/connectors/utils.py index 9f360a186..1e5e8bb67 100644 --- a/connectors/utils.py +++ b/connectors/utils.py @@ -374,7 +374,8 @@ def _callback(self, task, result_callback=None): self._task_over.set() if task.exception(): logger.error( - f"Exception found for task {task.get_name()}: {task.exception()}" + f"Exception found for task {task.get_name()}: {task.exception()}", + exc_info=True, ) if result_callback is not None: result_callback(task.result()) diff --git a/requirements/framework.txt b/requirements/framework.txt index 3b479c90f..651a3aebe 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -29,3 +29,6 @@ exchangelib==5.0.3 ldap3==2.9.1 lxml==4.9.3 pywinrm==0.4.3 +click==8.1.7 +colorama==0.4.6 +tabulate==0.9.0 diff --git a/setup.py b/setup.py index 0e1e1b432..4a2d9b5d6 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,8 @@ def read_reqs(req_file): install_requires=install_requires, entry_points=""" [console_scripts] - elastic-ingest = connectors.cli:main + elastic-ingest = connectors.service_cli:main fake-kibana = connectors.kibana:main + connectors = connectors.connectors_cli:main """, ) diff --git a/tests/sources/fixtures/oracle/connector.json b/tests/sources/fixtures/oracle/connector.json index fbf3b387d..70345fcd0 100644 --- a/tests/sources/fixtures/oracle/connector.json +++ b/tests/sources/fixtures/oracle/connector.json @@ -85,7 +85,8 @@ ], "order": 9, "type": "str", - "value": "TCP" + "value": "TCP", + "ui_restrictions": ["advanced"] }, "oracle_home": { "default_value": "", @@ -93,7 +94,8 @@ "order": 10, "required": false, "type": "str", - "value": "" + "value": "", + "ui_restrictions": ["advanced"] }, "wallet_configuration_path": { "default_value": "", @@ -101,7 +103,8 @@ "order": 11, "required": false, "type": "str", - "value": "" + "value": "", + "ui_restrictions": ["advanced"] } }, "filtering": [ diff --git a/tests/sources/test_mongo.py b/tests/sources/test_mongo.py index 3127439c1..f75890b5f 100644 --- a/tests/sources/test_mongo.py +++ b/tests/sources/test_mongo.py @@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, Mock import pytest +from bson import DBRef, ObjectId from bson.decimal128 import Decimal128 from connectors.protocol import Filter @@ -321,3 +322,20 @@ async def test_validate_config_when_configuration_valid_then_does_not_raise(): collection=configured_collection_name, ) as source: await source.validate_config() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "raw, output", + [ + ({"ref": DBRef("foo", "bar")}, {"ref": {"$ref": "foo", "$id": "bar"}}), + ({"dec": Decimal128("1.25")}, {"dec": 1.25}), + ( + {"id": ObjectId("507f1f77bcf86cd799439011")}, + {"id": "507f1f77bcf86cd799439011"}, + ), + ], +) +async def test_serialize(raw, output): + async with create_mongo_source() as source: + assert source.serialize(raw) == output diff --git a/tests/sources/test_s3.py b/tests/sources/test_s3.py index 7560f0fea..eecd0b229 100644 --- a/tests/sources/test_s3.py +++ b/tests/sources/test_s3.py @@ -6,17 +6,21 @@ from contextlib import asynccontextmanager from datetime import datetime from unittest import mock -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import ANY, AsyncMock, MagicMock, patch import aioboto3 import aiofiles import pytest from botocore.exceptions import ClientError, HTTPClientError +from connectors.filtering.validation import SyncRuleValidationResult +from connectors.protocol import Filter from connectors.source import ConfigurableFieldValueError -from connectors.sources.s3 import S3DataSource +from connectors.sources.s3 import S3AdvancedRulesValidator, S3DataSource from tests.sources.support import create_source +ADVANCED_SNIPPET = "advanced_snippet" + @asynccontextmanager async def create_s3_source(use_text_extraction_service=False): @@ -368,6 +372,43 @@ async def test_get_docs(mock_aws): num += 1 +@pytest.mark.parametrize( + "filtering", + [ + Filter( + { + ADVANCED_SNIPPET: { + "value": [ + {"bucket": "bucket1"}, + ] + } + } + ), + ], +) +@pytest.mark.asyncio +async def test_get_docs_with_advanced_rules(filtering): + async with create_s3_source() as source: + source.s3_client.get_bucket_location = mock.Mock( + return_value=await create_fake_coroutine("ap-south-1") + ) + with mock.patch( + "aioboto3.resources.collection.AIOResourceCollection", AIOResourceCollection + ), mock.patch("aiobotocore.client.AioBaseClient", S3Object), mock.patch( + "aiobotocore.utils.AioInstanceMetadataFetcher.retrieve_iam_role_credentials", + get_roles, + ): + num = 0 + async for (doc, _) in source.get_docs(filtering): + assert doc["_id"] in ( + "70743168e14c18632702ee6e3e9b73fc", + "9fbda540ca0a2441475aea7b8f37bdaf", + "c5a8c684e7bbdc471a20613a6d8074e1", + "e2819e8a4e921caaf0250548ffaddde4", + ) + num += 1 + + @pytest.mark.asyncio async def test_get_bucket_list(): """Test get_bucket_list method of S3Client""" @@ -441,3 +482,63 @@ async def test_close_with_client_session(): await source.close() with pytest.raises(HTTPClientError): await source.ping() + + +@pytest.mark.parametrize( + "advanced_rules, expected_validation_result", + [ + ( + # valid: empty array should be valid + [], + SyncRuleValidationResult.valid_result( + SyncRuleValidationResult.ADVANCED_RULES + ), + ), + ( + # valid: empty object should also be valid -> default value in Kibana + {}, + SyncRuleValidationResult.valid_result( + SyncRuleValidationResult.ADVANCED_RULES + ), + ), + ( + # valid: one custom pattern + [{"bucket": "bucket1"}], + SyncRuleValidationResult.valid_result( + SyncRuleValidationResult.ADVANCED_RULES + ), + ), + ( + # valid: two custom patterns + [{"bucket": "bucket1"}, {"bucket": "bucket2"}], + SyncRuleValidationResult.valid_result( + SyncRuleValidationResult.ADVANCED_RULES + ), + ), + ( + # invalid: extension in string + [{"bucket": "bucket1", "extension": ".jpg"}], + SyncRuleValidationResult( + SyncRuleValidationResult.ADVANCED_RULES, + is_valid=False, + validation_message=ANY, + ), + ), + ( + # invalid: array of arrays -> wrong type + {"bucket": ["a/b/c", ""]}, + SyncRuleValidationResult( + SyncRuleValidationResult.ADVANCED_RULES, + is_valid=False, + validation_message=ANY, + ), + ), + ], +) +@pytest.mark.asyncio +async def test_advanced_rules_validation(advanced_rules, expected_validation_result): + async with create_source(S3DataSource) as source: + validation_result = await S3AdvancedRulesValidator(source).validate( + advanced_rules + ) + assert validation_result == expected_validation_result diff --git a/tests/sources/test_sharepoint_online.py b/tests/sources/test_sharepoint_online.py index 2106b6a58..fec5355a9 100644 --- a/tests/sources/test_sharepoint_online.py +++ b/tests/sources/test_sharepoint_online.py @@ -1475,6 +1475,17 @@ async def test_drive_items_permissions_batch_not_found(self, client, patch_post) assert len(responses) == 0 + @pytest.mark.asyncio + async def test_drive_items_permissions_batch_empty(self, client, patch_post): + drive_id = 1 + drive_item_ids = [] + + async for _response in client.drive_items_permissions_batch( + drive_id, drive_item_ids + ): + msg = "we shouldn't get here" + raise Exception(msg) + @pytest.mark.asyncio async def test_site_role_assignments(self, client, patch_scroll): site_role_assignments_url = ( diff --git a/tests/test_connectors_cli.py b/tests/test_connectors_cli.py new file mode 100644 index 000000000..1cd778367 --- /dev/null +++ b/tests/test_connectors_cli.py @@ -0,0 +1,408 @@ +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from click.testing import CliRunner +from elasticsearch import ApiError + +from connectors import __version__ # NOQA +from connectors.cli.auth import CONFIG_FILE_PATH +from connectors.connectors_cli import cli, login +from connectors.protocol.connectors import Connector as ConnectorObject +from connectors.protocol.connectors import JobStatus +from connectors.protocol.connectors import SyncJob as SyncJobObject +from tests.commons import AsyncIterator + + +@pytest.fixture(autouse=True) +def mock_cli_config(): + with patch("connectors.connectors_cli.load_config") as mock: + mock.return_value = {"elasticsearch": {"host": "http://localhost:9211/"}} + yield mock + + +@pytest.fixture(autouse=True) +def mock_connector_es_client(): + with patch("connectors.cli.connector.ESClient") as mock: + mock.return_value = AsyncMock() + yield mock + + +@pytest.fixture(autouse=True) +def mock_job_es_client(): + with patch("connectors.cli.job.ESClient") as mock: + mock.return_value = AsyncMock() + yield mock + + +def test_version(): + runner = CliRunner() + result = runner.invoke(cli, ["-v"]) + assert result.exit_code == 0 + assert result.output.strip() == __version__ + + +def test_help_page(): + runner = CliRunner() + result = runner.invoke(cli, ["--help"]) + assert "Usage:" in result.output + assert "Options:" in result.output + assert "Commands:" in result.output + + +def test_help_page_when_no_arguments(): + runner = CliRunner() + result = runner.invoke(cli, []) + assert "Usage:" in result.output + assert "Options:" in result.output + assert "Commands:" in result.output + + +@patch("connectors.cli.auth.Auth._Auth__ping_es_client", AsyncMock(return_value=False)) +def test_login_unsuccessful(tmp_path): + runner = CliRunner() + with runner.isolated_filesystem(temp_dir=tmp_path) as temp_dir: + result = runner.invoke( + login, input="http://localhost:9200/\nwrong_username\nwrong_password\n" + ) + assert result.exit_code == 0 + assert "Authentication failed" in result.output + assert not os.path.isfile(os.path.join(temp_dir, CONFIG_FILE_PATH)) + + +@patch("connectors.cli.auth.Auth._Auth__ping_es_client", AsyncMock(return_value=True)) +def test_login_successful(tmp_path): + runner = CliRunner() + with runner.isolated_filesystem(temp_dir=tmp_path) as temp_dir: + result = runner.invoke( + login, input="http://localhost:9200/\nwrong_username\nwrong_password\n" + ) + assert result.exit_code == 0 + assert "Authentication successful" in result.output + assert os.path.isfile(os.path.join(temp_dir, CONFIG_FILE_PATH)) + + +@patch("click.confirm") +def test_login_when_credentials_file_exists(mocked_confirm, tmp_path): + runner = CliRunner() + with runner.isolated_filesystem(temp_dir=tmp_path) as temp_dir: + mocked_confirm.return_value = True + + # Create config file + os.makedirs(os.path.dirname(CONFIG_FILE_PATH)) + with open(os.path.join(temp_dir, CONFIG_FILE_PATH), "w") as f: + f.write("fake config file") + + result = runner.invoke( + login, input="http://localhost:9200/\ncorrect_username\ncorrect_password\n" + ) + assert result.exit_code == 0 + assert mocked_confirm.called_once() + + +def test_connector_help_page(): + runner = CliRunner() + result = runner.invoke(cli, ["connector", "--help"]) + assert result.exit_code == 0 + assert "Usage:" in result.output + assert "Options:" in result.output + assert "Commands:" in result.output + + +@patch("connectors.cli.connector.Connector.list_connectors", AsyncMock(return_value=[])) +def test_connector_list_no_connectors(): + runner = CliRunner() + result = runner.invoke(cli, ["connector", "list"]) + assert result.exit_code == 0 + assert "No connectors found" in result.output + + +def test_connector_list_one_connector(): + runner = CliRunner() + connector_index = MagicMock() + + doc = { + "_source": { + "index_name": "test_connector", + "service_type": "mongodb", + "last_sync_status": "error", + "status": "connected", + }, + "_id": "test_id", + } + connectors = [ConnectorObject(connector_index, doc)] + + with patch( + "connectors.protocol.ConnectorIndex.all_connectors", AsyncIterator(connectors) + ): + result = runner.invoke(cli, ["connector", "list"]) + + assert result.exit_code == 0 + assert "test_connector" in result.output + assert "test_id" in result.output + assert "mongodb" in result.output + assert "error" in result.output + assert "connected" in result.output + + +@patch("click.confirm") +def test_connector_create(patch_click_confirm): + runner = CliRunner() + + # configuration for the MongoDB connector + input_params = "\n".join( + [ + "test_connector", + "mongodb", + "en", + "http://localhost/", + "username", + "password", + "database", + "collection", + "False", + ] + ) + + with patch( + "connectors.protocol.connectors.ConnectorIndex.index", + AsyncMock(return_value={"_id": "new_connector_id"}), + ) as patched_create: + result = runner.invoke(cli, ["connector", "create"], input=input_params) + + patched_create.assert_called_once() + assert result.exit_code == 0 + + assert "has been created" in result.output + + +def test_index_help_page(): + runner = CliRunner() + result = runner.invoke(cli, ["index", "--help"]) + assert result.exit_code == 0 + assert "Usage:" in result.output + assert "Options:" in result.output + assert "Commands:" in result.output + + +@patch("connectors.cli.index.Index.list_indices", MagicMock(return_value=[])) +def test_index_list_no_indexes(): + runner = CliRunner() + result = runner.invoke(cli, ["index", "list"]) + assert result.exit_code == 0 + assert "No indices found" in result.output + + +def test_index_list_one_index(): + runner = CliRunner() + indices = {"indices": {"test_index": {"primaries": {"docs": {"count": 10}}}}} + + with patch( + "connectors.es.client.ESClient.list_indices", AsyncMock(return_value=indices) + ): + result = runner.invoke(cli, ["index", "list"]) + + assert result.exit_code == 0 + assert "test_index" in result.output + + +@patch("click.confirm", MagicMock(return_value=True)) +def test_index_clean(): + runner = CliRunner() + index_name = "test_index" + with patch( + "connectors.es.client.ESClient.clean_index", AsyncMock(return_value=True) + ) as mocked_method: + result = runner.invoke(cli, ["index", "clean", index_name]) + + assert "The index has been cleaned" in result.output + mocked_method.assert_called_once_with(index_name) + assert result.exit_code == 0 + + +@patch("click.confirm", MagicMock(return_value=True)) +def test_index_clean_error(): + runner = CliRunner() + index_name = "test_index" + with patch( + "connectors.es.client.ESClient.clean_index", + side_effect=ApiError(500, meta="meta", body="error"), + ): + result = runner.invoke(cli, ["index", "clean", index_name]) + + assert "Something went wrong." in result.output + assert result.exit_code == 0 + + +@patch("click.confirm", MagicMock(return_value=True)) +def test_index_delete(): + runner = CliRunner() + index_name = "test_index" + with patch( + "connectors.es.client.ESClient.delete_indices", AsyncMock(return_value=None) + ) as mocked_method: + result = runner.invoke(cli, ["index", "delete", index_name]) + + assert "The index has been deleted" in result.output + mocked_method.assert_called_once_with([index_name]) + assert result.exit_code == 0 + + +@patch("click.confirm", MagicMock(return_value=True)) +def test_delete_index_error(): + runner = CliRunner() + index_name = "test_index" + with patch( + "connectors.es.client.ESClient.delete_indices", + side_effect=ApiError(500, meta="meta", body="error"), + ): + result = runner.invoke(cli, ["index", "delete", index_name]) + + assert "Something went wrong." in result.output + assert result.exit_code == 0 + + +def test_job_help_page(): + runner = CliRunner() + result = runner.invoke(cli, ["job", "--help"]) + assert result.exit_code == 0 + assert "Usage:" in result.output + assert "Options:" in result.output + assert "Commands:" in result.output + + +def test_job_help_page_without_subcommands(): + runner = CliRunner() + result = runner.invoke(cli, ["job"]) + assert result.exit_code == 0 + assert "Usage:" in result.output + assert "Options:" in result.output + assert "Commands:" in result.output + + +@patch("click.confirm", MagicMock(return_value=True)) +def test_job_cancel(): + runner = CliRunner() + job_id = "test_job_id" + + job_index = MagicMock() + + doc = { + "_source": { + "connector": { + "index_name": "test_connector", + "service_type": "mongodb", + "last_sync_status": "error", + "status": "connected", + }, + "status": "running", + "job_type": "full", + }, + "_id": job_id, + } + + job = SyncJobObject(job_index, doc) + + with patch( + "connectors.cli.job.Job._Job__async_list_jobs", AsyncMock(return_value=[job]) + ): + with patch.object(job, "_terminate") as mocked_method: + result = runner.invoke(cli, ["job", "cancel", job_id]) + + mocked_method.assert_called_once_with(JobStatus.CANCELING) + assert "The job has been cancelled" in result.output + assert result.exit_code == 0 + + +@patch("click.confirm", MagicMock(return_value=True)) +def test_job_cancel_error(): + runner = CliRunner() + job_id = "test_job_id" + with patch( + "connectors.cli.job.Job._Job__async_list_jobs", + side_effect=ApiError(500, meta="meta", body="error"), + ): + result = runner.invoke(cli, ["job", "cancel", job_id]) + + assert "Something went wrong." in result.output + assert result.exit_code == 0 + + +def test_job_list_no_jobs(): + runner = CliRunner() + connector_id = "test_connector_id" + + with patch( + "connectors.cli.job.Job._Job__async_list_jobs", AsyncMock(return_value=[]) + ): + result = runner.invoke(cli, ["job", "list", connector_id]) + + assert "No jobs found" in result.output + assert result.exit_code == 0 + + +@patch("click.confirm", MagicMock(return_value=True)) +def test_job_list_one_job(): + runner = CliRunner() + job_id = "test_job_id" + connector_id = "test_connector_id" + index_name = "test_index_name" + status = "canceled" + deleted_document_count = 123 + indexed_document_count = 123123 + indexed_document_volume = 100500 + + job_index = MagicMock() + + doc = { + "_source": { + "connector": { + "id": connector_id, + "index_name": index_name, + "service_type": "mongodb", + "last_sync_status": "error", + "status": "connected", + }, + "status": status, + "deleted_document_count": deleted_document_count, + "indexed_document_count": indexed_document_count, + "indexed_document_volume": indexed_document_volume, + "job_type": "full", + }, + "_id": job_id, + } + + job = SyncJobObject(job_index, doc) + + with patch( + "connectors.protocol.connectors.SyncJobIndex.get_all_docs", AsyncIterator([job]) + ): + result = runner.invoke(cli, ["job", "list", connector_id]) + + assert job_id in result.output + assert connector_id in result.output + assert index_name in result.output + assert status in result.output + assert str(deleted_document_count) in result.output + assert str(indexed_document_count) in result.output + assert str(indexed_document_volume) in result.output + assert result.exit_code == 0 + + +@patch( + "connectors.protocol.connectors.ConnectorIndex.fetch_by_id", + AsyncMock(return_value=MagicMock()), +) +def test_job_start(): + runner = CliRunner() + connector_id = "test_connector_id" + + with patch( + "connectors.protocol.connectors.SyncJobIndex.create", + AsyncMock(return_value=True), + ) as patched_create: + result = runner.invoke(cli, ["job", "start", "-i", connector_id, "-t", "full"]) + + patched_create.assert_called_once() + assert "The job has been started" in result.output + assert result.exit_code == 0 diff --git a/tests/test_cli.py b/tests/test_service_cli.py similarity index 94% rename from tests/test_cli.py rename to tests/test_service_cli.py index bab9886fc..635f399bb 100644 --- a/tests/test_cli.py +++ b/tests/test_service_cli.py @@ -14,7 +14,7 @@ import pytest from connectors import __version__ -from connectors.cli import main, run +from connectors.service_cli import main, run HERE = os.path.dirname(__file__) FIXTURES_DIR = os.path.abspath(os.path.join(HERE, "fixtures")) @@ -90,8 +90,10 @@ def test_run_snowflake(mock_responses, set_env): assert "Cannot use the `list` action with other actions" in output -@patch("connectors.cli.set_logger") -@patch("connectors.cli.load_config", side_effect=Exception("something went wrong")) +@patch("connectors.service_cli.set_logger") +@patch( + "connectors.service_cli.load_config", side_effect=Exception("something went wrong") +) def test_main_with_invalid_configuration(load_config, set_logger): args = mock.MagicMock() args.log_level = logging.DEBUG # should be ignored! diff --git a/tests/test_utils.py b/tests/test_utils.py index 631323aa9..708aa6067 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -836,6 +836,8 @@ def batch_size(value): [1, 2, 3], ], ), + ([], batch_size(20), []), + ([[]], batch_size(20), [[[]]]), ], ) def test_iterable_batches_generator(iterable, batch_size_, expected_batches):