Skip to content

Commit

Permalink
[S3] Advanced Sync rules (#1921)
Browse files Browse the repository at this point in the history
Co-authored-by: Chenhui Wang <[email protected]>
  • Loading branch information
akanshi-elastic and wangch079 authored Nov 29, 2023
1 parent d29bca5 commit df52a50
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 17 deletions.
111 changes: 96 additions & 15 deletions connectors/sources/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,16 @@
from functools import partial

import aioboto3
import fastjsonschema
from aiobotocore.config import AioConfig
from aiobotocore.utils import logger as aws_logger
from botocore.exceptions import ClientError
from fastjsonschema import JsonSchemaValueException

from connectors.filtering.validation import (
AdvancedRulesValidator,
SyncRuleValidationResult,
)
from connectors.logger import logger, set_extra_logger
from connectors.source import BaseDataSource
from connectors.utils import hash_id
Expand Down Expand Up @@ -105,7 +111,7 @@ async def get_bucket_list(self):
buckets = self.configuration["buckets"]
return buckets

async def get_bucket_objects(self, bucket):
async def get_bucket_objects(self, bucket, **kwargs):
"""Returns bucket list from list_buckets response
Args:
bucket (str): Name of bucket
Expand All @@ -126,7 +132,14 @@ async def get_bucket_objects(self, bucket):
bucket_obj = await s3.Bucket(bucket)
await asyncio.sleep(0)

async for obj_summary in bucket_obj.objects.page_size(page_size):
if kwargs.get("prefix"):
objects = bucket_obj.objects.filter(
Prefix=kwargs["prefix"]
).page_size(page_size)
else:
objects = bucket_obj.objects.page_size(page_size)

async for obj_summary in objects:
yield obj_summary, s3_client
except Exception as exception:
self._logger.warning(
Expand All @@ -153,11 +166,49 @@ async def get_bucket_region(self, bucket_name):
return region


class S3AdvancedRulesValidator(AdvancedRulesValidator):
RULES_OBJECT_SCHEMA_DEFINITION = {
"type": "object",
"properties": {
"bucket": {"type": "string", "minLength": 1},
"prefix": {"type": "string"},
"extension": {"type": "array"},
},
"required": ["bucket"],
"additionalProperties": False,
}

SCHEMA_DEFINITION = {"type": "array", "items": RULES_OBJECT_SCHEMA_DEFINITION}

SCHEMA = fastjsonschema.compile(definition=SCHEMA_DEFINITION)

def __init__(self, source):
self.source = source

async def validate(self, advanced_rules):
if len(advanced_rules) == 0:
return SyncRuleValidationResult.valid_result(
SyncRuleValidationResult.ADVANCED_RULES
)
try:
S3AdvancedRulesValidator.SCHEMA(advanced_rules)
return SyncRuleValidationResult.valid_result(
rule_id=SyncRuleValidationResult.ADVANCED_RULES
)
except JsonSchemaValueException as e:
return SyncRuleValidationResult(
rule_id=SyncRuleValidationResult.ADVANCED_RULES,
is_valid=False,
validation_message=e.message,
)


class S3DataSource(BaseDataSource):
"""Amazon S3"""

name = "Amazon S3"
service_type = "s3"
advanced_rules_enabled = True

def __init__(self, configuration):
"""Set up the connection to the Amazon S3.
Expand All @@ -171,6 +222,9 @@ def __init__(self, configuration):
def _set_internal_logger(self):
self.s3_client.set_logger(self._logger)

def advanced_rules_validators(self):
return [S3AdvancedRulesValidator(self)]

async def ping(self):
"""Verify the connection with AWS"""
try:
Expand Down Expand Up @@ -203,6 +257,26 @@ async def format_document(self, bucket_name, bucket_object):
}
return document

async def advanced_sync(self, rule):
async def process_object(obj_summary, s3_client):
document = await self.format_document(
bucket_name=bucket, bucket_object=obj_summary
)
return document, partial(
self.get_content, doc=document, s3_client=s3_client
)

bucket = rule["bucket"]
prefix = rule.get("prefix", "")
async for obj_summary, s3_client in self.s3_client.get_bucket_objects(
bucket=bucket, prefix=prefix
):
if not rule.get("extension"):
yield await process_object(obj_summary, s3_client)

elif self.get_file_extension(obj_summary.key) in rule.get("extension", []):
yield await process_object(obj_summary, s3_client)

async def get_docs(self, filtering=None):
"""Get documents from Amazon S3
Expand All @@ -212,19 +286,25 @@ async def get_docs(self, filtering=None):
Yields:
dictionary: Document from Amazon S3.
"""
bucket_list = await self.s3_client.get_bucket_list()
for bucket in bucket_list:
async for obj_summary, s3_client in self.s3_client.get_bucket_objects(
bucket
):
document = await self.format_document(
bucket_name=bucket, bucket_object=obj_summary
)
yield document, partial(
self.get_content,
doc=document,
s3_client=s3_client,
)
if filtering and filtering.has_advanced_rules():
for rule in filtering.get_advanced_rules():
async for document, attachment in self.advanced_sync(rule=rule):
yield document, attachment

else:
bucket_list = await self.s3_client.get_bucket_list()
for bucket in bucket_list:
async for obj_summary, s3_client in self.s3_client.get_bucket_objects(
bucket=bucket
):
document = await self.format_document(
bucket_name=bucket, bucket_object=obj_summary
)
yield document, partial(
self.get_content,
doc=document,
s3_client=s3_client,
)

async def get_content(self, doc, s3_client, timestamp=None, doit=None):
if not (doit):
Expand Down Expand Up @@ -272,6 +352,7 @@ def get_default_configuration(cls):
"display": "textarea",
"label": "AWS Buckets",
"order": 1,
"tooltip": "AWS Buckets are ignored when Advanced Sync Rules are used.",
"type": "list",
},
"aws_access_key_id": {
Expand Down
105 changes: 103 additions & 2 deletions tests/sources/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@
from contextlib import asynccontextmanager
from datetime import datetime
from unittest import mock
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import ANY, AsyncMock, MagicMock, patch

import aioboto3
import aiofiles
import pytest
from botocore.exceptions import ClientError, HTTPClientError

from connectors.filtering.validation import SyncRuleValidationResult
from connectors.protocol import Filter
from connectors.source import ConfigurableFieldValueError
from connectors.sources.s3 import S3DataSource
from connectors.sources.s3 import S3AdvancedRulesValidator, S3DataSource
from tests.sources.support import create_source

ADVANCED_SNIPPET = "advanced_snippet"


@asynccontextmanager
async def create_s3_source(use_text_extraction_service=False):
Expand Down Expand Up @@ -368,6 +372,43 @@ async def test_get_docs(mock_aws):
num += 1


@pytest.mark.parametrize(
"filtering",
[
Filter(
{
ADVANCED_SNIPPET: {
"value": [
{"bucket": "bucket1"},
]
}
}
),
],
)
@pytest.mark.asyncio
async def test_get_docs_with_advanced_rules(filtering):
async with create_s3_source() as source:
source.s3_client.get_bucket_location = mock.Mock(
return_value=await create_fake_coroutine("ap-south-1")
)
with mock.patch(
"aioboto3.resources.collection.AIOResourceCollection", AIOResourceCollection
), mock.patch("aiobotocore.client.AioBaseClient", S3Object), mock.patch(
"aiobotocore.utils.AioInstanceMetadataFetcher.retrieve_iam_role_credentials",
get_roles,
):
num = 0
async for (doc, _) in source.get_docs(filtering):
assert doc["_id"] in (
"70743168e14c18632702ee6e3e9b73fc",
"9fbda540ca0a2441475aea7b8f37bdaf",
"c5a8c684e7bbdc471a20613a6d8074e1",
"e2819e8a4e921caaf0250548ffaddde4",
)
num += 1


@pytest.mark.asyncio
async def test_get_bucket_list():
"""Test get_bucket_list method of S3Client"""
Expand Down Expand Up @@ -441,3 +482,63 @@ async def test_close_with_client_session():
await source.close()
with pytest.raises(HTTPClientError):
await source.ping()


@pytest.mark.parametrize(
"advanced_rules, expected_validation_result",
[
(
# valid: empty array should be valid
[],
SyncRuleValidationResult.valid_result(
SyncRuleValidationResult.ADVANCED_RULES
),
),
(
# valid: empty object should also be valid -> default value in Kibana
{},
SyncRuleValidationResult.valid_result(
SyncRuleValidationResult.ADVANCED_RULES
),
),
(
# valid: one custom pattern
[{"bucket": "bucket1"}],
SyncRuleValidationResult.valid_result(
SyncRuleValidationResult.ADVANCED_RULES
),
),
(
# valid: two custom patterns
[{"bucket": "bucket1"}, {"bucket": "bucket2"}],
SyncRuleValidationResult.valid_result(
SyncRuleValidationResult.ADVANCED_RULES
),
),
(
# invalid: extension in string
[{"bucket": "bucket1", "extension": ".jpg"}],
SyncRuleValidationResult(
SyncRuleValidationResult.ADVANCED_RULES,
is_valid=False,
validation_message=ANY,
),
),
(
# invalid: array of arrays -> wrong type
{"bucket": ["a/b/c", ""]},
SyncRuleValidationResult(
SyncRuleValidationResult.ADVANCED_RULES,
is_valid=False,
validation_message=ANY,
),
),
],
)
@pytest.mark.asyncio
async def test_advanced_rules_validation(advanced_rules, expected_validation_result):
async with create_source(S3DataSource) as source:
validation_result = await S3AdvancedRulesValidator(source).validate(
advanced_rules
)
assert validation_result == expected_validation_result

0 comments on commit df52a50

Please sign in to comment.