From 03e2728088ce0e93a2185683ddbc123a45ae5693 Mon Sep 17 00:00:00 2001 From: Navarone Feekery <13634519+navarone-feekery@users.noreply.github.com> Date: Tue, 24 Oct 2023 11:09:08 +0200 Subject: [PATCH] Fix configuration migrations to work for peripheral connectors (#1830) --- connectors/protocol/connectors.py | 175 ++++++++++++------------------ tests/protocol/test_connectors.py | 147 +++++++++++++++++++++++-- 2 files changed, 206 insertions(+), 116 deletions(-) diff --git a/connectors/protocol/connectors.py b/connectors/protocol/connectors.py index 73454bf0e..72bf7cdad 100644 --- a/connectors/protocol/connectors.py +++ b/connectors/protocol/connectors.py @@ -743,35 +743,27 @@ async def prepare(self, config, sources): configured_connector_id = config.get("connector_id", "") configured_service_type = config.get("service_type", "") + is_main_connector = self.id == configured_connector_id - if self.id != configured_connector_id: - # check configuration for native and other peripheral connectors + if is_main_connector: + if not configured_service_type: + self.log_error("Service type is not configured") + raise ServiceTypeNotConfiguredError("Service type is not configured.") + + if configured_service_type not in sources: + raise ServiceTypeNotSupportedError(configured_service_type) + else: if self.service_type not in sources: self.log_debug( f"Peripheral connector has invalid service type {self.service_type}, cannot check configuration formatting." ) return - await self.validate_configuration_formatting( - sources[self.service_type], self.service_type - ) - - return - - if not configured_service_type: - self.log_error("Service type is not configured") - raise ServiceTypeNotConfiguredError("Service type is not configured.") - - if configured_service_type not in sources: - raise ServiceTypeNotSupportedError(configured_service_type) - - if self.service_type is not None and not self.configuration.is_empty(): - await self.validate_configuration_formatting( - sources[configured_service_type], configured_service_type - ) - - doc = {} - fqn = sources[configured_service_type] + fqn = ( + sources[configured_service_type] + if is_main_connector + else sources[self.service_type] + ) try: source_klass = get_source_klass(fqn) except Exception as e: @@ -780,25 +772,12 @@ async def prepare(self, config, sources): f"Could not instantiate {fqn} for {configured_service_type}" ) from e - if self.service_type is None: + doc = self.validated_doc(source_klass) + if is_main_connector and self.service_type is None: doc["service_type"] = configured_service_type self.log_debug(f"Populated service type {configured_service_type}") - simple_config = source_klass.get_simple_configuration() - current_config = self.configuration.to_dict() - missing_keys = simple_config.keys() - current_config.keys() - if self.configuration.is_empty(): - # sets the defaults and the flag to NEEDS_CONFIGURATION - doc["configuration"] = simple_config - doc["status"] = Status.NEEDS_CONFIGURATION.value - self.log_debug("Populated configuration") - elif missing_keys: - doc["configuration"] = self.updated_configuration( - missing_keys, current_config, simple_config - ) - # doc["status"] = Status.NEEDS_CONFIGURATION.value # not setting status, because it may be that default values are sufficient - - if self.features.features != source_klass.features(): + if is_main_connector and self.features.features != source_klass.features(): doc["features"] = source_klass.features() self.log_debug("Populated features") @@ -813,7 +792,41 @@ async def prepare(self, config, sources): ) await self.reload() - def updated_configuration( + def validated_doc(self, source_klass): + simple_config = source_klass.get_simple_configuration() + current_config = self.configuration.to_dict() + + if self.configuration.is_empty(): + # sets the defaults and the flag to NEEDS_CONFIGURATION + self.log_debug("Populated configuration") + return { + "configuration": simple_config, + "status": Status.NEEDS_CONFIGURATION.value, + } + + missing_fields = simple_config.keys() - current_config.keys() + fields_missing_properties = filter_nested_dict_by_keys( + DEFAULT_CONFIGURATION.keys(), current_config + ) + if not missing_fields and not fields_missing_properties: + return {} + + doc = {"configuration": {}} + if missing_fields: + doc["configuration"] = self.updated_configuration_fields( + missing_fields, current_config, simple_config + ) + if fields_missing_properties: + updated_config = self.updated_configuration_field_properties( + fields_missing_properties, simple_config + ) + doc["configuration"] = deep_merge_dicts( + doc["configuration"], updated_config + ) + + return doc + + def updated_configuration_fields( self, missing_keys, current_config, simple_default_config ): self.log_warning( @@ -834,6 +847,25 @@ def updated_configuration( draft_config[config_name] = draft_config_obj return draft_config + def updated_configuration_field_properties( + self, fields_missing_properties, simple_config + ): + """Checks the field properties for every field in a configuration. + If a field is missing field properties, add those field properties + with default values. + """ + self.log_info( + f"Connector {self.id} ({self.service_type}) is missing configuration field properties. Generating defaults." + ) + + # filter the default config by what fields we want to update, then merge the actual config into it + filtered_simple_config = { + key: value + for key, value in simple_config.items() + if key in fields_missing_properties.keys() + } + return deep_merge_dicts(filtered_simple_config, fields_missing_properties) + @with_concurrency_control() async def validate_filtering(self, validator): await self.reload() @@ -877,71 +909,6 @@ async def document_count(self): ) return result["count"] - async def validate_configuration_formatting(self, fqn, service_type): - """Wrapper function for validating configuration field properties. - - Args: - fqn (string): the source fqn for a service, from config file - service_type (string): service type of the connector - """ - try: - source_klass = get_source_klass(fqn) - except Exception as e: - self.log_critical(e, exc_info=True) - raise DataSourceError( - f"Could not instantiate {fqn} for {service_type}" - ) from e - - default_config = source_klass.get_simple_configuration() - current_config = self.configuration.to_dict() - - await self.add_missing_configuration_field_properties( - service_type, default_config, current_config - ) - - async def add_missing_configuration_field_properties( - self, service_type, default_config, current_config - ): - """Checks the field properties for every field in a configuration. - If a field is missing field properties, add those field properties - with default values. - If no field properties are missing, nothing is updated. - - Args: - service_type (string): service type of the connector - default_config (dict): the default configuration for the connector - current_config (dict): the currently existing configuration for the connector - """ - configs_missing_properties = filter_nested_dict_by_keys( - DEFAULT_CONFIGURATION.keys(), current_config - ) - if not configs_missing_properties: - return - - self.log_info( - f"Connector for {service_type} is missing configuration field properties. Generating defaults." - ) - - # filter the default config by what fields we want to update, then merge the actual config into it - filtered_default_config = { - key: value - for key, value in default_config.items() - if key in configs_missing_properties.keys() - } - doc = { - "configuration": deep_merge_dicts( - filtered_default_config, configs_missing_properties - ) - } - - await self.index.update( - doc_id=self.id, - doc=doc, - if_seq_no=self._seq_no, - if_primary_term=self._primary_term, - ) - await self.reload() - def _prefix(self): return f"[Connector id: {self.id}, index name: {self.index_name}]" diff --git a/tests/protocol/test_connectors.py b/tests/protocol/test_connectors.py index c02e31956..808c0d58b 100644 --- a/tests/protocol/test_connectors.py +++ b/tests/protocol/test_connectors.py @@ -901,7 +901,7 @@ def get_default_configuration(cls): @pytest.mark.asyncio -async def test_connector_prepare_different_id(): +async def test_connector_prepare_different_id_invalid_source(): doc_id = "1" seq_no = 1 primary_term = 2 @@ -924,13 +924,19 @@ async def test_connector_prepare_different_id(): index.update.assert_not_awaited() +@pytest.mark.parametrize( + "main_doc_id, this_doc_id", + [ + ("1", "1"), + ("1", "2"), + ], +) @pytest.mark.asyncio -async def test_connector_prepare_with_prepared_connector(): - doc_id = "1" +async def test_connector_prepare_with_prepared_connector(main_doc_id, this_doc_id): seq_no = 1 primary_term = 2 connector_doc = { - "_id": doc_id, + "_id": this_doc_id, "_seq_no": seq_no, "_primary_term": primary_term, "_source": { @@ -971,7 +977,7 @@ async def test_connector_prepare_with_prepared_connector(): }, } config = { - "connector_id": doc_id, + "connector_id": main_doc_id, "service_type": "banana", } sources = {"banana": "tests.protocol.test_connectors:Banana"} @@ -983,13 +989,128 @@ async def test_connector_prepare_with_prepared_connector(): index.update.assert_not_awaited() +@pytest.mark.parametrize( + "main_doc_id, this_doc_id", + [ + ("1", "1"), + ("1", "2"), + ], +) @pytest.mark.asyncio -async def test_connector_prepare_with_connector_missing_field_properties_creates_them(): - doc_id = "1" +async def test_connector_prepare_with_connector_empty_config_creates_default( + main_doc_id, this_doc_id +): seq_no = 1 primary_term = 2 connector_doc = { - "_id": doc_id, + "_id": this_doc_id, + "_seq_no": seq_no, + "_primary_term": primary_term, + "_source": { + "service_type": "banana", + "configuration": {}, + "features": Banana.features(), + }, + } + config = { + "connector_id": main_doc_id, + "service_type": "banana", + } + sources = {"banana": "tests.protocol.test_connectors:Banana"} + index = Mock() + index.fetch_response_by_id = AsyncMock(return_value=connector_doc) + index.update = AsyncMock() + connector = Connector(elastic_index=index, doc_source=connector_doc) + + expected = Banana.get_simple_configuration() + + # only updates fields with missing properties + await connector.prepare(config, sources) + index.update.assert_called_once_with( + doc_id=this_doc_id, + doc={"configuration": expected, "status": "needs_configuration"}, + if_seq_no=seq_no, + if_primary_term=primary_term, + ) + + +@pytest.mark.parametrize( + "main_doc_id, this_doc_id", + [ + ("1", "1"), + ("1", "2"), + ], +) +@pytest.mark.asyncio +async def test_connector_prepare_with_connector_missing_fields_creates_them( + main_doc_id, this_doc_id +): + seq_no = 1 + primary_term = 2 + connector_doc = { + "_id": this_doc_id, + "_seq_no": seq_no, + "_primary_term": primary_term, + "_source": { + "service_type": "banana", + "configuration": { + "two": { + "default_value": None, + "depends_on": [], + "display": "text", + "label": "", + "options": [], + "order": 1, + "required": True, + "sensitive": False, + "tooltip": None, + "type": "str", + "ui_restrictions": [], + "validations": [], + "value": "foobar", + }, + }, + "features": Banana.features(), + }, + } + config = { + "connector_id": main_doc_id, + "service_type": "banana", + } + sources = {"banana": "tests.protocol.test_connectors:Banana"} + index = Mock() + index.fetch_response_by_id = AsyncMock(return_value=connector_doc) + index.update = AsyncMock() + connector = Connector(elastic_index=index, doc_source=connector_doc) + + expected = Banana.get_simple_configuration() + del expected["two"] + + # only updates fields with missing properties + await connector.prepare(config, sources) + index.update.assert_called_once_with( + doc_id=this_doc_id, + doc={"configuration": expected}, + if_seq_no=seq_no, + if_primary_term=primary_term, + ) + + +@pytest.mark.parametrize( + "main_doc_id, this_doc_id", + [ + ("1", "1"), + ("1", "2"), + ], +) +@pytest.mark.asyncio +async def test_connector_prepare_with_connector_missing_field_properties_creates_them( + main_doc_id, this_doc_id +): + seq_no = 1 + primary_term = 2 + connector_doc = { + "_id": this_doc_id, "_seq_no": seq_no, "_primary_term": primary_term, "_source": { @@ -1016,7 +1137,7 @@ async def test_connector_prepare_with_connector_missing_field_properties_creates }, } config = { - "connector_id": doc_id, + "connector_id": main_doc_id, "service_type": "banana", } sources = {"banana": "tests.protocol.test_connectors:Banana"} @@ -1034,7 +1155,7 @@ async def test_connector_prepare_with_connector_missing_field_properties_creates await connector.prepare(config, sources) index.update.assert_called_once_with( - doc_id=doc_id, + doc_id=this_doc_id, doc={"configuration": expected}, if_seq_no=seq_no, if_primary_term=primary_term, @@ -1976,7 +2097,7 @@ def test_get_advanced_rules(filtering, expected_advanced_rules): assert Filter(filtering).get_advanced_rules() == expected_advanced_rules -def test_updated_configuration(): +def test_updated_configuration_fields(): current = { "tenant_id": {"label": "Tenant ID", "order": 1, "type": "str", "value": "foo"}, "tenant_name": { @@ -2027,7 +2148,9 @@ def test_updated_configuration(): } missing_configs = ["new_config"] connector = Connector(elastic_index=Mock(), doc_source={"_id": "test"}) - result = connector.updated_configuration(missing_configs, current, simple_default) + result = connector.updated_configuration_fields( + missing_configs, current, simple_default + ) # all keys included where there are changes (excludes 'tenant_name') assert result.keys() == set(