From e4da590f603dba69001203cbb26370959c4ca9b0 Mon Sep 17 00:00:00 2001 From: Eric Butera Date: Tue, 18 Jul 2023 09:42:57 -0400 Subject: [PATCH 01/13] docs: architecture --- docs/diagrams/architecture.md | 175 ++++++++++++++++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 docs/diagrams/architecture.md diff --git a/docs/diagrams/architecture.md b/docs/diagrams/architecture.md new file mode 100644 index 0000000..5c2a864 --- /dev/null +++ b/docs/diagrams/architecture.md @@ -0,0 +1,175 @@ +# Architecture + +```mermaid +--- +title: Scan Overview +--- +flowchart LR + cloud-connector --> CloudProviders + CloudProviders --> ResourceTypes + ResourceTypes --> ASM + + subgraph CloudProviders + direction TB + AWS + Azure + GCP + end + +``` + +## AWS + +```mermaid +--- +title: AWS Scan +--- +flowchart LR + cloud-connector --> CloudProviders + AWS --> AwsResourceTypes + + aws-api-gateway --> DomainSeed + aws-ecs --> IpSeed + aws-eni --> IpSeed + aws-elb --> DomainSeed + aws-rds --> DomainSeed + aws-route53 --> DomainSeed + aws-s3 --> CloudAsset + + IpSeed --> ASMSeed + DomainSeed --> ASMSeed + CloudAsset --> ASMCloudAsset + + subgraph CloudProviders + direction TB + AWS + end + + subgraph ResourceTypes + subgraph AwsResourceTypes + direction TB + aws-api-gateway + aws-ecs + aws-eni + aws-elb + aws-rds + aws-route53 + aws-s3 + end + end + + subgraph CcSeedTypes + direction TB + DomainSeed + IpSeed + CloudAsset + end + + subgraph ASM + direction TB + ASMSeed["/v1/seeds"] + ASMCloudAsset["/beta/cloudConnector/addCloudAsset"] + end +``` + +## Azure + +```mermaid +--- +title: Azure Scan +--- +flowchart LR + cloud-connector --> CloudProviders + AZURE --> AzureResourceTypes + + az-container-groups --> IpSeed + az-dns-zones --> DomainSeed + az-public-ip-addresses --> IpSeed + az-sql-servers --> DomainSeed + az-storage-accounts --> CloudAsset + + IpSeed --> ASMSeed + DomainSeed --> ASMSeed + CloudAsset --> ASMCloudAsset + + subgraph CloudProviders + direction TB + AZURE + end + + subgraph ResourceTypes + subgraph AzureResourceTypes + direction TB + az-container-groups + az-dns-zones + az-public-ip-addresses + az-sql-servers + az-storage-accounts + end + end + + subgraph CcSeedTypes + direction TB + DomainSeed + IpSeed + CloudAsset + end + + subgraph ASM + direction TB + ASMSeed["/v1/seeds"] + ASMCloudAsset["/beta/cloudConnector/addCloudAsset"] + end +``` + +## GCP + +```mermaid +--- +title: GCP Scan +--- +flowchart LR + cloud-connector --> CloudProviders + GOOGLE --> GcpResourceTypes + + gcp-compute-instance --> IpSeed + gcp-compute-address --> IpSeed + gcp-container-cluster --> IpSeed + gcp-cloudsql-instance --> DomainSeed + gcp-dns-zone --> DomainSeed + gcp-storage-bucket --> CloudAsset + + IpSeed --> ASMSeed + DomainSeed --> ASMSeed + CloudAsset --> ASMCloudAsset + + subgraph CloudProviders + direction TB + GOOGLE + end + + subgraph ResourceTypes + subgraph GcpResourceTypes + direction TB + gcp-compute-instance + gcp-compute-address + gcp-container-cluster + gcp-cloudsql-instance + gcp-dns-zone + gcp-storage-bucket + end + end + + subgraph CcSeedTypes + direction TB + DomainSeed + IpSeed + CloudAsset + end + + subgraph ASM + direction TB + ASMSeed["/v1/seeds"] + ASMCloudAsset["/beta/cloudConnector/addCloudAsset"] + end +``` From bbf6c8b0789f724391860848b5d25b22d2aa6963 Mon Sep 17 00:00:00 2001 From: Eric Butera Date: Thu, 10 Aug 2023 10:41:43 -0400 Subject: [PATCH 02/13] WIP --- docs/diagrams/architecture.md | 175 ------------------ .../aws_connector/connector.py | 127 ++++++++++--- .../cloud_connectors/common/connector.py | 23 ++- .../cloud_connectors/plugins/aws_tags.py | 4 + 4 files changed, 120 insertions(+), 209 deletions(-) delete mode 100644 docs/diagrams/architecture.md diff --git a/docs/diagrams/architecture.md b/docs/diagrams/architecture.md deleted file mode 100644 index 5c2a864..0000000 --- a/docs/diagrams/architecture.md +++ /dev/null @@ -1,175 +0,0 @@ -# Architecture - -```mermaid ---- -title: Scan Overview ---- -flowchart LR - cloud-connector --> CloudProviders - CloudProviders --> ResourceTypes - ResourceTypes --> ASM - - subgraph CloudProviders - direction TB - AWS - Azure - GCP - end - -``` - -## AWS - -```mermaid ---- -title: AWS Scan ---- -flowchart LR - cloud-connector --> CloudProviders - AWS --> AwsResourceTypes - - aws-api-gateway --> DomainSeed - aws-ecs --> IpSeed - aws-eni --> IpSeed - aws-elb --> DomainSeed - aws-rds --> DomainSeed - aws-route53 --> DomainSeed - aws-s3 --> CloudAsset - - IpSeed --> ASMSeed - DomainSeed --> ASMSeed - CloudAsset --> ASMCloudAsset - - subgraph CloudProviders - direction TB - AWS - end - - subgraph ResourceTypes - subgraph AwsResourceTypes - direction TB - aws-api-gateway - aws-ecs - aws-eni - aws-elb - aws-rds - aws-route53 - aws-s3 - end - end - - subgraph CcSeedTypes - direction TB - DomainSeed - IpSeed - CloudAsset - end - - subgraph ASM - direction TB - ASMSeed["/v1/seeds"] - ASMCloudAsset["/beta/cloudConnector/addCloudAsset"] - end -``` - -## Azure - -```mermaid ---- -title: Azure Scan ---- -flowchart LR - cloud-connector --> CloudProviders - AZURE --> AzureResourceTypes - - az-container-groups --> IpSeed - az-dns-zones --> DomainSeed - az-public-ip-addresses --> IpSeed - az-sql-servers --> DomainSeed - az-storage-accounts --> CloudAsset - - IpSeed --> ASMSeed - DomainSeed --> ASMSeed - CloudAsset --> ASMCloudAsset - - subgraph CloudProviders - direction TB - AZURE - end - - subgraph ResourceTypes - subgraph AzureResourceTypes - direction TB - az-container-groups - az-dns-zones - az-public-ip-addresses - az-sql-servers - az-storage-accounts - end - end - - subgraph CcSeedTypes - direction TB - DomainSeed - IpSeed - CloudAsset - end - - subgraph ASM - direction TB - ASMSeed["/v1/seeds"] - ASMCloudAsset["/beta/cloudConnector/addCloudAsset"] - end -``` - -## GCP - -```mermaid ---- -title: GCP Scan ---- -flowchart LR - cloud-connector --> CloudProviders - GOOGLE --> GcpResourceTypes - - gcp-compute-instance --> IpSeed - gcp-compute-address --> IpSeed - gcp-container-cluster --> IpSeed - gcp-cloudsql-instance --> DomainSeed - gcp-dns-zone --> DomainSeed - gcp-storage-bucket --> CloudAsset - - IpSeed --> ASMSeed - DomainSeed --> ASMSeed - CloudAsset --> ASMCloudAsset - - subgraph CloudProviders - direction TB - GOOGLE - end - - subgraph ResourceTypes - subgraph GcpResourceTypes - direction TB - gcp-compute-instance - gcp-compute-address - gcp-container-cluster - gcp-cloudsql-instance - gcp-dns-zone - gcp-storage-bucket - end - end - - subgraph CcSeedTypes - direction TB - DomainSeed - IpSeed - CloudAsset - end - - subgraph ASM - direction TB - ASMSeed["/v1/seeds"] - ASMCloudAsset["/beta/cloudConnector/addCloudAsset"] - end -``` diff --git a/src/censys/cloud_connectors/aws_connector/connector.py b/src/censys/cloud_connectors/aws_connector/connector.py index ce9beb8..be7e06d 100644 --- a/src/censys/cloud_connectors/aws_connector/connector.py +++ b/src/censys/cloud_connectors/aws_connector/connector.py @@ -1,6 +1,7 @@ """AWS Cloud Connector.""" import contextlib from collections.abc import Generator, Sequence +from multiprocessing import Pool from typing import Any, Optional, TypeVar, Union import boto3 @@ -43,6 +44,7 @@ VALID_RECORD_TYPES = ["A", "CNAME"] IGNORED_TAGS = ["censys-cloud-connector-ignore"] +MAX_PROC = 4 # TODO: .env settings class AwsCloudConnector(CloudConnector): @@ -92,17 +94,19 @@ def __init__(self, settings: Settings): self.ignored_tags = [] self.global_ignored_tags: set[str] = set(IGNORED_TAGS) - def scan_seeds(self): + def scan_seeds(self, **kwargs): """Scan AWS for seeds.""" + credential = kwargs.get("credential") + region = kwargs.get("region") self.logger.info( - f"Scanning AWS account {self.account_number} in region {self.region}" + f"Scanning AWS account {credential['account_number']} in region {region}" ) - super().scan_seeds() + return super().scan_seeds(**kwargs) - def scan_cloud_assets(self): + def scan_cloud_assets(self, **kwargs): """Scan AWS for cloud assets.""" self.logger.info(f"Scanning AWS account {self.account_number}") - super().scan_cloud_assets() + super().scan_cloud_assets(**kwargs) def scan_all(self): """Scan all configured AWS provider accounts.""" @@ -110,12 +114,15 @@ def scan_all(self): tuple, AwsSpecificSettings ] = self.settings.providers.get(self.provider, {}) + pool = Pool(processes=MAX_PROC) + for provider_setting in provider_settings.values(): self.provider_settings = provider_setting for credential in self.provider_settings.get_credentials(): - self.credential = credential - self.account_number = credential["account_number"] + # self.credential = credential + # self.account_number = credential["account_number"] + # TODO: this wont work if using a pool (no self!) self.ignored_tags = self.get_ignored_tags(credential["ignore_tags"]) # for each account + region combination, run each seed scanner @@ -128,13 +135,29 @@ def scan_all(self): provider_setting, provider={ "region": region, - "account_number": self.account_number, + "account_number": credential[ + "account_number" + ], # self.account_number, }, ): - self.scan_seeds() + self.logger.info( + "starting pool account:%s region:%s", + credential["account_number"], + region, + ) + # currently doesn't work because `self.` references clobber each other + # example: region gets unset after first iteration and program crashes + pool.apply_async( + self.scan_seeds, + kwds={ + "credential": credential, + "region": region, + }, + ) + # self.scan(**kwargs) except Exception as e: self.logger.error( - f"Unable to scan account {self.account_number} in region {self.region}. Error: {e}" + f"Unable to scan account {credential['account_number']} in region {region}. Error: {e}" ) self.dispatch_event(EventTypeEnum.SCAN_FAILED, exception=e) self.region = None @@ -148,27 +171,44 @@ def scan_all(self): provider_setting, provider={"account_number": self.account_number}, ): - self.scan_cloud_assets() + # self.scan_cloud_assets() + pool.apply_async( + self.scan_cloud_assets, + kwds={ + "credential": credential, + "region": region, + }, + ) except Exception as e: self.logger.error( f"Unable to scan account {self.account_number}. Error: {e}" ) self.dispatch_event(EventTypeEnum.SCAN_FAILED, exception=e) - def format_label(self, service: AwsServices, region: Optional[str] = None) -> str: + pool.close() + pool.join() + + def format_label( + self, + service: AwsServices, + region: Optional[str] = None, + account_number: Optional[str] = None, + ) -> str: """Format AWS label. Args: service (AwsServices): AWS Service Type region (str): AWS Region override + account_number (str): AWS Account number Returns: str: Formatted label. """ - # TODO: rename this function to make it obvious that region is included in a label + # TODO: s/self.account_number/account_number <- use param + account = account_number or self.account_number region = region or self.region region_label = f"/{region}" if region != "" else "" - return f"AWS: {service} - {self.account_number}{region_label}" + return f"AWS: {service} - {account}{region_label}" def credentials(self) -> dict: """Generate required credentials for AWS. @@ -211,6 +251,9 @@ def get_aws_client( """ try: credentials = credentials or self.credentials() + credentials[ + "endpoint_url" + ] = "http://localhost.localstack.cloud:4566" # TODO: env settings if credentials.get("aws_access_key_id"): self.logger.debug(f"AWS Service {service} using access key credentials") return boto3.client(service, **credentials) # type: ignore @@ -328,6 +371,7 @@ def assume_role( role_session = ( self.credential["role_session_name"] or AwsDefaults.ROLE_SESSION_NAME.value ) + # TODO: s/self.account_number/cred[account_number] <- pass in account number role: dict[str, Any] = { "RoleArn": f"arn:aws:iam::{self.account_number}:role/{role_name}", "RoleSessionName": role_session, @@ -610,10 +654,26 @@ def _get_route53_zone_resources( .build_full_result() ) - def get_route53_zones(self): + def get_route53_zones(self, **kwargs): """Retrieve Route 53 Zones and emit seeds.""" - client: Route53Client = self.get_aws_client(service=AwsServices.ROUTE53_ZONES) - label = self.format_label(SeedLabel.ROUTE53_ZONES) + # TODO: how to pass in cred,region? + credential = kwargs["credential"] + region = kwargs["region"] + + botocred = self.boto_cred( + region_name=region, + access_key=credential["access_key"], + secret_key=credential["secret_key"], + ) + client: Route53Client = self.get_aws_client( + service=AwsServices.ROUTE53_ZONES, credentials=botocred + ) + label = self.format_label( + SeedLabel.ROUTE53_ZONES, + region=region, + account_number=credential["account_number"], + ) + has_added_seeds = False try: zones = self._get_route53_zone_hosts(client) @@ -631,8 +691,9 @@ def get_route53_zones(self): id = zone.get("Id") resource_sets = self._get_route53_zone_resources(client, id) for resource_set in resource_sets.get("ResourceRecordSets", []): - if resource_set.get("Type") not in VALID_RECORD_TYPES: - continue + # Note: localstack creates 2 entries per hosted zone. (remember this if stats are "off") + # if resource_set.get("Type") not in VALID_RECORD_TYPES: + # continue # turned off so localstack things show up domain_name = resource_set.get("Name").rstrip(".") with SuppressValidationError(): @@ -700,26 +761,42 @@ def get_s3_region(self, client: S3Client, bucket: str) -> str: location = client.get_bucket_location(Bucket=bucket)["LocationConstraint"] return location or "us-east-1" - def get_s3_instances(self): + def get_s3_instances(self, **kwargs): """Retrieve Simple Storage Service data and emit seeds.""" - client: S3Client = self.get_aws_client(service=AwsServices.STORAGE_BUCKET) + # TODO: how to pass in cred,region? + credential = kwargs["credential"] + region = kwargs["region"] + + botocred = self.boto_cred( + region_name=region, + access_key=credential["access_key"], + secret_key=credential["secret_key"], + ) + client: S3Client = self.get_aws_client( + service=AwsServices.STORAGE_BUCKET, credentials=botocred + ) try: data = client.list_buckets().get("Buckets", []) + for bucket in data: bucket_name = bucket.get("Name") if not bucket_name: continue - region = self.get_s3_region(client, bucket_name) - label = self.format_label(SeedLabel.STORAGE_BUCKET, region) + lookup_region = self.get_s3_region(client, bucket_name) + label = self.format_label( + SeedLabel.STORAGE_BUCKET, + region=region, + account_number=credential["account_number"], + ) with SuppressValidationError(): bucket_asset = AwsStorageBucketAsset( - value=AwsStorageBucketAsset.url(bucket_name, region), + value=AwsStorageBucketAsset.url(bucket_name, lookup_region), uid=label, scan_data={ - "accountNumber": self.account_number, + "accountNumber": credential["account_number"], }, ) self.add_cloud_asset( diff --git a/src/censys/cloud_connectors/common/connector.py b/src/censys/cloud_connectors/common/connector.py index 1e8bcd1..3f3943c 100644 --- a/src/censys/cloud_connectors/common/connector.py +++ b/src/censys/cloud_connectors/common/connector.py @@ -1,4 +1,5 @@ """Base class for all cloud connectors.""" +import time from abc import ABC, abstractmethod from collections import defaultdict from enum import Enum @@ -67,6 +68,7 @@ def __init__(self, settings: Settings): self.cloud_assets = defaultdict(set) self.current_service = None + # TODO: how to pass in cred,region? (each scanner will have diff things to pass in) def delete_seeds_by_label(self, label: str): """Replace seeds for [label] with an empty list. @@ -81,7 +83,7 @@ def delete_seeds_by_label(self, label: str): self.logger.info(f"Deleted any seeds for label {label}.") self.dispatch_event(EventTypeEnum.SEEDS_DELETED, label=label) - def get_seeds(self) -> None: + def get_seeds(self, **kwargs) -> None: """Gather seeds.""" for seed_type, seed_scanner in self.seed_scanners.items(): self.current_service = seed_type @@ -92,10 +94,11 @@ def get_seeds(self) -> None: self.logger.debug(f"Skipping {seed_type}") continue self.logger.debug(f"Scanning {seed_type}") - seed_scanner() + seed_scanner(**kwargs) self.current_service = None - def get_cloud_assets(self) -> None: + # TODO: how to pass in cred,region? (each scanner will have diff things to pass in) + def get_cloud_assets(self, **kwargs) -> None: """Gather cloud assets.""" for cloud_asset_type, cloud_asset_scanner in self.cloud_asset_scanners.items(): self.current_service = cloud_asset_type @@ -106,7 +109,7 @@ def get_cloud_assets(self) -> None: self.logger.debug(f"Skipping {cloud_asset_type}") continue self.logger.debug(f"Scanning {cloud_asset_type}") - cloud_asset_scanner() + cloud_asset_scanner(**kwargs) self.current_service = None def get_event_context( @@ -127,6 +130,7 @@ def get_event_context( "event_type": event_type, "connector": self, "provider": self.provider, + # service=None, this uses the self.current_service for the value "service": service or self.current_service, } @@ -237,11 +241,11 @@ def submit_cloud_assets_wrapper(self): # pragma: no cover self.submit_cloud_assets() self.clear() - def scan_seeds(self): + def scan_seeds(self, **kwargs): """Scan the seeds.""" self.logger.info("Gathering seeds...") self.dispatch_event(EventTypeEnum.SCAN_STARTED) - self.get_seeds() + self.get_seeds(**kwargs) self.submit_seeds_wrapper() self.dispatch_event(EventTypeEnum.SCAN_FINISHED) @@ -253,12 +257,13 @@ def scan_cloud_assets(self): self.submit_cloud_assets_wrapper() self.dispatch_event(EventTypeEnum.SCAN_FINISHED) - def scan(self): + # TODO: how to pass in cred,region? (each scanner will have diff things to pass in) + def scan(self, **kwargs): """Scan the seeds and cloud assets.""" self.logger.info("Gathering seeds and cloud assets...") self.dispatch_event(EventTypeEnum.SCAN_STARTED) - self.get_seeds() - self.get_cloud_assets() + self.get_seeds(**kwargs) + self.get_cloud_assets(**kwargs) self.submit() self.dispatch_event(EventTypeEnum.SCAN_FINISHED) diff --git a/src/censys/cloud_connectors/plugins/aws_tags.py b/src/censys/cloud_connectors/plugins/aws_tags.py index 9cbb112..afdca62 100644 --- a/src/censys/cloud_connectors/plugins/aws_tags.py +++ b/src/censys/cloud_connectors/plugins/aws_tags.py @@ -134,6 +134,7 @@ def on_add_cloud_asset( AwsResourceTypes.STORAGE_BUCKET: self._get_storage_bucket_tags, } service: Optional[AwsResourceTypes] = context.get("service") # type: ignore + # service = none, context.get('service') value ins't set if service in tag_retrieval_handlers: try: tag_retrieval_handlers[service](context, cloud_asset, **kwargs) @@ -434,10 +435,12 @@ def _get_storage_bucket_tags( bucket_name = kwargs.get("bucket_name") client: S3Client = kwargs.get("aws_client") # type: ignore if not bucket_name or not client: + print(f"tags bucket: {bucket_name} len:0") return try: tag_set = client.get_bucket_tagging(Bucket=bucket_name).get("TagSet", []) + print(f"tags bucket: {bucket_name} len:{len(tag_set)}") if not tag_set: return @@ -450,6 +453,7 @@ def _get_storage_bucket_tags( self.add_cloud_asset_tags(context, cloud_asset, filtered_tags) # type: ignore except ClientError as e: + print(f"tags bucket: {bucket_name} len:0") # If there are no tag sets, it will raise a ClientError with the code "NoSuchTagSet" if e.response.get("Error", {}).get("Code") == "NoSuchTagSet": return From 86912ceb04626187a192b87c6989b822073a21b6 Mon Sep 17 00:00:00 2001 From: Eric Butera Date: Thu, 10 Aug 2023 14:52:46 -0400 Subject: [PATCH 03/13] WIP - aws_endpoint_url setting - add max proc setting - test scan_all pool fix --- .../cloud_connectors/aws_connector/connector.py | 16 ++++++++++------ src/censys/cloud_connectors/common/settings.py | 10 ++++++++++ tests/data/default_settings.json | 3 ++- tests/test_aws_connector.py | 10 ++++++++++ 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/src/censys/cloud_connectors/aws_connector/connector.py b/src/censys/cloud_connectors/aws_connector/connector.py index be7e06d..c05ca0c 100644 --- a/src/censys/cloud_connectors/aws_connector/connector.py +++ b/src/censys/cloud_connectors/aws_connector/connector.py @@ -44,7 +44,6 @@ VALID_RECORD_TYPES = ["A", "CNAME"] IGNORED_TAGS = ["censys-cloud-connector-ignore"] -MAX_PROC = 4 # TODO: .env settings class AwsCloudConnector(CloudConnector): @@ -71,6 +70,7 @@ class AwsCloudConnector(CloudConnector): # Current set of ignored tags (combined set of user settings + overall settings) ignored_tags: list[str] global_ignored_tags: set[str] + pool: Pool def __init__(self, settings: Settings): """Initialize AWS Cloud Connectors. @@ -93,6 +93,7 @@ def __init__(self, settings: Settings): self.ignored_tags = [] self.global_ignored_tags: set[str] = set(IGNORED_TAGS) + self.pool = Pool(processes=settings.scan_concurrency) def scan_seeds(self, **kwargs): """Scan AWS for seeds.""" @@ -114,7 +115,9 @@ def scan_all(self): tuple, AwsSpecificSettings ] = self.settings.providers.get(self.provider, {}) - pool = Pool(processes=MAX_PROC) + self.logger.debug( + f"scanning AWS using {self.settings.scan_concurrency} processes" + ) for provider_setting in provider_settings.values(): self.provider_settings = provider_setting @@ -140,7 +143,7 @@ def scan_all(self): ], # self.account_number, }, ): - self.logger.info( + self.logger.debug( "starting pool account:%s region:%s", credential["account_number"], region, @@ -251,9 +254,10 @@ def get_aws_client( """ try: credentials = credentials or self.credentials() - credentials[ - "endpoint_url" - ] = "http://localhost.localstack.cloud:4566" # TODO: env settings + + if self.settings.aws_endpoint_url: + credentials["endpoint_url"] = self.settings.aws_endpoint_url + if credentials.get("aws_access_key_id"): self.logger.debug(f"AWS Service {service} using access key credentials") return boto3.client(service, **credentials) # type: ignore diff --git a/src/censys/cloud_connectors/common/settings.py b/src/censys/cloud_connectors/common/settings.py index 2544534..94f3a69 100644 --- a/src/censys/cloud_connectors/common/settings.py +++ b/src/censys/cloud_connectors/common/settings.py @@ -189,6 +189,16 @@ class Settings(BaseSettings): env="AZURE_REFRESH_ALL_REGIONS", description="Scan all available Azure regions", ) + aws_endpoint_url: str = Field( + default="", + env="AWS_ENDPOINT_URL", + description="AWS endpoint url override (for testing)", + ) + scan_concurrency: int = Field( + default=1, + env="SCAN_CONCURRENCY", + description="Maximum number of concurrent scans", + ) # Verification timeout validation_timeout: int = Field( diff --git a/tests/data/default_settings.json b/tests/data/default_settings.json index 08012f7..1ac6d33 100644 --- a/tests/data/default_settings.json +++ b/tests/data/default_settings.json @@ -4,5 +4,6 @@ "censys_cookies": { "session": "test-censys-session-xxxxxxxxxxxxxxxx" }, - "validation_timeout": 5 + "validation_timeout": 5, + "aws_endpoint_url": "" } diff --git a/tests/test_aws_connector.py b/tests/test_aws_connector.py index 25b060e..576efb3 100644 --- a/tests/test_aws_connector.py +++ b/tests/test_aws_connector.py @@ -121,6 +121,16 @@ def test_get_aws_client(self): aws_secret_access_key=self.connector.provider_settings.secret_key, ) + def test_endpoint_url_override(self): + endpoint_url = "test-endpoint-url" + expected = {"aws_access_key_id": "test", "endpoint_url": endpoint_url} + self.settings.aws_endpoint_url = endpoint_url + mock_client = self.mocker.patch("boto3.client", autospec=True) + self.connector.get_aws_client( + AwsServices.API_GATEWAY, {"aws_access_key_id": "test"} + ) + mock_client.assert_called_with(AwsServices.API_GATEWAY, **expected) + def test_get_aws_client_uses_override_credentials(self): service = AwsServices.API_GATEWAY expected = self.data["TEST_BOTO_CRED_FULL"] From 74d2bfdee13a10b21a0b2354f693073f5654cbd9 Mon Sep 17 00:00:00 2001 From: Eric Butera Date: Fri, 11 Aug 2023 18:39:13 -0400 Subject: [PATCH 04/13] WIP - pool can't be pickled, need to figure out self.pool issue - introduced AwsScanContext to manage state of outer loop in an easier way inside workers - so many TODOs for refactoring class state into context --- .../aws_connector/connector.py | 254 +++++++++++++----- 1 file changed, 190 insertions(+), 64 deletions(-) diff --git a/src/censys/cloud_connectors/aws_connector/connector.py b/src/censys/cloud_connectors/aws_connector/connector.py index c05ca0c..00203fc 100644 --- a/src/censys/cloud_connectors/aws_connector/connector.py +++ b/src/censys/cloud_connectors/aws_connector/connector.py @@ -1,6 +1,7 @@ """AWS Cloud Connector.""" import contextlib from collections.abc import Generator, Sequence +from dataclasses import dataclass from multiprocessing import Pool from typing import Any, Optional, TypeVar, Union @@ -45,6 +46,26 @@ VALID_RECORD_TYPES = ["A", "CNAME"] IGNORED_TAGS = ["censys-cloud-connector-ignore"] +# TODO: fix self.{property} references: +# This has to happen because if the worker pool spawns multiple account + regions, each worker will change the self.{property} value, thus making each process scan the SAME account. +# +# instead of changing everything everywhere, perhaps a data structure can handle this? +# make a dictionary of provider-setting-key (which is account + region) +# then inside scan use self.scan_contexts[provider-setting-key] = {...} + + +@dataclass +class AwsScanContext: + """Required configuration context for scan().""" + + provider_settings: AwsSpecificSettings + temp_sts_cred: Optional[dict] + botocred: dict + credential: dict + account_number: str + region: str + ignored_tags: list[str] + class AwsCloudConnector(CloudConnector): """AWS Cloud Connector. @@ -55,10 +76,26 @@ class AwsCloudConnector(CloudConnector): """ provider = ProviderEnum.AWS + + # During a run this will be set to the current account being scanned + # It is common to have multiple top level accounts in providers.yml provider_settings: AwsSpecificSettings + # workaround for storing multiple configurations during a scan() call + # multiprocessing dictates that each worker runs scan in a different process + # each process will share the same AwsCloudConnector instance + # if a worker sets a self property, that is updated for _all_ workers + # therefore, make a dict that each worker can reference it's unique account+region configuration + # + # each scan_contexts entry will have a unique key so that multiple accounts and regions can be scanned in parallel + # scan_config_entry = { + # "temp_sts_cred": {}, "account_number": "", "region": "", "ignored_tags":[], credential: {} + # } + scan_contexts: dict[str, AwsScanContext] = {} + # Temporary STS credentials created with Assume Role will be stored here during # a connector scan. + # TODO: fix self.temp_sts_cred temp_sts_cred: Optional[dict] = None # When scanning, the current loaded credential will be set here. @@ -70,7 +107,7 @@ class AwsCloudConnector(CloudConnector): # Current set of ignored tags (combined set of user settings + overall settings) ignored_tags: list[str] global_ignored_tags: set[str] - pool: Pool + # pool: Pool def __init__(self, settings: Settings): """Initialize AWS Cloud Connectors. @@ -91,22 +128,35 @@ def __init__(self, settings: Settings): AwsResourceTypes.STORAGE_BUCKET: self.get_s3_instances, } + # TODO: fix self.ignored_tags self.ignored_tags = [] self.global_ignored_tags: set[str] = set(IGNORED_TAGS) - self.pool = Pool(processes=settings.scan_concurrency) + # self.pool = Pool(processes=settings.scan_concurrency) def scan_seeds(self, **kwargs): - """Scan AWS for seeds.""" - credential = kwargs.get("credential") - region = kwargs.get("region") + """Scan AWS.""" + # credential = kwargs.get("credential") + # x region = kwargs.get("region") + # self.logger.info( + # f"Scanning AWS account {credential['account_number']} in region {region}" + # ) + + scan_context_key = kwargs["scan_context_key"] + scan_context = kwargs["scan_context"] + # this is here because setting it outside of scan was causing a race condition where it didn't exist when accessed + # must be something to do with pool & async add + self.scan_contexts[scan_context_key] = scan_context self.logger.info( - f"Scanning AWS account {credential['account_number']} in region {region}" + f"Scanning AWS account {scan_context['account_number']} in region {scan_context['region']}" ) - return super().scan_seeds(**kwargs) + super().scan_seeds(**kwargs) def scan_cloud_assets(self, **kwargs): """Scan AWS for cloud assets.""" - self.logger.info(f"Scanning AWS account {self.account_number}") + scan_context_key = kwargs["scan_context_key"] + scan_context = kwargs["scan_context"] + self.scan_contexts[scan_context_key] = scan_context + self.logger.info(f"Scanning AWS account {scan_context['account_number']}") super().scan_cloud_assets(**kwargs) def scan_all(self): @@ -119,18 +169,29 @@ def scan_all(self): f"scanning AWS using {self.settings.scan_concurrency} processes" ) + pool = Pool(processes=self.settings.scan_concurrency) + for provider_setting in provider_settings.values(): + # `provider_setting` is a specific top-level AwsAccount entry in providers.yml + # TODO: provider_settings should really be passed into scan :/ self.provider_settings = provider_setting + self.scan_contexts = {} for credential in self.provider_settings.get_credentials(): + # TODO: fix self.credential # self.credential = credential + # TODO: fix self.account_number # self.account_number = credential["account_number"] - # TODO: this wont work if using a pool (no self!) - self.ignored_tags = self.get_ignored_tags(credential["ignore_tags"]) + # TODO: fix self.ignored_tags + ignored_tags = self.get_ignored_tags(credential["ignore_tags"]) + # TODO: this wont work if using a pool (no self!) + self.ignored_tags = ignored_tags # for each account + region combination, run each seed scanner for region in self.provider_settings.regions: + # TODO: fix self.temp_sts_cred self.temp_sts_cred = None + # TODO: fix self.region self.region = region try: with Healthcheck( @@ -148,15 +209,42 @@ def scan_all(self): credential["account_number"], region, ) - # currently doesn't work because `self.` references clobber each other - # example: region gets unset after first iteration and program crashes + + # TODO: this might not work (how does timeout/renewal of creds work?) + # i really don't like this, put it in the scan_contexts[provider-setting-key] = {...} + botocred = self.boto_cred( + region_name=region, + access_key=credential["access_key"], + secret_key=credential["secret_key"], + # TODO: what is session_token again? + # session_token=provider_setting.session_token, + ) + + scan_context_key = credential["account_number"] + region + scan_context = AwsScanContext( + provider_settings=provider_setting, + temp_sts_cred=None, + credential=credential, + botocred=botocred, + account_number=credential["account_number"], + region=region, + ignored_tags=ignored_tags, + ) + + # self.pool.apply_async( pool.apply_async( self.scan_seeds, kwds={ - "credential": credential, - "region": region, + # TODO remove all of this except `scan_context_key` + # "provider_setting": provider_setting, + # "botocred": botocred, + # "credential": credential, + # "region": region, + "scan_context_key": scan_context_key, + "scan_context": scan_context, }, ) + # self.logger.info(f"asyn res {x}") # self.scan(**kwargs) except Exception as e: self.logger.error( @@ -209,6 +297,7 @@ def format_label( """ # TODO: s/self.account_number/account_number <- use param account = account_number or self.account_number + # TODO: fix self.region region = region or self.region region_label = f"/{region}" if region != "" else "" return f"AWS: {service} - {account}{region_label}" @@ -225,12 +314,14 @@ def credentials(self) -> dict: # Role name is the credential field which causes STS to activate. # Once activated the temporary STS creds will be used by all # subsequent AWS service client calls. + # TODO: fix self.credential if role_name := self.credential.get("role_name"): self.logger.debug(f"Using STS for role {role_name}") return self.get_assume_role_credentials(role_name) self.logger.debug("Using provider settings credentials") return self.boto_cred( + # TODO: fix self.region self.region, self.provider_settings.access_key, self.provider_settings.secret_key, @@ -286,12 +377,15 @@ def get_assume_role_credentials(self, role_name: Optional[str] = None) -> dict: Raises: Exception: If the credentials could not be created. """ + # TODO: fix self.temp_sts_cred if self.temp_sts_cred: self.logger.debug("Using cached temporary STS credentials") else: try: temp_creds = self.assume_role(role_name) + # TODO: fix self.temp_sts_cred self.temp_sts_cred = self.boto_cred( + # TODO: fix self.region self.region, temp_creds["AccessKeyId"], temp_creds["SecretAccessKey"], @@ -304,6 +398,7 @@ def get_assume_role_credentials(self, role_name: Optional[str] = None) -> dict: self.logger.error(f"Failed to assume role: {e}") raise + # TODO: fix self.temp_sts_cred return self.temp_sts_cred def boto_cred( @@ -360,8 +455,10 @@ def assume_role( Returns: CredentialsTypeDef: Temporary credentials. """ + # TODO: verify this works with worker pool change- Always use the primary account credentials to query STS # use primary account's credentials to query STS for temp creds credentials = self.boto_cred( + # TODO: fix self.region self.region, self.provider_settings.access_key, self.provider_settings.secret_key, @@ -373,7 +470,9 @@ def assume_role( ) role_session = ( - self.credential["role_session_name"] or AwsDefaults.ROLE_SESSION_NAME.value + # FIX self.credential + self.credential["role_session_name"] + or AwsDefaults.ROLE_SESSION_NAME.value ) # TODO: s/self.account_number/cred[account_number] <- pass in account number role: dict[str, Any] = { @@ -388,13 +487,19 @@ def assume_role( ) return temp_creds["Credentials"] - def get_api_gateway_domains_v1(self): + def get_api_gateway_domains_v1(self, **kwargs): """Retrieve all API Gateway V1 domains and emit seeds.""" - client: APIGatewayClient = self.get_aws_client(service=AwsServices.API_GATEWAY) + key = kwargs["scan_context_key"] + ctx: AwsScanContext = self.scan_contexts[key] + + client: APIGatewayClient = self.get_aws_client( + service=AwsServices.API_GATEWAY, credentials=ctx.botocred + ) label = self.format_label(SeedLabel.API_GATEWAY) try: apis = client.get_rest_apis() for domain in apis.get("items", []): + # TODO: fix self.region domain_name = f"{domain['id']}.execute-api.{self.region}.amazonaws.com" # TODO: emit log when a seeds is dropped due to validation error with SuppressValidationError(): @@ -403,10 +508,13 @@ def get_api_gateway_domains_v1(self): except ClientError as e: self.logger.error(f"Could not connect to API Gateway V1. Error: {e}") - def get_api_gateway_domains_v2(self): + def get_api_gateway_domains_v2(self, **kwargs): """Retrieve API Gateway V2 domains and emit seeds.""" + key = kwargs["scan_context_key"] + ctx: AwsScanContext = self.scan_contexts[key] + client: ApiGatewayV2Client = self.get_aws_client( - service=AwsServices.API_GATEWAY_V2 + service=AwsServices.API_GATEWAY_V2, credentials=ctx.botocred ) label = self.format_label(SeedLabel.API_GATEWAY) try: @@ -419,18 +527,22 @@ def get_api_gateway_domains_v2(self): except ClientError as e: self.logger.error(f"Could not connect to API Gateway V2. Error: {e}") - def get_api_gateway_domains(self): + def get_api_gateway_domains(self, **kwargs): """Retrieve all versions of Api Gateway data and emit seeds.""" - self.get_api_gateway_domains_v1() - self.get_api_gateway_domains_v2() + self.get_api_gateway_domains_v1(**kwargs) + self.get_api_gateway_domains_v2(**kwargs) label = self.format_label(SeedLabel.API_GATEWAY) if not self.seeds.get(label): self.delete_seeds_by_label(label) - def get_load_balancers_v1(self): + def get_load_balancers_v1(self, **kwargs): """Retrieve Elastic Load Balancers (ELB) V1 data and emit seeds.""" + key = kwargs["scan_context_key"] + ctx: AwsScanContext = self.scan_contexts[key] + client: ElasticLoadBalancingClient = self.get_aws_client( - service=AwsServices.LOAD_BALANCER + service=AwsServices.LOAD_BALANCER, + credentials=ctx.botocred, ) label = self.format_label(SeedLabel.LOAD_BALANCER) try: @@ -443,10 +555,13 @@ def get_load_balancers_v1(self): except ClientError as e: self.logger.error(f"Could not connect to ELB V1. Error: {e}") - def get_load_balancers_v2(self): + def get_load_balancers_v2(self, **kwargs): """Retrieve Elastic Load Balancers (ELB) V2 data and emit seeds.""" + key = kwargs["scan_context_key"] + ctx: AwsScanContext = self.scan_contexts[key] + client: ElasticLoadBalancingv2Client = self.get_aws_client( - service=AwsServices.LOAD_BALANCER_V2 + service=AwsServices.LOAD_BALANCER_V2, credentials=ctx.botocred ) label = self.format_label(SeedLabel.LOAD_BALANCER) try: @@ -459,24 +574,33 @@ def get_load_balancers_v2(self): except ClientError as e: self.logger.error(f"Could not connect to ELB V2. Error: {e}") - def get_load_balancers(self): + def get_load_balancers(self, **kwargs): """Retrieve Elastic Load Balancers (ELB) data and emit seeds.""" - self.get_load_balancers_v1() - self.get_load_balancers_v2() + self.get_load_balancers_v1(**kwargs) + self.get_load_balancers_v2(**kwargs) label = self.format_label(SeedLabel.LOAD_BALANCER) if not self.seeds.get(label): self.delete_seeds_by_label(label) - def get_network_interfaces(self): + def get_network_interfaces(self, **kwargs): """Retrieve EC2 Elastic Network Interfaces (ENI) data and emit seeds.""" + key = kwargs["scan_context_key"] + ctx: AwsScanContext = self.scan_contexts[key] + try: - interfaces = self.describe_network_interfaces() + interfaces = self.describe_network_interfaces(ctx.botocred) except ClientError as e: self.logger.error(f"Could not connect to ENI Service. Error: {e}") return label = self.format_label(SeedLabel.NETWORK_INTERFACE) has_added_seeds = False - instance_tags, instance_tag_sets = self.get_resource_tags() + + interfaces = self.describe_network_interfaces() + # this looks like a bug not passing in a resource type + ( + instance_tags, + instance_tag_sets, + ) = self.get_resource_tags(ctx.botocred) for ip_address, record in interfaces.items(): instance_id = record["InstanceId"] @@ -494,7 +618,7 @@ def get_network_interfaces(self): if not has_added_seeds: self.delete_seeds_by_label(label) - def describe_network_interfaces(self) -> dict: + def describe_network_interfaces(self, botocred: dict) -> dict: """Retrieve EC2 Elastic Network Interfaces (ENI) data. Raises: @@ -503,7 +627,8 @@ def describe_network_interfaces(self) -> dict: Returns: dict: Network Interfaces. """ - ec2: EC2Client = self.get_aws_client(AwsServices.EC2) + # TODO pass in scan_contexts + ec2: EC2Client = self.get_aws_client(AwsServices.EC2, credentials=botocred) interfaces: dict[str, dict[str, Union[None, str, list]]] = {} # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html#EC2.Client.describe_network_interfaces @@ -536,7 +661,7 @@ def describe_network_interfaces(self) -> dict: return interfaces def get_resource_tags_paginated( - self, resource_types: Optional[list[str]] = None + self, botocred: dict, resource_types: Optional[list[str]] = None ) -> Generator[TagDescriptionTypeDef, None, None]: """Retrieve EC2 resource tags paginated. @@ -546,7 +671,8 @@ def get_resource_tags_paginated( Yields: Generator[TagDescriptionTypeDef]: Tags. """ - ec2: EC2Client = self.get_aws_client(AwsServices.EC2) + # TODO pass in ctx + ec2: EC2Client = self.get_aws_client(AwsServices.EC2, credentials=botocred) paginator = ec2.get_paginator( "describe_tags", ) @@ -560,7 +686,7 @@ def get_resource_tags_paginated( yield from tags def get_resource_tags( - self, resource_types: Optional[list[str]] = None + self, botocred: dict, resource_types: Optional[list[str]] = None ) -> tuple[dict, dict]: """Get EC2 resource tags based on resource types. @@ -573,7 +699,7 @@ def get_resource_tags( resource_tags: dict = {} resource_tag_sets: dict = {} - for tag in self.get_resource_tags_paginated(resource_types): + for tag in self.get_resource_tags_paginated(resource_types, botocred): # Tags come in two formats: # 1. Tag = { Key = "Name", Value = "actual-tag-name" } # 2. Tag = { Key = "actual-key-name", Value = "tag-value-that-is-unused-here"} @@ -605,11 +731,16 @@ def network_interfaces_ignored_tags(self, data: NetworkInterfaceTypeDef) -> bool tags = self.extract_tags_from_tagset(tag_set) return self.has_ignored_tag(tags) - def get_rds_instances(self): + def get_rds_instances(self, **kwargs): """Retrieve Relational Database Services (RDS) data and emit seeds.""" - client: RDSClient = self.get_aws_client(service=AwsServices.RDS) + key = kwargs["scan_context_key"] + ctx: AwsScanContext = self.scan_contexts[key] + client: RDSClient = self.get_aws_client( + service=AwsServices.RDS, credentials=ctx.botocred + ) label = self.format_label(SeedLabel.RDS) has_added_seeds = False + try: data = client.describe_db_instances() for instance in data.get("DBInstances", []): @@ -661,21 +792,16 @@ def _get_route53_zone_resources( def get_route53_zones(self, **kwargs): """Retrieve Route 53 Zones and emit seeds.""" # TODO: how to pass in cred,region? - credential = kwargs["credential"] - region = kwargs["region"] + key = kwargs["scan_context_key"] + ctx: AwsScanContext = self.scan_contexts[key] - botocred = self.boto_cred( - region_name=region, - access_key=credential["access_key"], - secret_key=credential["secret_key"], - ) client: Route53Client = self.get_aws_client( - service=AwsServices.ROUTE53_ZONES, credentials=botocred + service=AwsServices.ROUTE53_ZONES, credentials=ctx.botocred ) label = self.format_label( SeedLabel.ROUTE53_ZONES, - region=region, - account_number=credential["account_number"], + region=ctx.region, + account_number=ctx.credential["account_number"], ) has_added_seeds = False @@ -711,12 +837,16 @@ def get_route53_zones(self, **kwargs): except ClientError as e: self.logger.error(f"Could not connect to Route 53 Zones. Error: {e}") - def get_ecs_instances(self): + def get_ecs_instances(self, **kwargs): """Retrieve Elastic Container Service data and emit seeds.""" - ecs: ECSClient = self.get_aws_client(AwsServices.ECS) - ec2: EC2Client = self.get_aws_client(AwsServices.EC2) + key = kwargs["scan_context_key"] + ctx: AwsScanContext = self.scan_contexts[key] + + ecs: ECSClient = self.get_aws_client(AwsServices.ECS, credentials=ctx.botocred) + ec2: EC2Client = self.get_aws_client(AwsServices.EC2, credentials=ctx.botocred) label = self.format_label(SeedLabel.ECS) has_added_seeds = False + try: clusters = ecs.list_clusters() for cluster in clusters.get("clusterArns", []): @@ -768,16 +898,11 @@ def get_s3_region(self, client: S3Client, bucket: str) -> str: def get_s3_instances(self, **kwargs): """Retrieve Simple Storage Service data and emit seeds.""" # TODO: how to pass in cred,region? - credential = kwargs["credential"] - region = kwargs["region"] + key = kwargs["scan_context_key"] + ctx: AwsScanContext = self.scan_contexts[key] - botocred = self.boto_cred( - region_name=region, - access_key=credential["access_key"], - secret_key=credential["secret_key"], - ) client: S3Client = self.get_aws_client( - service=AwsServices.STORAGE_BUCKET, credentials=botocred + service=AwsServices.STORAGE_BUCKET, credentials=ctx.botocred ) try: @@ -791,8 +916,8 @@ def get_s3_instances(self, **kwargs): lookup_region = self.get_s3_region(client, bucket_name) label = self.format_label( SeedLabel.STORAGE_BUCKET, - region=region, - account_number=credential["account_number"], + region=ctx.region, + account_number=ctx.account_number, ) with SuppressValidationError(): @@ -800,7 +925,7 @@ def get_s3_instances(self, **kwargs): value=AwsStorageBucketAsset.url(bucket_name, lookup_region), uid=label, scan_data={ - "accountNumber": credential["account_number"], + "accountNumber": ctx.account_number, }, ) self.add_cloud_asset( @@ -835,6 +960,7 @@ def has_ignored_tag(self, tags: list[str]) -> bool: Returns: bool: If the list contains an ignored tag. """ + # TODO: fix self.ignored_tags return any(tag in self.ignored_tags for tag in tags) def extract_tags_from_tagset(self, tag_set: list[TagTypeDef]) -> list[str]: From dbfa3c268ba9cccf9ca3972616c23dc9121b900a Mon Sep 17 00:00:00 2001 From: Eric Butera Date: Mon, 14 Aug 2023 09:01:46 -0400 Subject: [PATCH 05/13] WIP - scan context everywhere --- .../aws_connector/connector.py | 380 +++++++++--------- .../cloud_connectors/common/connector.py | 4 +- .../cloud_connectors/common/healthcheck.py | 1 + 3 files changed, 190 insertions(+), 195 deletions(-) diff --git a/src/censys/cloud_connectors/aws_connector/connector.py b/src/censys/cloud_connectors/aws_connector/connector.py index 00203fc..f8f21e6 100644 --- a/src/censys/cloud_connectors/aws_connector/connector.py +++ b/src/censys/cloud_connectors/aws_connector/connector.py @@ -2,6 +2,7 @@ import contextlib from collections.abc import Generator, Sequence from dataclasses import dataclass +from logging import Logger from multiprocessing import Pool from typing import Any, Optional, TypeVar, Union @@ -38,6 +39,7 @@ from censys.cloud_connectors.common.context import SuppressValidationError from censys.cloud_connectors.common.enums import EventTypeEnum, ProviderEnum from censys.cloud_connectors.common.healthcheck import Healthcheck +from censys.cloud_connectors.common.logger import get_logger from censys.cloud_connectors.common.seed import DomainSeed, IpSeed from censys.cloud_connectors.common.settings import Settings @@ -53,18 +55,23 @@ # make a dictionary of provider-setting-key (which is account + region) # then inside scan use self.scan_contexts[provider-setting-key] = {...} +# TODO: logging changes: +# add account + region, current resource-type + @dataclass class AwsScanContext: """Required configuration context for scan().""" provider_settings: AwsSpecificSettings - temp_sts_cred: Optional[dict] - botocred: dict + # temp_sts_cred: Optional[dict] credential: dict account_number: str region: str + + # ignored tags on for the current credential (aws account) ignored_tags: list[str] + logger: Logger class AwsCloudConnector(CloudConnector): @@ -77,10 +84,6 @@ class AwsCloudConnector(CloudConnector): provider = ProviderEnum.AWS - # During a run this will be set to the current account being scanned - # It is common to have multiple top level accounts in providers.yml - provider_settings: AwsSpecificSettings - # workaround for storing multiple configurations during a scan() call # multiprocessing dictates that each worker runs scan in a different process # each process will share the same AwsCloudConnector instance @@ -95,20 +98,20 @@ class AwsCloudConnector(CloudConnector): # Temporary STS credentials created with Assume Role will be stored here during # a connector scan. - # TODO: fix self.temp_sts_cred - temp_sts_cred: Optional[dict] = None - - # When scanning, the current loaded credential will be set here. - credential: dict = {} - - account_number: str - region: Optional[str] - - # Current set of ignored tags (combined set of user settings + overall settings) - ignored_tags: list[str] - global_ignored_tags: set[str] + # temp_sts_cred: Optional[dict] = None + # During a run this will be set to the current account being scanned + # It is common to have multiple top level accounts in providers.yml + # provider_settings: AwsSpecificSettings + # credential: dict = {} + # account_number: str + # region: Optional[str] + # ignored_tags: list[str] # pool: Pool + global_ignored_tags: set[ + str + ] # this can remain self. as it is global across all accounts + def __init__(self, settings: Settings): """Initialize AWS Cloud Connectors. @@ -127,36 +130,40 @@ def __init__(self, settings: Settings): self.cloud_asset_scanners = { AwsResourceTypes.STORAGE_BUCKET: self.get_s3_instances, } - - # TODO: fix self.ignored_tags - self.ignored_tags = [] self.global_ignored_tags: set[str] = set(IGNORED_TAGS) + # self.ignored_tags = [] # self.pool = Pool(processes=settings.scan_concurrency) def scan_seeds(self, **kwargs): """Scan AWS.""" - # credential = kwargs.get("credential") - # x region = kwargs.get("region") - # self.logger.info( - # f"Scanning AWS account {credential['account_number']} in region {region}" - # ) + # when scan() is called, it has been forked into a separate process (from scan_all) scan_context_key = kwargs["scan_context_key"] - scan_context = kwargs["scan_context"] - # this is here because setting it outside of scan was causing a race condition where it didn't exist when accessed - # must be something to do with pool & async add - self.scan_contexts[scan_context_key] = scan_context + scan_context: AwsScanContext = kwargs["scan_context"] + # scan_context must be set within scan(), otherwise race conditions happen where it doesn't exist when accessed + + # multiprocessing requires separate logger instances per process + # TODO: there is still something odd going on here, notice logger isnt being set to self.logger, but by even calling get_logger it "fixes" the sub-process log level from being WARNING to .env's DEBUG + logger = get_logger( + log_name=f"{self.provider.lower()}_cloud_connector", + level=self.settings.logging_level, + ) + scan_context.logger = logger + self.logger.info( - f"Scanning AWS account {scan_context['account_number']} in region {scan_context['region']}" + f"Scanning AWS - account:{scan_context.account_number} region:{scan_context.region}" ) + + self.scan_contexts[scan_context_key] = scan_context super().scan_seeds(**kwargs) def scan_cloud_assets(self, **kwargs): """Scan AWS for cloud assets.""" + # TODO: pull scan_seeds changes in after rebase scan_context_key = kwargs["scan_context_key"] scan_context = kwargs["scan_context"] self.scan_contexts[scan_context_key] = scan_context - self.logger.info(f"Scanning AWS account {scan_context['account_number']}") + self.logger.info(f"Scanning AWS account {scan_context.account_number}") super().scan_cloud_assets(**kwargs) def scan_all(self): @@ -172,97 +179,90 @@ def scan_all(self): pool = Pool(processes=self.settings.scan_concurrency) for provider_setting in provider_settings.values(): - # `provider_setting` is a specific top-level AwsAccount entry in providers.yml - # TODO: provider_settings should really be passed into scan :/ + # `provider_setting` represents a specific top-level AwsAccount entry in providers.yml + # + # DO NOT use provider_settings anywhere in this class! + # provider_settings exists for the parent CloudConnector self.provider_settings = provider_setting self.scan_contexts = {} - for credential in self.provider_settings.get_credentials(): - # TODO: fix self.credential + # for credential in self.provider_settings.get_credentials(): + for credential in provider_settings.get_credentials(): # self.credential = credential - # TODO: fix self.account_number # self.account_number = credential["account_number"] - - # TODO: fix self.ignored_tags + account_number = credential["account_number"] ignored_tags = self.get_ignored_tags(credential["ignore_tags"]) - # TODO: this wont work if using a pool (no self!) - self.ignored_tags = ignored_tags + # self.ignored_tags = ignored_tags + # for each account + region combination, run each seed scanner - for region in self.provider_settings.regions: - # TODO: fix self.temp_sts_cred - self.temp_sts_cred = None - # TODO: fix self.region - self.region = region + for region in provider_settings.regions: + # self.temp_sts_cred = None + # self.region = region try: with Healthcheck( self.settings, - provider_setting, + provider_settings, provider={ "region": region, - "account_number": credential[ - "account_number" - ], # self.account_number, + "account_number": account_number, + # self.account_number, }, ): self.logger.debug( "starting pool account:%s region:%s", - credential["account_number"], + account_number, region, ) - # TODO: this might not work (how does timeout/renewal of creds work?) - # i really don't like this, put it in the scan_contexts[provider-setting-key] = {...} - botocred = self.boto_cred( - region_name=region, - access_key=credential["access_key"], - secret_key=credential["secret_key"], - # TODO: what is session_token again? - # session_token=provider_setting.session_token, - ) - - scan_context_key = credential["account_number"] + region + # Credentials aren't exactly obvious how they work + # Assume role flow: + # - use the "primary account" access + secret key to "connect" + # - call STS assume role to get "temporary credentials" + # - temporary credentials can be used for all resource types in an account + region + # - note: creds expire after N time (hours?) + scan_context_key = f"{account_number}_{region}" scan_context = AwsScanContext( - provider_settings=provider_setting, - temp_sts_cred=None, + provider_settings=provider_settings, credential=credential, - botocred=botocred, - account_number=credential["account_number"], + account_number=account_number, region=region, ignored_tags=ignored_tags, + logger=None, + # temp_sts_cred=None, ) - # self.pool.apply_async( + # scan workflow: + # - get seeds + cloud-assets + tags-plugin + # - submit seeds + cloud-assets + print(f"scan_all logging level {self.logger.level}") pool.apply_async( self.scan_seeds, kwds={ # TODO remove all of this except `scan_context_key` # "provider_setting": provider_setting, - # "botocred": botocred, # "credential": credential, # "region": region, "scan_context_key": scan_context_key, "scan_context": scan_context, }, ) - # self.logger.info(f"asyn res {x}") # self.scan(**kwargs) except Exception as e: self.logger.error( f"Unable to scan account {credential['account_number']} in region {region}. Error: {e}" ) self.dispatch_event(EventTypeEnum.SCAN_FAILED, exception=e) - self.region = None + # self.region = None # for each account, run each cloud asset scanner try: - self.temp_sts_cred = None - self.region = None + # self.temp_sts_cred = None + # self.region = None with Healthcheck( self.settings, provider_setting, provider={"account_number": self.account_number}, ): - # self.scan_cloud_assets() pool.apply_async( self.scan_cloud_assets, kwds={ @@ -282,8 +282,8 @@ def scan_all(self): def format_label( self, service: AwsServices, + account_number: str, region: Optional[str] = None, - account_number: Optional[str] = None, ) -> str: """Format AWS label. @@ -295,14 +295,10 @@ def format_label( Returns: str: Formatted label. """ - # TODO: s/self.account_number/account_number <- use param - account = account_number or self.account_number - # TODO: fix self.region - region = region or self.region region_label = f"/{region}" if region != "" else "" - return f"AWS: {service} - {account}{region_label}" + return f"AWS: {service} - {account_number}{region_label}" - def credentials(self) -> dict: + def credentials(self, ctx: AwsScanContext) -> dict: """Generate required credentials for AWS. This method will attempt to use any active STS sessions before falling @@ -311,25 +307,31 @@ def credentials(self) -> dict: Returns: dict: Boto Credential format. """ + # Note: original design of credentials() was to lazily-load a connection if one + # didn't exist using a previous STS temporary credential. + # Role name is the credential field which causes STS to activate. # Once activated the temporary STS creds will be used by all # subsequent AWS service client calls. + # TODO: fix self.credential - if role_name := self.credential.get("role_name"): + if role_name := ctx.credential.get("role_name"): self.logger.debug(f"Using STS for role {role_name}") - return self.get_assume_role_credentials(role_name) + return self.get_assume_role_credentials(ctx) # (account_number, role_name) self.logger.debug("Using provider settings credentials") return self.boto_cred( - # TODO: fix self.region - self.region, - self.provider_settings.access_key, - self.provider_settings.secret_key, - self.provider_settings.session_token, + ctx.region, + ctx.provider_settings.access_key, + ctx.provider_settings.secret_key, + ctx.provider_settings.session_token, ) def get_aws_client( - self, service: AwsServices, credentials: Optional[dict] = None + self, + service: AwsServices, + ctx: AwsScanContext, + credentials: Optional[dict] = None, ) -> T: """Creates an AWS client for the provided service. @@ -344,7 +346,7 @@ def get_aws_client( T: An AWS boto3 client. """ try: - credentials = credentials or self.credentials() + credentials = credentials or self.credentials(ctx) if self.settings.aws_endpoint_url: credentials["endpoint_url"] = self.settings.aws_endpoint_url @@ -365,7 +367,13 @@ def get_aws_client( ) raise - def get_assume_role_credentials(self, role_name: Optional[str] = None) -> dict: + def get_assume_role_credentials( + self, + ctx: AwsScanContext, + # account_number: Optional[str] = None, # TODO: pass this in + # role_name: Optional[str] = None, + # role_session_name: Optional[str] = None, + ) -> dict: """Acquire temporary STS credentials and cache them for the duration of the scan. Args: @@ -377,29 +385,35 @@ def get_assume_role_credentials(self, role_name: Optional[str] = None) -> dict: Raises: Exception: If the credentials could not be created. """ - # TODO: fix self.temp_sts_cred - if self.temp_sts_cred: - self.logger.debug("Using cached temporary STS credentials") - else: - try: - temp_creds = self.assume_role(role_name) - # TODO: fix self.temp_sts_cred - self.temp_sts_cred = self.boto_cred( - # TODO: fix self.region - self.region, - temp_creds["AccessKeyId"], - temp_creds["SecretAccessKey"], - temp_creds["SessionToken"], - ) - self.logger.debug( - f"Created temporary STS credentials for role {role_name}" - ) - except Exception as e: - self.logger.error(f"Failed to assume role: {e}") - raise + role_name = ctx.credential["role_name"] + role_session_name = ctx.credential["role_session_name"] + # TODO: temp_sts_cred is removed, make sure this works (and doesnt "assume role" per seed - should only be on the resource type level) # TODO: fix self.temp_sts_cred - return self.temp_sts_cred + # if self.temp_sts_cred: + # self.logger.debug("Using cached temporary STS credentials") + # else: + try: + temp_creds = self.assume_role( + # account_number, region, role_name, role_session_name + ctx, + role_name=role_name, + role_session_name=role_session_name, + ) + # TODO: fix self.temp_sts_cred + # self.temp_sts_cred = self.boto_cred( + self.logger.debug(f"Created temporary STS credentials for role {role_name}") + return self.boto_cred( + ctx.region, + temp_creds["AccessKeyId"], + temp_creds["SecretAccessKey"], + temp_creds["SessionToken"], + ) + except Exception as e: + self.logger.error(f"Failed to assume role: {e}") + raise + # TODO: fix self.temp_sts_cred + # return self.temp_sts_cred def boto_cred( self, @@ -442,7 +456,12 @@ def boto_cred( return cred def assume_role( - self, role_name: Optional[str] = AwsDefaults.ROLE_NAME.value + self, + ctx: AwsScanContext, + # account_number: str, + # region: str, + role_name: Optional[str] = AwsDefaults.ROLE_NAME.value, + role_session_name: Optional[str] = AwsDefaults.ROLE_SESSION_NAME.value, ) -> CredentialsTypeDef: """Acquire temporary credentials generated by Secure Token Service (STS). @@ -450,7 +469,9 @@ def assume_role( the STS service. Args: + account_number(str): AWS account number. role_name (str, optional): Role name to assume. Defaults to "CensysCloudConnectorRole". + role_session_name (str, optional): Role session name. Defaults to "CensysCloudConnectorSession". Returns: CredentialsTypeDef: Temporary credentials. @@ -458,26 +479,19 @@ def assume_role( # TODO: verify this works with worker pool change- Always use the primary account credentials to query STS # use primary account's credentials to query STS for temp creds credentials = self.boto_cred( - # TODO: fix self.region - self.region, - self.provider_settings.access_key, - self.provider_settings.secret_key, - self.provider_settings.session_token, + ctx.region, + ctx.provider_settings.access_key, + ctx.provider_settings.secret_key, + ctx.provider_settings.session_token, ) client: STSClient = self.get_aws_client( AwsServices.SECURE_TOKEN_SERVICE, credentials=credentials, ) - role_session = ( - # FIX self.credential - self.credential["role_session_name"] - or AwsDefaults.ROLE_SESSION_NAME.value - ) - # TODO: s/self.account_number/cred[account_number] <- pass in account number role: dict[str, Any] = { - "RoleArn": f"arn:aws:iam::{self.account_number}:role/{role_name}", - "RoleSessionName": role_session, + "RoleArn": f"arn:aws:iam::{ctx.account_number}:role/{role_name}", + "RoleSessionName": role_session_name, } temp_creds = client.assume_role(**role) @@ -492,16 +506,13 @@ def get_api_gateway_domains_v1(self, **kwargs): key = kwargs["scan_context_key"] ctx: AwsScanContext = self.scan_contexts[key] - client: APIGatewayClient = self.get_aws_client( - service=AwsServices.API_GATEWAY, credentials=ctx.botocred - ) - label = self.format_label(SeedLabel.API_GATEWAY) + client: APIGatewayClient = self.get_aws_client(service=AwsServices.API_GATEWAY, ctx) + label = self.format_label(SeedLabel.API_GATEWAY, ctx.account_number, ctx.region) + try: apis = client.get_rest_apis() for domain in apis.get("items", []): - # TODO: fix self.region - domain_name = f"{domain['id']}.execute-api.{self.region}.amazonaws.com" - # TODO: emit log when a seeds is dropped due to validation error + domain_name = f"{domain['id']}.execute-api.{ctx.region}.amazonaws.com" with SuppressValidationError(): domain_seed = DomainSeed(value=domain_name, label=label) self.add_seed(domain_seed, api_gateway_res=domain) @@ -514,9 +525,10 @@ def get_api_gateway_domains_v2(self, **kwargs): ctx: AwsScanContext = self.scan_contexts[key] client: ApiGatewayV2Client = self.get_aws_client( - service=AwsServices.API_GATEWAY_V2, credentials=ctx.botocred + AwsServices.API_GATEWAY_V2, ctx ) - label = self.format_label(SeedLabel.API_GATEWAY) + label = self.format_label(SeedLabel.API_GATEWAY, ctx.account_number, ctx.region) + try: apis = client.get_apis() for domain in apis.get("Items", []): @@ -541,10 +553,10 @@ def get_load_balancers_v1(self, **kwargs): ctx: AwsScanContext = self.scan_contexts[key] client: ElasticLoadBalancingClient = self.get_aws_client( - service=AwsServices.LOAD_BALANCER, - credentials=ctx.botocred, + service=AwsServices.LOAD_BALANCER, ctx ) - label = self.format_label(SeedLabel.LOAD_BALANCER) + label = self.format_label(SeedLabel.LOAD_BALANCER, ctx.account_number, ctx.region) + try: data = client.describe_load_balancers() for elb in data.get("LoadBalancerDescriptions", []): @@ -561,9 +573,10 @@ def get_load_balancers_v2(self, **kwargs): ctx: AwsScanContext = self.scan_contexts[key] client: ElasticLoadBalancingv2Client = self.get_aws_client( - service=AwsServices.LOAD_BALANCER_V2, credentials=ctx.botocred + AwsServices.LOAD_BALANCER_V2, ctx ) - label = self.format_label(SeedLabel.LOAD_BALANCER) + label = self.format_label(SeedLabel.LOAD_BALANCER, ctx.account_number, ctx.region) + try: data = client.describe_load_balancers() for elb in data.get("LoadBalancers", []): @@ -587,25 +600,14 @@ def get_network_interfaces(self, **kwargs): key = kwargs["scan_context_key"] ctx: AwsScanContext = self.scan_contexts[key] - try: - interfaces = self.describe_network_interfaces(ctx.botocred) - except ClientError as e: - self.logger.error(f"Could not connect to ENI Service. Error: {e}") - return - label = self.format_label(SeedLabel.NETWORK_INTERFACE) - has_added_seeds = False - - interfaces = self.describe_network_interfaces() - # this looks like a bug not passing in a resource type - ( - instance_tags, - instance_tag_sets, - ) = self.get_resource_tags(ctx.botocred) + label = self.format_label(SeedLabel.NETWORK_INTERFACE, ctx.account_number, ctx.region) + interfaces = self.describe_network_interfaces(ctx) + instance_tags, instance_tag_sets = self.get_resource_tags(ctx) for ip_address, record in interfaces.items(): instance_id = record["InstanceId"] tags = instance_tags.get(instance_id) - if tags and self.has_ignored_tag(tags): + if tags and self.has_ignored_tag(ctx.ignored_tags, tags): self.logger.debug( f"Skipping ignored tag for network instance {ip_address}" ) @@ -618,7 +620,7 @@ def get_network_interfaces(self, **kwargs): if not has_added_seeds: self.delete_seeds_by_label(label) - def describe_network_interfaces(self, botocred: dict) -> dict: + def describe_network_interfaces(self, ctx: AwsScanContext) -> dict: """Retrieve EC2 Elastic Network Interfaces (ENI) data. Raises: @@ -627,8 +629,7 @@ def describe_network_interfaces(self, botocred: dict) -> dict: Returns: dict: Network Interfaces. """ - # TODO pass in scan_contexts - ec2: EC2Client = self.get_aws_client(AwsServices.EC2, credentials=botocred) + ec2: EC2Client = self.get_aws_client(AwsServices.EC2, ctx) interfaces: dict[str, dict[str, Union[None, str, list]]] = {} # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html#EC2.Client.describe_network_interfaces @@ -642,7 +643,7 @@ def describe_network_interfaces(self, botocred: dict) -> dict: network_interface_id = network.get("NetworkInterfaceId") instance_id = network.get("Attachment", {}).get("InstanceId") - if self.network_interfaces_ignored_tags(network): + if self.network_interfaces_ignored_tags(ctx.ignored_tags, network): self.logger.debug( f"Skipping ignored tag for network interface {network_interface_id}" ) @@ -661,7 +662,9 @@ def describe_network_interfaces(self, botocred: dict) -> dict: return interfaces def get_resource_tags_paginated( - self, botocred: dict, resource_types: Optional[list[str]] = None + self, + ctx: AwsScanContext, + resource_types: Optional[list[str]] = None, ) -> Generator[TagDescriptionTypeDef, None, None]: """Retrieve EC2 resource tags paginated. @@ -671,11 +674,8 @@ def get_resource_tags_paginated( Yields: Generator[TagDescriptionTypeDef]: Tags. """ - # TODO pass in ctx - ec2: EC2Client = self.get_aws_client(AwsServices.EC2, credentials=botocred) - paginator = ec2.get_paginator( - "describe_tags", - ) + ec2: EC2Client = self.get_aws_client(AwsServices.EC2, ctx) + paginator = ec2.get_paginator("describe_tags") for page in paginator.paginate( Filters=[ @@ -686,7 +686,9 @@ def get_resource_tags_paginated( yield from tags def get_resource_tags( - self, botocred: dict, resource_types: Optional[list[str]] = None + self, + ctx: AwsScanContext, + resource_types: Optional[list[str]] = None, ) -> tuple[dict, dict]: """Get EC2 resource tags based on resource types. @@ -699,7 +701,7 @@ def get_resource_tags( resource_tags: dict = {} resource_tag_sets: dict = {} - for tag in self.get_resource_tags_paginated(resource_types, botocred): + for tag in self.get_resource_tags_paginated(resource_types, ctx): # Tags come in two formats: # 1. Tag = { Key = "Name", Value = "actual-tag-name" } # 2. Tag = { Key = "actual-key-name", Value = "tag-value-that-is-unused-here"} @@ -718,7 +720,9 @@ def get_resource_tags( return resource_tags, resource_tag_sets - def network_interfaces_ignored_tags(self, data: NetworkInterfaceTypeDef) -> bool: + def network_interfaces_ignored_tags( + self, ignored_tags: list[str], data: NetworkInterfaceTypeDef + ) -> bool: """Check if network interface has ignored tags. Args: @@ -729,17 +733,15 @@ def network_interfaces_ignored_tags(self, data: NetworkInterfaceTypeDef) -> bool """ tag_set = data.get("TagSet", []) tags = self.extract_tags_from_tagset(tag_set) - return self.has_ignored_tag(tags) + return self.has_ignored_tag(ignored_tags, tags) def get_rds_instances(self, **kwargs): """Retrieve Relational Database Services (RDS) data and emit seeds.""" key = kwargs["scan_context_key"] ctx: AwsScanContext = self.scan_contexts[key] - client: RDSClient = self.get_aws_client( - service=AwsServices.RDS, credentials=ctx.botocred - ) - label = self.format_label(SeedLabel.RDS) - has_added_seeds = False + + client: RDSClient = self.get_aws_client(service=AwsServices.RDS, ctx) + label = self.format_label(SeedLabel.RDS, ctx.account_number, ctx.region) try: data = client.describe_db_instances() @@ -791,17 +793,12 @@ def _get_route53_zone_resources( def get_route53_zones(self, **kwargs): """Retrieve Route 53 Zones and emit seeds.""" - # TODO: how to pass in cred,region? key = kwargs["scan_context_key"] ctx: AwsScanContext = self.scan_contexts[key] - client: Route53Client = self.get_aws_client( - service=AwsServices.ROUTE53_ZONES, credentials=ctx.botocred - ) + client: Route53Client = self.get_aws_client(AwsServices.ROUTE53_ZONES, ctx) label = self.format_label( - SeedLabel.ROUTE53_ZONES, - region=ctx.region, - account_number=ctx.credential["account_number"], + SeedLabel.ROUTE53_ZONES, ctx.account_number, ctx.region ) has_added_seeds = False @@ -842,10 +839,9 @@ def get_ecs_instances(self, **kwargs): key = kwargs["scan_context_key"] ctx: AwsScanContext = self.scan_contexts[key] - ecs: ECSClient = self.get_aws_client(AwsServices.ECS, credentials=ctx.botocred) - ec2: EC2Client = self.get_aws_client(AwsServices.EC2, credentials=ctx.botocred) - label = self.format_label(SeedLabel.ECS) - has_added_seeds = False + ecs: ECSClient = self.get_aws_client(AwsServices.ECS, ctx) + ec2: EC2Client = self.get_aws_client(AwsServices.EC2, ctx) + label = self.format_label(SeedLabel.ECS, ctx.account_number, ctx.region) try: clusters = ecs.list_clusters() @@ -901,9 +897,7 @@ def get_s3_instances(self, **kwargs): key = kwargs["scan_context_key"] ctx: AwsScanContext = self.scan_contexts[key] - client: S3Client = self.get_aws_client( - service=AwsServices.STORAGE_BUCKET, credentials=ctx.botocred - ) + client: S3Client = self.get_aws_client(AwsServices.STORAGE_BUCKET, ctx) try: data = client.list_buckets().get("Buckets", []) @@ -916,8 +910,8 @@ def get_s3_instances(self, **kwargs): lookup_region = self.get_s3_region(client, bucket_name) label = self.format_label( SeedLabel.STORAGE_BUCKET, - region=ctx.region, - account_number=ctx.account_number, + ctx.account_number, + ctx.region, ) with SuppressValidationError(): @@ -951,17 +945,17 @@ def get_ignored_tags(self, tags: Optional[list[str]] = None): return list(ignored) - def has_ignored_tag(self, tags: list[str]) -> bool: + def has_ignored_tag(self, ignored_tags: list[str], tags: list[str]) -> bool: """Check if a list of tags contains an ignored tag. Args: - tags (list[str]): Tags on the current resource. + ignored_tags (list[str]): Ignored tags for the current AwsScanContext + tags (list[str]): Tags on the current AWS resource. Returns: bool: If the list contains an ignored tag. """ - # TODO: fix self.ignored_tags - return any(tag in self.ignored_tags for tag in tags) + return any(tag in ignored_tags for tag in tags) def extract_tags_from_tagset(self, tag_set: list[TagTypeDef]) -> list[str]: """Extract tags from tagset. diff --git a/src/censys/cloud_connectors/common/connector.py b/src/censys/cloud_connectors/common/connector.py index 3f3943c..6a77b45 100644 --- a/src/censys/cloud_connectors/common/connector.py +++ b/src/censys/cloud_connectors/common/connector.py @@ -249,11 +249,11 @@ def scan_seeds(self, **kwargs): self.submit_seeds_wrapper() self.dispatch_event(EventTypeEnum.SCAN_FINISHED) - def scan_cloud_assets(self): + def scan_cloud_assets(self, **kwargs): """Scan the cloud assets.""" self.logger.info("Gathering cloud assets...") self.dispatch_event(EventTypeEnum.SCAN_STARTED) - self.get_cloud_assets() + self.get_cloud_assets(**kwargs) self.submit_cloud_assets_wrapper() self.dispatch_event(EventTypeEnum.SCAN_FINISHED) diff --git a/src/censys/cloud_connectors/common/healthcheck.py b/src/censys/cloud_connectors/common/healthcheck.py index 6b2a7d8..784a421 100644 --- a/src/censys/cloud_connectors/common/healthcheck.py +++ b/src/censys/cloud_connectors/common/healthcheck.py @@ -19,6 +19,7 @@ class Healthcheck: def __init__( self, settings: Settings, + # TODO: make sure this still works (should use AwsScanContext) provider_specific_settings: ProviderSpecificSettings, provider: Optional[dict] = None, exception_map: Optional[dict[Exception, ErrorCodes]] = None, From 071599f5ef29989d269535199c104643f6d61af8 Mon Sep 17 00:00:00 2001 From: Eric Butera Date: Wed, 16 Aug 2023 12:22:55 -0400 Subject: [PATCH 06/13] WIP - introduce cloud events - emit payloads - hacky Aurora client - research log line showing provider info --- poetry.lock | 33 ++++++- pyproject.toml | 1 + .../aws_connector/connector.py | 50 ++++++---- src/censys/cloud_connectors/common/aurora.py | 53 +++++++++++ .../cloud_connectors/common/connector.py | 93 ++++++++++++++++++- src/censys/cloud_connectors/common/logger.py | 16 +++- 6 files changed, 222 insertions(+), 24 deletions(-) create mode 100644 src/censys/cloud_connectors/common/aurora.py diff --git a/poetry.lock b/poetry.lock index 5508160..371b2da 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2554,6 +2554,23 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "cloudevents" +version = "1.9.0" +description = "CloudEvents Python SDK" +optional = false +python-versions = "*" +files = [ + {file = "cloudevents-1.9.0-py3-none-any.whl", hash = "sha256:1011459d56d8f0184a46456f5d72632a2565f18171e51b33e06f643e723d30c9"}, + {file = "cloudevents-1.9.0.tar.gz", hash = "sha256:8beb27503f97e215f886f73c17671012e96bb6268137fb3b2f9ef552727ab5b1"}, +] + +[package.dependencies] +deprecation = ">=2.0,<3.0" + +[package.extras] +pydantic = ["pydantic (>=1.0.0,<2.0)"] + [[package]] name = "colorama" version = "0.4.6" @@ -2715,6 +2732,20 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest (<5)", "PyTest-Cov", "PyTest-Cov (<2.6)", "bump2version (<1)", "configparser (<5)", "importlib-metadata (<3)", "importlib-resources (<4)", "sphinx (<2)", "sphinxcontrib-websupport (<2)", "tox", "zipp (<2)"] +[[package]] +name = "deprecation" +version = "2.1.0" +description = "A library to handle automated deprecations" +optional = false +python-versions = "*" +files = [ + {file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"}, + {file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"}, +] + +[package.dependencies] +packaging = "*" + [[package]] name = "distlib" version = "0.3.6" @@ -5587,4 +5618,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "1c36716ca7aa69cbee68bdeb39ee9c7a7ecbbed87c43b1180a1faafa115182cb" +content-hash = "fa3a761a2df073f297e1044a6f54c62bb14ba4055b1a438b043707c32b696e79" diff --git a/pyproject.toml b/pyproject.toml index 26daf32..edbc433 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ inquirerpy = "^0.3.3" pydantic = {extras = ["dotenv", "email"], version = "^1.9.0"} requests = "^2.30.0" rich = "^13.3.5" +cloudevents = "^1.9.0" protobuf = "^4.23.4" [tool.poetry.group.dev.dependencies] diff --git a/src/censys/cloud_connectors/aws_connector/connector.py b/src/censys/cloud_connectors/aws_connector/connector.py index f8f21e6..be847d7 100644 --- a/src/censys/cloud_connectors/aws_connector/connector.py +++ b/src/censys/cloud_connectors/aws_connector/connector.py @@ -104,7 +104,6 @@ class AwsCloudConnector(CloudConnector): # provider_settings: AwsSpecificSettings # credential: dict = {} # account_number: str - # region: Optional[str] # ignored_tags: list[str] # pool: Pool @@ -141,12 +140,14 @@ def scan_seeds(self, **kwargs): scan_context_key = kwargs["scan_context_key"] scan_context: AwsScanContext = kwargs["scan_context"] # scan_context must be set within scan(), otherwise race conditions happen where it doesn't exist when accessed + self.scan_contexts[scan_context_key] = scan_context # multiprocessing requires separate logger instances per process # TODO: there is still something odd going on here, notice logger isnt being set to self.logger, but by even calling get_logger it "fixes" the sub-process log level from being WARNING to .env's DEBUG logger = get_logger( log_name=f"{self.provider.lower()}_cloud_connector", level=self.settings.logging_level, + # TODO: extra - add account + region to log lines ) scan_context.logger = logger @@ -155,7 +156,9 @@ def scan_seeds(self, **kwargs): ) self.scan_contexts[scan_context_key] = scan_context + # TODO: self.dispatch_event(EventTypeEnum.SCAN_STARTED) super().scan_seeds(**kwargs) + # TODO: self.dispatch_event(EventTypeEnum.SCAN_FINISHED) def scan_cloud_assets(self, **kwargs): """Scan AWS for cloud assets.""" @@ -337,7 +340,8 @@ def get_aws_client( Args: service (AwsServices): The AWS service name. - credentials (dict): Override credentials instead of using the default. + ctx (AwsScanContext): The scan context. + credentials (dict): Override credentials instead of using the default. Typically used with STS Raises: Exception: If the client could not be created. @@ -460,8 +464,8 @@ def assume_role( ctx: AwsScanContext, # account_number: str, # region: str, - role_name: Optional[str] = AwsDefaults.ROLE_NAME.value, - role_session_name: Optional[str] = AwsDefaults.ROLE_SESSION_NAME.value, + role_name: Optional[str] = None, + role_session_name: Optional[str] = None, ) -> CredentialsTypeDef: """Acquire temporary credentials generated by Secure Token Service (STS). @@ -476,6 +480,9 @@ def assume_role( Returns: CredentialsTypeDef: Temporary credentials. """ + role_name = role_name or AwsDefaults.ROLE_NAME.value + role_session_name = role_session_name or AwsDefaults.ROLE_SESSION_NAME.value + # TODO: verify this works with worker pool change- Always use the primary account credentials to query STS # use primary account's credentials to query STS for temp creds credentials = self.boto_cred( @@ -487,6 +494,7 @@ def assume_role( client: STSClient = self.get_aws_client( AwsServices.SECURE_TOKEN_SERVICE, credentials=credentials, + ctx=ctx, ) role: dict[str, Any] = { @@ -515,7 +523,8 @@ def get_api_gateway_domains_v1(self, **kwargs): domain_name = f"{domain['id']}.execute-api.{ctx.region}.amazonaws.com" with SuppressValidationError(): domain_seed = DomainSeed(value=domain_name, label=label) - self.add_seed(domain_seed, api_gateway_res=domain) + # self.add_seed(domain_seed, api_gateway_res=domain) + self.emit_seed(ctx, domain_seed, api_gateway_res=domain) except ClientError as e: self.logger.error(f"Could not connect to API Gateway V1. Error: {e}") @@ -535,7 +544,8 @@ def get_api_gateway_domains_v2(self, **kwargs): domain_name = domain["ApiEndpoint"].split("//")[1] with SuppressValidationError(): domain_seed = DomainSeed(value=domain_name, label=label) - self.add_seed(domain_seed, api_gateway_res=domain) + # self.add_seed(domain_seed, api_gateway_res=domain) + self.emit_seed(ctx, domain_seed, api_gateway_res=domain) except ClientError as e: self.logger.error(f"Could not connect to API Gateway V2. Error: {e}") @@ -563,7 +573,8 @@ def get_load_balancers_v1(self, **kwargs): if value := elb.get("DNSName"): with SuppressValidationError(): domain_seed = DomainSeed(value=value, label=label) - self.add_seed(domain_seed, elb_res=elb, aws_client=client) + # self.add_seed(domain_seed, elb_res=elb, aws_client=client) + self.emit_seed(ctx, domain_seed, elb_res=elb, aws_client=client) except ClientError as e: self.logger.error(f"Could not connect to ELB V1. Error: {e}") @@ -583,7 +594,8 @@ def get_load_balancers_v2(self, **kwargs): if value := elb.get("DNSName"): with SuppressValidationError(): domain_seed = DomainSeed(value=value, label=label) - self.add_seed(domain_seed, elb_res=elb, aws_client=client) + # self.add_seed(domain_seed, elb_res=elb, aws_client=client) + self.emit_seed(ctx, domain_seed, elb_res=elb, aws_client=client) except ClientError as e: self.logger.error(f"Could not connect to ELB V2. Error: {e}") @@ -615,10 +627,8 @@ def get_network_interfaces(self, **kwargs): with SuppressValidationError(): ip_seed = IpSeed(value=ip_address, label=label) - self.add_seed(ip_seed, tags=instance_tag_sets.get(instance_id)) - has_added_seeds = True - if not has_added_seeds: - self.delete_seeds_by_label(label) + # self.add_seed(ip_seed, tags=instance_tag_sets.get(instance_id)) + self.emit_seed(ctx, ip_seed, tags=instance_tag_sets.get(instance_id)) def describe_network_interfaces(self, ctx: AwsScanContext) -> dict: """Retrieve EC2 Elastic Network Interfaces (ENI) data. @@ -742,6 +752,7 @@ def get_rds_instances(self, **kwargs): client: RDSClient = self.get_aws_client(service=AwsServices.RDS, ctx) label = self.format_label(SeedLabel.RDS, ctx.account_number, ctx.region) + has_added_seeds = False try: data = client.describe_db_instances() @@ -752,8 +763,9 @@ def get_rds_instances(self, **kwargs): if domain_name := instance.get("Endpoint", {}).get("Address"): with SuppressValidationError(): domain_seed = DomainSeed(value=domain_name, label=label) - self.add_seed(domain_seed, rds_res=instance) + # self.add_seed(domain_seed, rds_res=instance) has_added_seeds = True + self.emit_seed(ctx, domain_seed, rds_res=instance) if not has_added_seeds: self.delete_seeds_by_label(label) except ClientError as e: @@ -812,8 +824,9 @@ def get_route53_zones(self, **kwargs): domain_name = zone.get("Name").rstrip(".") with SuppressValidationError(): domain_seed = DomainSeed(value=domain_name, label=label) - self.add_seed(domain_seed, route53_zone_res=zone, aws_client=client) has_added_seeds = True + # self.add_seed(domain_seed, route53_zone_res=zone, aws_client=client) + self.emit_seed(ctx, domain_seed, route53_zone_res=zone) id = zone.get("Id") resource_sets = self._get_route53_zone_resources(client, id) @@ -843,6 +856,7 @@ def get_ecs_instances(self, **kwargs): ec2: EC2Client = self.get_aws_client(AwsServices.EC2, ctx) label = self.format_label(SeedLabel.ECS, ctx.account_number, ctx.region) + has_added_seeds = False try: clusters = ecs.list_clusters() for cluster in clusters.get("clusterArns", []): @@ -871,8 +885,9 @@ def get_ecs_instances(self, **kwargs): with SuppressValidationError(): ip_seed = IpSeed(value=ip_address, label=label) - self.add_seed(ip_seed, ecs_res=instance) + # self.add_seed(ip_seed, ecs_res=instance) has_added_seeds = True + self.emit_seed(ctx, ip_seed, ecs_res=instance) if not has_added_seeds: self.delete_seeds_by_label(label) except ClientError as e: @@ -922,8 +937,9 @@ def get_s3_instances(self, **kwargs): "accountNumber": ctx.account_number, }, ) - self.add_cloud_asset( - bucket_asset, bucket_name=bucket_name, aws_client=client + # self.add_cloud_asset(bucket_asset, bucket_name=bucket_name, aws_client=client) + self.emit_cloud_asset( + ctx, bucket_asset, bucket_name=bucket_name, aws_client=client ) except ClientError as e: self.logger.error(f"Could not connect to S3. Error: {e}") diff --git a/src/censys/cloud_connectors/common/aurora.py b/src/censys/cloud_connectors/common/aurora.py new file mode 100644 index 0000000..b275622 --- /dev/null +++ b/src/censys/cloud_connectors/common/aurora.py @@ -0,0 +1,53 @@ +"""Interact with Aurora API.""" +from cloudevents.conversion import to_structured + +from censys.asm import Seeds + + +class Aurora(Seeds): + """Aurora API class.""" + + # TODO base_path = "/api/v1/payload/enqueue" + base_path = "/api/payload/enqueue" + + def emit(self, payload: dict) -> None: + """Emit a payload to ASM. + + Args: + payload (dict): Payload to emit. + """ + + # TODO: contact aurora + # https://github.com/cloudevents/sdk-python#structured-http-cloudevent + headers, body = to_structured(payload) + # requests.post("", data=body, headers=headers) + + # data = {"payload": payload} + # censys python is forcing json encoding, whereas we want to send a binary cloud event + # data_workaround = {"body": payload} + # return self._post( + # self.base_path, data=body, headers=headers + # ) # , **data_workaround) + + request_kwargs = {"timeout": self.timeout, "data": body, "headers": headers} + + # TODO @backoff_wrapper + url = f"{self._api_url}{self.base_path}" + resp = self._call_method(self._session.post, url, request_kwargs) + # TODO: handle response + print(f"resp: {resp}") + + def emit_batch(self, payloads) -> None: + """Emit a payload to ASM. + + Args: + payload (dict): Payload to emit. + """ + + # TODO: contact aurora + # https://github.com/cloudevents/sdk-python#structured-http-cloudevent + # headers, body = to_structured(event) + # requests.post("", data=body, headers=headers) + + data = {"payloads": payloads} + return self._post(self.base_path, data=data) diff --git a/src/censys/cloud_connectors/common/connector.py b/src/censys/cloud_connectors/common/connector.py index 6a77b45..ce2b7c5 100644 --- a/src/censys/cloud_connectors/common/connector.py +++ b/src/censys/cloud_connectors/common/connector.py @@ -6,9 +6,13 @@ from logging import Logger from typing import Callable, Optional, Union +from cloudevents.http import CloudEvent + from censys.asm import Beta, Seeds from censys.common.exceptions import CensysAsmException +from censys.cloud_connectors.common.aurora import Aurora + from .cloud_asset import CloudAsset from .enums import EventTypeEnum, ProviderEnum from .logger import get_logger @@ -27,6 +31,7 @@ class CloudConnector(ABC): logger: Logger seeds_api: Seeds beta_api: Beta + aurora_api: Aurora seeds: dict[str, set[Seed]] cloud_assets: dict[str, set[CloudAsset]] seed_scanners: dict[str, Callable[[], None]] @@ -63,12 +68,17 @@ def __init__(self, settings: Settings): user_agent=settings.censys_user_agent, cookies=settings.censys_cookies, ) + self.aurora_api = Aurora( + settings.censys_api_key, + url=settings.censys_asm_api_base_url, + user_agent=settings.censys_user_agent, + cookies=settings.censys_cookies, + ) self.seeds = defaultdict(set) self.cloud_assets = defaultdict(set) self.current_service = None - # TODO: how to pass in cred,region? (each scanner will have diff things to pass in) def delete_seeds_by_label(self, label: str): """Replace seeds for [label] with an empty list. @@ -94,10 +104,15 @@ def get_seeds(self, **kwargs) -> None: self.logger.debug(f"Skipping {seed_type}") continue self.logger.debug(f"Scanning {seed_type}") + + start = time.time() seed_scanner(**kwargs) + duration = time.time() - start + self.logger.debug( + f"Scan seed-type:{seed_type} count:{len(self.seeds)} duration:{duration:.2f}" + ) self.current_service = None - # TODO: how to pass in cred,region? (each scanner will have diff things to pass in) def get_cloud_assets(self, **kwargs) -> None: """Gather cloud assets.""" for cloud_asset_type, cloud_asset_scanner in self.cloud_asset_scanners.items(): @@ -109,7 +124,13 @@ def get_cloud_assets(self, **kwargs) -> None: self.logger.debug(f"Skipping {cloud_asset_type}") continue self.logger.debug(f"Scanning {cloud_asset_type}") + + start = time.time() cloud_asset_scanner(**kwargs) + duration = time.time() - start + self.logger.debug( + f"Scan cloud-asset-type:{cloud_asset_type} count:{len(self.seeds)} duration:{duration:.2f}" + ) self.current_service = None def get_event_context( @@ -150,6 +171,60 @@ def dispatch_event( context = self.get_event_context(event_type, service) CloudConnectorPluginRegistry.dispatch_event(context=context, **kwargs) + def emit(self, payload: dict): + """Send a payload to Aurora + + Args: + payload (dict): Payload. + """ + self.logger.debug(f"Sending payload: {payload}") + + # TODO: emit in batches!!!! + self.aurora_api.emit(payload) + + def emit_seed(self, ctx, seed: Seed, **kwargs): + """Emit a seed payload. + + Args: + seed (Seed): The seed to emit. + **kwargs: Additional data for event dispatching. + """ + # self.logger.debug(f"Found Seed: {seed.to_dict()}") + self.dispatch_event(EventTypeEnum.SEED_FOUND, seed=seed, **kwargs) + + attributes = { + "type": "com.censys.cloud-connector.seed", + "source": "cc-user-agent", + } + data = { + "seed": seed.to_dict(), + } + payload = CloudEvent(attributes, data) + self.emit(payload) + + def emit_cloud_asset(self, ctx, cloud_asset: CloudAsset, **kwargs): + """Emit a cloud asset payload. + + Args: + cloud_asset (CloudAsset): The cloud asset to emit. + **kwargs: Additional data for event dispatching. + """ + # self.logger.debug(f"Found Cloud Asset: {cloud_asset.to_dict()}") + self.dispatch_event( + EventTypeEnum.CLOUD_ASSET_FOUND, cloud_asset=cloud_asset, **kwargs + ) + + attributes = { + "type": "com.censys.cloud-connector.cloud-asset", + "source": "cc-user-agent", + } + data = { + "asset": cloud_asset.to_dict(), + } + payload = CloudEvent(attributes, data) + + self.emit(payload) + def add_seed(self, seed: Seed, **kwargs): """Add a seed. @@ -210,17 +285,25 @@ def submit_cloud_assets(self): def clear(self): """Clear the seeds and cloud assets.""" + # TODO: it's possible this clobbers seeds & cloudassets from other process' account + regions + + # hm.. do seeds and cloud-assets even work in multiprocessing? + # should they be contained in scanner context? + self.logger.debug(f"Clearing {len(self.seeds)} seeds") self.seeds.clear() + + self.logger.debug(f"Clearing {len(self.cloud_assets)} cloud assets") self.cloud_assets.clear() - def submit(self): # pragma: no cover + def submit(self, **kwargs): # pragma: no cover """Submit the seeds and cloud assets to Censys ASM.""" if self.settings.dry_run: self.logger.info("Dry run enabled. Skipping submission.") else: self.logger.info("Submitting seeds and cloud assets...") - self.submit_seeds() - self.submit_cloud_assets() + self.submit_seeds(**kwargs) + self.submit_cloud_assets(**kwargs) + self.clear() def submit_seeds_wrapper(self): # pragma: no cover diff --git a/src/censys/cloud_connectors/common/logger.py b/src/censys/cloud_connectors/common/logger.py index c3da211..fc77d04 100644 --- a/src/censys/cloud_connectors/common/logger.py +++ b/src/censys/cloud_connectors/common/logger.py @@ -4,7 +4,7 @@ def get_logger( - log_name: Optional[str] = "cloud_connector", level: str = "INFO" + log_name: Optional[str] = "cloud_connector", level: str = "INFO", **kwargs ) -> logging.Logger: """Returns a custom logger. @@ -18,12 +18,26 @@ def get_logger( """ logger = logging.getLogger(log_name) if not logger.hasHandlers(): + formatter = logging.Formatter( fmt="%(asctime)s:%(levelname)s:%(name)s: %(message)s" + # fmt="%(asctime)s:%(levelname)s:%(name)s:%(provider)s: %(message)s" ) handler = logging.StreamHandler() handler.setFormatter(formatter) logger.addHandler(handler) + + # TODO - add provider (AWS=account+region, GCP=org+project, AZURE=subid) to log record + # + # https://stackoverflow.com/a/57820456/19351735 + # old_factory = logging.getLogRecordFactory() + # def record_factory(*args, **kwargs): + # record = old_factory(*args, **kwargs) + # provider = kwargs.get("provider", "") + # record.provider = provider + # return record + + # logging.setLogRecordFactory(record_factory) logger.setLevel(level) return logger From b482c3e5d58a76ce93b8e0738ce8d9c94676b839 Mon Sep 17 00:00:00 2001 From: Eric Butera Date: Fri, 18 Aug 2023 13:30:50 -0400 Subject: [PATCH 07/13] WIP: - add provider to log (aws only for now) - change add_seed to use a list and submit_seed_payload - change add_cloud_asset to use a map + submit_cloud_asset_payload - rough and ready aurora client --- .../aws_connector/connector.py | 183 +++++++++++++++--- src/censys/cloud_connectors/common/aurora.py | 60 +++--- .../cloud_connectors/common/connector.py | 141 ++++++++------ src/censys/cloud_connectors/common/logger.py | 30 +-- 4 files changed, 279 insertions(+), 135 deletions(-) diff --git a/src/censys/cloud_connectors/aws_connector/connector.py b/src/censys/cloud_connectors/aws_connector/connector.py index be847d7..b9d7a25 100644 --- a/src/censys/cloud_connectors/aws_connector/connector.py +++ b/src/censys/cloud_connectors/aws_connector/connector.py @@ -147,7 +147,7 @@ def scan_seeds(self, **kwargs): logger = get_logger( log_name=f"{self.provider.lower()}_cloud_connector", level=self.settings.logging_level, - # TODO: extra - add account + region to log lines + provider=f"{self.provider}_{scan_context.account_number}_{scan_context.region}", ) scan_context.logger = logger @@ -300,6 +300,7 @@ def format_label( """ region_label = f"/{region}" if region != "" else "" return f"AWS: {service} - {account_number}{region_label}" + # technically this should use self.provider.label() instead of hardcoding "AWS" def credentials(self, ctx: AwsScanContext) -> dict: """Generate required credentials for AWS. @@ -518,13 +519,21 @@ def get_api_gateway_domains_v1(self, **kwargs): label = self.format_label(SeedLabel.API_GATEWAY, ctx.account_number, ctx.region) try: + seeds = [] apis = client.get_rest_apis() for domain in apis.get("items", []): domain_name = f"{domain['id']}.execute-api.{ctx.region}.amazonaws.com" with SuppressValidationError(): - domain_seed = DomainSeed(value=domain_name, label=label) + # domain_seed = DomainSeed(value=domain_name, label=label) # self.add_seed(domain_seed, api_gateway_res=domain) - self.emit_seed(ctx, domain_seed, api_gateway_res=domain) + # self.emit_seed(ctx, domain_seed, api_gateway_res=domain) + seed = self.process_seed( + DomainSeed( + value=domain_name, label=label, api_gateway_res=domain + ) + ) + seeds.append(seed) + self.submit_seed_payload(label, seeds) except ClientError as e: self.logger.error(f"Could not connect to API Gateway V1. Error: {e}") @@ -539,13 +548,21 @@ def get_api_gateway_domains_v2(self, **kwargs): label = self.format_label(SeedLabel.API_GATEWAY, ctx.account_number, ctx.region) try: + seeds = [] apis = client.get_apis() for domain in apis.get("Items", []): domain_name = domain["ApiEndpoint"].split("//")[1] with SuppressValidationError(): - domain_seed = DomainSeed(value=domain_name, label=label) + # domain_seed = DomainSeed(value=domain_name, label=label) # self.add_seed(domain_seed, api_gateway_res=domain) - self.emit_seed(ctx, domain_seed, api_gateway_res=domain) + # self.emit_seed(ctx, domain_seed, api_gateway_res=domain) + seed = self.process_seed( + DomainSeed( + value=domain_name, label=label, api_gateway_res=domain + ) + ) + seeds.append(seed) + self.submit_seed_payload(label, seeds) except ClientError as e: self.logger.error(f"Could not connect to API Gateway V2. Error: {e}") @@ -568,13 +585,21 @@ def get_load_balancers_v1(self, **kwargs): label = self.format_label(SeedLabel.LOAD_BALANCER, ctx.account_number, ctx.region) try: + seeds = [] data = client.describe_load_balancers() for elb in data.get("LoadBalancerDescriptions", []): if value := elb.get("DNSName"): with SuppressValidationError(): - domain_seed = DomainSeed(value=value, label=label) + # domain_seed = DomainSeed(value=value, label=label) # self.add_seed(domain_seed, elb_res=elb, aws_client=client) - self.emit_seed(ctx, domain_seed, elb_res=elb, aws_client=client) + # self.emit_seed(ctx, domain_seed, elb_res=elb, aws_client=client) + seed = self.process_seed( + DomainSeed(value=value, label=label), + elb_res=elb, + aws_client=client, + ) + seeds.append(seed) + self.submit_seed_payload(label, seeds) except ClientError as e: self.logger.error(f"Could not connect to ELB V1. Error: {e}") @@ -589,13 +614,21 @@ def get_load_balancers_v2(self, **kwargs): label = self.format_label(SeedLabel.LOAD_BALANCER, ctx.account_number, ctx.region) try: + seeds = [] data = client.describe_load_balancers() for elb in data.get("LoadBalancers", []): if value := elb.get("DNSName"): with SuppressValidationError(): - domain_seed = DomainSeed(value=value, label=label) + # domain_seed = DomainSeed(value=value, label=label) # self.add_seed(domain_seed, elb_res=elb, aws_client=client) - self.emit_seed(ctx, domain_seed, elb_res=elb, aws_client=client) + # self.emit_seed(ctx, domain_seed, elb_res=elb, aws_client=client) + seed = self.process_seed( + DomainSeed(value=value, label=label), + elb_res=elb, + aws_client=client, + ) + seeds.append(seed) + self.submit_seed_payload(label, seeds) except ClientError as e: self.logger.error(f"Could not connect to ELB V2. Error: {e}") @@ -616,6 +649,7 @@ def get_network_interfaces(self, **kwargs): interfaces = self.describe_network_interfaces(ctx) instance_tags, instance_tag_sets = self.get_resource_tags(ctx) + seeds = [] for ip_address, record in interfaces.items(): instance_id = record["InstanceId"] tags = instance_tags.get(instance_id) @@ -626,9 +660,16 @@ def get_network_interfaces(self, **kwargs): continue with SuppressValidationError(): - ip_seed = IpSeed(value=ip_address, label=label) + # ip_seed = IpSeed(value=ip_address, label=label) # self.add_seed(ip_seed, tags=instance_tag_sets.get(instance_id)) - self.emit_seed(ctx, ip_seed, tags=instance_tag_sets.get(instance_id)) + # self.emit_seed(ctx, ip_seed, tags=instance_tag_sets.get(instance_id)) + seed = self.process_seed( + IpSeed(value=ip_address, label=label), + tags=instance_tag_sets.get(instance_id), + ) + seeds.append(seed) + + self.submit_seed_payload(label, seeds) def describe_network_interfaces(self, ctx: AwsScanContext) -> dict: """Retrieve EC2 Elastic Network Interfaces (ENI) data. @@ -755,6 +796,7 @@ def get_rds_instances(self, **kwargs): has_added_seeds = False try: + seeds = [] data = client.describe_db_instances() for instance in data.get("DBInstances", []): if not instance.get("PubliclyAccessible"): @@ -762,10 +804,16 @@ def get_rds_instances(self, **kwargs): if domain_name := instance.get("Endpoint", {}).get("Address"): with SuppressValidationError(): - domain_seed = DomainSeed(value=domain_name, label=label) + # domain_seed = DomainSeed(value=domain_name, label=label) # self.add_seed(domain_seed, rds_res=instance) + # self.emit_seed(ctx, domain_seed, rds_res=instance) has_added_seeds = True - self.emit_seed(ctx, domain_seed, rds_res=instance) + seed = self.process_seed( + DomainSeed(value=domain_name, label=label), rds_res=instance + ) + seeds.append(seed) + + self.submit_seed_payload(label, seeds) if not has_added_seeds: self.delete_seeds_by_label(label) except ClientError as e: @@ -814,7 +862,15 @@ def get_route53_zones(self, **kwargs): ) has_added_seeds = False + # TODO: potentially send seeds with empty values to remove "stale" seeds + + # Notice add_seed has extra keyword arguments - these were piped into add_seed for dispatch_event + # - add_seed cannot use self.seeds anymore because concurrency + # - add_seed dispatched_event PER seed, but seeds were later submitted in submit_seeds + # - for now i split dispatch_event, add_seed, and submit_seed_payloads into separate calls + try: + seeds = [] zones = self._get_route53_zone_hosts(client) for zone in zones.get("HostedZones", []): if zone.get("Config", {}).get("PrivateZone"): @@ -823,10 +879,22 @@ def get_route53_zones(self, **kwargs): # Add the zone itself as a seed domain_name = zone.get("Name").rstrip(".") with SuppressValidationError(): - domain_seed = DomainSeed(value=domain_name, label=label) + # domain_seed = DomainSeed(value=domain_name, label=label) has_added_seeds = True - # self.add_seed(domain_seed, route53_zone_res=zone, aws_client=client) - self.emit_seed(ctx, domain_seed, route53_zone_res=zone) + seed = self.process_seed( + DomainSeed(value=domain_name, label=label), + route53_zone_res=zone, + aws_client=client, + ) + seeds.append(seed) + # self.emit_seed(ctx, domain_seed, route53_zone_res=zone) + # self.dispatch_event( + # EventTypeEnum.SEED_FOUND, + # seed=domain_seed, + # route53_zone_res=zone, + # aws_client=client, + # ) + # seeds.append(domain_seed) id = zone.get("Id") resource_sets = self._get_route53_zone_resources(client, id) @@ -837,11 +905,33 @@ def get_route53_zones(self, **kwargs): domain_name = resource_set.get("Name").rstrip(".") with SuppressValidationError(): - domain_seed = DomainSeed(value=domain_name, label=label) - self.add_seed( - domain_seed, route53_zone_res=zone, aws_client=client + # domain_seed = DomainSeed(value=domain_name, label=label) + # + # TODO: label is for this entire loop, emitting per item will make more requests than necessary! + # seeds[seed.label].push(seed) + # self.add_seed(domain_seed, route53_zone_res=zone, aws_client=client) + # TODO: add_seed quadratic time - loops here, then loops to submit seed + # self.emit_seed( + # ctx, domain_seed, route53_zone_res=zone, aws_client=client + # ) + # + # self.dispatch_event( + # EventTypeEnum.SEED_FOUND, + # seed=domain_seed, + # route53_zone_res=zone, + # aws_client=client, + # ) + # seeds.append(domain_seed) + + seed = self.process_seed( + DomainSeed(value=domain_name, label=label), + route53_zone_res=zone, + aws_client=client, ) + seeds.append(seed) has_added_seeds = True + + self.submit_seed_payload(label, seeds) if not has_added_seeds: self.delete_seeds_by_label(label) except ClientError as e: @@ -858,6 +948,7 @@ def get_ecs_instances(self, **kwargs): has_added_seeds = False try: + seeds = [] clusters = ecs.list_clusters() for cluster in clusters.get("clusterArns", []): cluster_instances = ecs.list_container_instances(cluster=cluster) @@ -884,10 +975,22 @@ def get_ecs_instances(self, **kwargs): continue with SuppressValidationError(): - ip_seed = IpSeed(value=ip_address, label=label) + # ip_seed = IpSeed(value=ip_address, label=label) + # TODO: don't use add_seed + # instead, emit Payload + # modifying add seed would require managing account+region or use AwsScanContext which requires more time than available # self.add_seed(ip_seed, ecs_res=instance) + # self.emit_seed(ctx, ip_seed, ecs_res=instance) + # or maybe self.enqueue(seed) + # would be best to async queue these + # but we are in a pool already... + seed = self.process_seed( + IpSeed(value=ip_address, label=label), ecs_res=instance + ) + seeds.append(seed) has_added_seeds = True - self.emit_seed(ctx, ip_seed, ecs_res=instance) + + self.submit_seed_payload(label, seeds) if not has_added_seeds: self.delete_seeds_by_label(label) except ClientError as e: @@ -908,7 +1011,6 @@ def get_s3_region(self, client: S3Client, bucket: str) -> str: def get_s3_instances(self, **kwargs): """Retrieve Simple Storage Service data and emit seeds.""" - # TODO: how to pass in cred,region? key = kwargs["scan_context_key"] ctx: AwsScanContext = self.scan_contexts[key] @@ -917,30 +1019,57 @@ def get_s3_instances(self, **kwargs): try: data = client.list_buckets().get("Buckets", []) + # TODO: this should actually be a set of buckets, not a list (no dupes) + # findings = { 'uid1=AWS: 123456789012/us-east-1': [asset,...], 'uid2=AWS: 123456789012/us-west-1': [asset,...]} + findings: dict[str, list[AwsStorageBucketAsset]] = {} + for bucket in data: bucket_name = bucket.get("Name") if not bucket_name: continue + # TODO: figure out correct value for region + # if we use lookup_region, then the submit_cloud_asset_payload call will need to be adjusted + # it shouldn't be submitting a payload PER bucket; it should be payload per account + region lookup_region = self.get_s3_region(client, bucket_name) label = self.format_label( SeedLabel.STORAGE_BUCKET, ctx.account_number, - ctx.region, + # oh this is interesting.... lookup_region OR ctx.region.. which one? + # pretty sure it's lookup_region, otherwise whats the point of looking up the bucket's region? + lookup_region, + # ctx.region, ) + # TODO: this isnt right + # assets = [] + with SuppressValidationError(): - bucket_asset = AwsStorageBucketAsset( + asset = AwsStorageBucketAsset( value=AwsStorageBucketAsset.url(bucket_name, lookup_region), uid=label, scan_data={ "accountNumber": ctx.account_number, }, ) - # self.add_cloud_asset(bucket_asset, bucket_name=bucket_name, aws_client=client) - self.emit_cloud_asset( - ctx, bucket_asset, bucket_name=bucket_name, aws_client=client + # self.add_cloud_asset(asset, bucket_name=bucket_name, aws_client=client) + # self.emit_cloud_asset( + # ctx, asset, bucket_name=bucket_name, aws_client=client + # ) + asset = self.process_cloud_asset( + asset, bucket_name=bucket_name, aws_client=client ) + # assets.append(asset) + if label not in findings: + findings[label] = [] + findings[label].append(asset) + + # TODO convert this to findings below + # self.submit_cloud_asset_payload(label, assets) + + # TODO: submit findings map here + for label, assets in findings.items(): + self.submit_cloud_asset_payload(label, assets) except ClientError as e: self.logger.error(f"Could not connect to S3. Error: {e}") diff --git a/src/censys/cloud_connectors/common/aurora.py b/src/censys/cloud_connectors/common/aurora.py index b275622..8d893eb 100644 --- a/src/censys/cloud_connectors/common/aurora.py +++ b/src/censys/cloud_connectors/common/aurora.py @@ -1,53 +1,41 @@ -"""Interact with Aurora API.""" +"""Aurora API client.""" from cloudevents.conversion import to_structured +from cloudevents.http import CloudEvent from censys.asm import Seeds class Aurora(Seeds): - """Aurora API class.""" + """Aurora API client.""" - # TODO base_path = "/api/v1/payload/enqueue" - base_path = "/api/payload/enqueue" + base_path = "/api" - def emit(self, payload: dict) -> None: - """Emit a payload to ASM. + # TODO @backoff_wrapper + def enqueue_payload(self, payload: CloudEvent) -> None: + """Enqueue a payload for later processing. Args: - payload (dict): Payload to emit. + payload (CloudEvent): Payload. """ - # TODO: contact aurora - # https://github.com/cloudevents/sdk-python#structured-http-cloudevent headers, body = to_structured(payload) - # requests.post("", data=body, headers=headers) - - # data = {"payload": payload} - # censys python is forcing json encoding, whereas we want to send a binary cloud event - # data_workaround = {"body": payload} - # return self._post( - # self.base_path, data=body, headers=headers - # ) # , **data_workaround) - request_kwargs = {"timeout": self.timeout, "data": body, "headers": headers} - # TODO @backoff_wrapper - url = f"{self._api_url}{self.base_path}" + # url = f"{self._api_url}{self.base_path}" + url = f"{self._api_url}{self.base_path}/payload/enqueue" resp = self._call_method(self._session.post, url, request_kwargs) - # TODO: handle response - print(f"resp: {resp}") - def emit_batch(self, payloads) -> None: - """Emit a payload to ASM. - - Args: - payload (dict): Payload to emit. - """ - - # TODO: contact aurora - # https://github.com/cloudevents/sdk-python#structured-http-cloudevent - # headers, body = to_structured(event) - # requests.post("", data=body, headers=headers) - - data = {"payloads": payloads} - return self._post(self.base_path, data=data) + # TODO: handle response + # TODO: read enqueue response `event ID` (for status tracking) + print(f"TODO resp: {resp}") + + if resp.ok: + try: + json_data = resp.json() + # if "error" not in json_data: + # return json_data + return json_data + except ValueError: + return {"code": resp.status_code, "status": resp.reason} + + return {} diff --git a/src/censys/cloud_connectors/common/connector.py b/src/censys/cloud_connectors/common/connector.py index ce2b7c5..7b27573 100644 --- a/src/censys/cloud_connectors/common/connector.py +++ b/src/censys/cloud_connectors/common/connector.py @@ -171,60 +171,6 @@ def dispatch_event( context = self.get_event_context(event_type, service) CloudConnectorPluginRegistry.dispatch_event(context=context, **kwargs) - def emit(self, payload: dict): - """Send a payload to Aurora - - Args: - payload (dict): Payload. - """ - self.logger.debug(f"Sending payload: {payload}") - - # TODO: emit in batches!!!! - self.aurora_api.emit(payload) - - def emit_seed(self, ctx, seed: Seed, **kwargs): - """Emit a seed payload. - - Args: - seed (Seed): The seed to emit. - **kwargs: Additional data for event dispatching. - """ - # self.logger.debug(f"Found Seed: {seed.to_dict()}") - self.dispatch_event(EventTypeEnum.SEED_FOUND, seed=seed, **kwargs) - - attributes = { - "type": "com.censys.cloud-connector.seed", - "source": "cc-user-agent", - } - data = { - "seed": seed.to_dict(), - } - payload = CloudEvent(attributes, data) - self.emit(payload) - - def emit_cloud_asset(self, ctx, cloud_asset: CloudAsset, **kwargs): - """Emit a cloud asset payload. - - Args: - cloud_asset (CloudAsset): The cloud asset to emit. - **kwargs: Additional data for event dispatching. - """ - # self.logger.debug(f"Found Cloud Asset: {cloud_asset.to_dict()}") - self.dispatch_event( - EventTypeEnum.CLOUD_ASSET_FOUND, cloud_asset=cloud_asset, **kwargs - ) - - attributes = { - "type": "com.censys.cloud-connector.cloud-asset", - "source": "cc-user-agent", - } - data = { - "asset": cloud_asset.to_dict(), - } - payload = CloudEvent(attributes, data) - - self.emit(payload) - def add_seed(self, seed: Seed, **kwargs): """Add a seed. @@ -232,6 +178,7 @@ def add_seed(self, seed: Seed, **kwargs): seed (Seed): The seed to add. **kwargs: Additional data for event dispatching. """ + # TODO: not compatible with multiprocessing if not seed.label.startswith(self.label_prefix): seed.label = self.label_prefix + seed.label self.seeds[seed.label].add(seed) @@ -245,6 +192,7 @@ def add_cloud_asset(self, cloud_asset: CloudAsset, **kwargs): cloud_asset (CloudAsset): The cloud asset to add. **kwargs: Additional data for event dispatching. """ + # TODO: not compatible with multiprocessing if not cloud_asset.uid.startswith(self.label_prefix): cloud_asset.uid = self.label_prefix + cloud_asset.uid self.cloud_assets[cloud_asset.uid].add(cloud_asset) @@ -255,6 +203,7 @@ def add_cloud_asset(self, cloud_asset: CloudAsset, **kwargs): def submit_seeds(self): """Submit the seeds to Censys ASM.""" + # TODO: not compatible with multiprocessing submitted_seeds = 0 for label, seeds in self.seeds.items(): try: @@ -269,6 +218,7 @@ def submit_seeds(self): def submit_cloud_assets(self): """Submit the cloud assets to Censys ASM.""" + # TODO: not compatible with multiprocessing submitted_assets = 0 for uid, cloud_assets in self.cloud_assets.items(): try: @@ -283,12 +233,87 @@ def submit_cloud_assets(self): EventTypeEnum.CLOUD_ASSETS_SUBMITTED, count=submitted_assets ) + def process_seed(self, seed: Seed, **kwargs) -> Seed: + """Prepare a seed for submission. Also dispatch events. + + Args: + seed (Seed): Seed. + + Returns: + Seed: Processed seed. + """ + if not seed.label.startswith(self.label_prefix): + seed.label = self.label_prefix + seed.label + + self.logger.debug(f"Found Seed: {seed.to_dict()}") + self.dispatch_event(EventTypeEnum.SEED_FOUND, seed=seed, **kwargs) + return seed + + def process_cloud_asset(self, cloud_asset: CloudAsset, **kwargs) -> CloudAsset: + """Prepare a cloud asset for submission. + + Args: + cloud_asset (CloudAsset): The cloud asset to add. + **kwargs: Additional data for event dispatching. + """ + if not cloud_asset.uid.startswith(self.label_prefix): + cloud_asset.uid = self.label_prefix + cloud_asset.uid + + self.logger.debug(f"Found Cloud Asset: {cloud_asset.to_dict()}") + self.dispatch_event( + EventTypeEnum.CLOUD_ASSET_FOUND, cloud_asset=cloud_asset, **kwargs + ) + return cloud_asset + + def submit_seed_payload(self, label: str, seeds: list[Seeds]): + """Submit a seed payload. + + Args: + label (str): Label for the seeds. + seeds (list[Seeds]): List of seeds. + """ + # seed = DomainSeed(type='DOMAIN_NAME', value='example-2.com', label='AWS: Route53/Zones - 001111111112/us-west-1') + + # TODO: constants for attributes + attributes = { + "type": "com.censys.cloud-connector.seed", + "source": "cc-user-agent", + } + data = { + "label": label, + "seeds": [seed.to_dict() for seed in seeds], + } + payload = CloudEvent(attributes, data) + result = self.aurora_api.enqueue_payload(payload) + self.logger.debug(f"submit seed payload {payload}") + # TODO handle result + print(f"result {result}") + + def submit_cloud_asset_payload(self, uid: str, cloud_assets: list[CloudAsset]): + """Submit a cloud asset payload. + + Args: + uid (str): Unique identifier for the cloud asset. + cloud_assets (list[CloudAsset]): List of cloud assets. + """ + # TODO: constants for attributes + attributes = { + "type": "com.censys.cloud-connector.cloud-asset", + "source": "cc-user-agent", + } + data = { + "uid": uid, + "assets": [asset.to_dict() for asset in cloud_assets], + } + payload = CloudEvent(attributes, data) + result = self.aurora_api.enqueue_payload(payload) + self.logger.debug(f"submit asset payload {payload}") + # TODO handle result + print(f"result {result}") + def clear(self): """Clear the seeds and cloud assets.""" - # TODO: it's possible this clobbers seeds & cloudassets from other process' account + regions - - # hm.. do seeds and cloud-assets even work in multiprocessing? - # should they be contained in scanner context? + # TODO: not compatible with multiprocessing self.logger.debug(f"Clearing {len(self.seeds)} seeds") self.seeds.clear() diff --git a/src/censys/cloud_connectors/common/logger.py b/src/censys/cloud_connectors/common/logger.py index fc77d04..2999c92 100644 --- a/src/censys/cloud_connectors/common/logger.py +++ b/src/censys/cloud_connectors/common/logger.py @@ -4,7 +4,10 @@ def get_logger( - log_name: Optional[str] = "cloud_connector", level: str = "INFO", **kwargs + log_name: Optional[str] = "cloud_connector", + level: str = "INFO", + provider: str = "", + **kwargs, ) -> logging.Logger: """Returns a custom logger. @@ -20,24 +23,23 @@ def get_logger( if not logger.hasHandlers(): formatter = logging.Formatter( - fmt="%(asctime)s:%(levelname)s:%(name)s: %(message)s" - # fmt="%(asctime)s:%(levelname)s:%(name)s:%(provider)s: %(message)s" + # fmt="%(asctime)s:%(levelname)s:%(name)s: %(message)s" + fmt="%(asctime)s:%(levelname)s:%(name)s:%(provider)s: %(message)s" ) handler = logging.StreamHandler() handler.setFormatter(formatter) logger.addHandler(handler) - # TODO - add provider (AWS=account+region, GCP=org+project, AZURE=subid) to log record - # - # https://stackoverflow.com/a/57820456/19351735 - # old_factory = logging.getLogRecordFactory() - # def record_factory(*args, **kwargs): - # record = old_factory(*args, **kwargs) - # provider = kwargs.get("provider", "") - # record.provider = provider - # return record - - # logging.setLogRecordFactory(record_factory) + logging.setLogRecordFactory( + lambda *args, **kwargs: CustomLogRecord(*args, provider=provider, **kwargs) + ) logger.setLevel(level) return logger + + +# TODO: see if there is an easier way to do this (dont like the provider arg) +class CustomLogRecord(logging.LogRecord): + def __init__(self, *args, provider: str = "", **kwargs): + super().__init__(*args, **kwargs) + self.provider = provider From 711415642f8c564e8b04fa1ccc5383d9649eb63c Mon Sep 17 00:00:00 2001 From: Eric Butera Date: Tue, 22 Aug 2023 16:07:08 -0400 Subject: [PATCH 08/13] WIP: add payload source --- .../aws_connector/connector.py | 1 - src/censys/cloud_connectors/common/aurora.py | 12 +-- .../cloud_connectors/common/connector.py | 75 +++++++++++++------ src/censys/cloud_connectors/common/enums.py | 7 ++ .../cloud_connectors/common/settings.py | 5 ++ 5 files changed, 65 insertions(+), 35 deletions(-) diff --git a/src/censys/cloud_connectors/aws_connector/connector.py b/src/censys/cloud_connectors/aws_connector/connector.py index b9d7a25..04b6b89 100644 --- a/src/censys/cloud_connectors/aws_connector/connector.py +++ b/src/censys/cloud_connectors/aws_connector/connector.py @@ -464,7 +464,6 @@ def assume_role( self, ctx: AwsScanContext, # account_number: str, - # region: str, role_name: Optional[str] = None, role_session_name: Optional[str] = None, ) -> CredentialsTypeDef: diff --git a/src/censys/cloud_connectors/common/aurora.py b/src/censys/cloud_connectors/common/aurora.py index 8d893eb..3cf8d33 100644 --- a/src/censys/cloud_connectors/common/aurora.py +++ b/src/censys/cloud_connectors/common/aurora.py @@ -10,32 +10,24 @@ class Aurora(Seeds): base_path = "/api" - # TODO @backoff_wrapper + # TODO @_backoff_wrapper def enqueue_payload(self, payload: CloudEvent) -> None: """Enqueue a payload for later processing. Args: payload (CloudEvent): Payload. """ - headers, body = to_structured(payload) request_kwargs = {"timeout": self.timeout, "data": body, "headers": headers} - # url = f"{self._api_url}{self.base_path}" url = f"{self._api_url}{self.base_path}/payload/enqueue" resp = self._call_method(self._session.post, url, request_kwargs) - # TODO: handle response - # TODO: read enqueue response `event ID` (for status tracking) - print(f"TODO resp: {resp}") - if resp.ok: try: json_data = resp.json() - # if "error" not in json_data: - # return json_data return json_data except ValueError: return {"code": resp.status_code, "status": resp.reason} - return {} + raise Exception(f"Invalid response: {resp.text}") diff --git a/src/censys/cloud_connectors/common/connector.py b/src/censys/cloud_connectors/common/connector.py index 7b27573..27b5803 100644 --- a/src/censys/cloud_connectors/common/connector.py +++ b/src/censys/cloud_connectors/common/connector.py @@ -14,7 +14,7 @@ from censys.cloud_connectors.common.aurora import Aurora from .cloud_asset import CloudAsset -from .enums import EventTypeEnum, ProviderEnum +from .enums import EventTypeEnum, PayloadTypes, ProviderEnum from .logger import get_logger from .plugins import CloudConnectorPluginRegistry, EventContext from .seed import Seed @@ -265,29 +265,62 @@ def process_cloud_asset(self, cloud_asset: CloudAsset, **kwargs) -> CloudAsset: ) return cloud_asset - def submit_seed_payload(self, label: str, seeds: list[Seeds]): + def get_payload_source(self): + """Generate the CloudEvent source value. + + Returns: + str: The CloudEvent source value. + """ + # see: https://github.com/cloudevents/spec/blob/main/cloudevents/spec.md#source-1 + return f"https://github.com/censys/censys-cloud-connector/releases/tag/v{self.settings.cloud_connector_version}" + + def payload(self, payload_type: PayloadTypes, data: dict) -> CloudEvent: + """Generate a CloudEvent payload. + + Args: + type (PayloadTypes): The CloudEvent type. + data (dict): Payload data. + + Returns: + CloudEvent: The CloudEvent payload. + """ + attributes = { + "type": payload_type.value, + "source": self.get_payload_source(), + } + return CloudEvent(attributes, data) + + def enqueue_payload(self, payload: CloudEvent) -> str: + """Enqueue a CloudEvent payload. + + Args: + payload (CloudEvent): The CloudEvent payload. + + Returns: + str: Event ID. + """ + result = self.aurora_api.enqueue_payload(payload) + event_id = result.get("eventId", "ERROR") + return event_id + + def submit_seed_payload(self, label: str, seeds: list[Seeds]) -> str: """Submit a seed payload. Args: label (str): Label for the seeds. seeds (list[Seeds]): List of seeds. - """ - # seed = DomainSeed(type='DOMAIN_NAME', value='example-2.com', label='AWS: Route53/Zones - 001111111112/us-west-1') - # TODO: constants for attributes - attributes = { - "type": "com.censys.cloud-connector.seed", - "source": "cc-user-agent", - } + Returns: + str: Event ID. + """ data = { "label": label, "seeds": [seed.to_dict() for seed in seeds], } - payload = CloudEvent(attributes, data) - result = self.aurora_api.enqueue_payload(payload) - self.logger.debug(f"submit seed payload {payload}") - # TODO handle result - print(f"result {result}") + payload = self.payload(PayloadTypes.PAYLOAD_SEED, data) + event_id = self.enqueue_payload(payload) + self.logger.debug(f"seed payload {payload} event_id:{event_id}") + return event_id def submit_cloud_asset_payload(self, uid: str, cloud_assets: list[CloudAsset]): """Submit a cloud asset payload. @@ -296,20 +329,14 @@ def submit_cloud_asset_payload(self, uid: str, cloud_assets: list[CloudAsset]): uid (str): Unique identifier for the cloud asset. cloud_assets (list[CloudAsset]): List of cloud assets. """ - # TODO: constants for attributes - attributes = { - "type": "com.censys.cloud-connector.cloud-asset", - "source": "cc-user-agent", - } data = { "uid": uid, "assets": [asset.to_dict() for asset in cloud_assets], } - payload = CloudEvent(attributes, data) - result = self.aurora_api.enqueue_payload(payload) - self.logger.debug(f"submit asset payload {payload}") - # TODO handle result - print(f"result {result}") + payload = self.payload(PayloadTypes.PAYLOAD_CLOUD_ASSET, data) + event_id = self.enqueue_payload(payload) + self.logger.debug(f"cloud asset payload {payload} event_id:{event_id}") + return event_id def clear(self): """Clear the seeds and cloud assets.""" diff --git a/src/censys/cloud_connectors/common/enums.py b/src/censys/cloud_connectors/common/enums.py index f98bdc8..7a9ad4a 100644 --- a/src/censys/cloud_connectors/common/enums.py +++ b/src/censys/cloud_connectors/common/enums.py @@ -62,3 +62,10 @@ class EventTypeEnum(str, Enum, metaclass=CaseInsensitiveEnumMeta): SEEDS_SUBMITTED = "SEEDS_SUBMITTED" CLOUD_ASSETS_SUBMITTED = "CLOUD_ASSETS_SUBMITTED" SEEDS_DELETED = "SEEDS_DELETED" + + +class PayloadTypes(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Payload types supported by Censys.""" + + PAYLOAD_SEED = "com.censys.cloud-connector.seed" + PAYLOAD_CLOUD_ASSET = "com.censys.cloud-connector.cloud-asset" diff --git a/src/censys/cloud_connectors/common/settings.py b/src/censys/cloud_connectors/common/settings.py index 94f3a69..4e901bd 100644 --- a/src/censys/cloud_connectors/common/settings.py +++ b/src/censys/cloud_connectors/common/settings.py @@ -157,6 +157,11 @@ class Settings(BaseSettings): censys_cookies: dict = Field( default={}, env="CENSYS_COOKIES", description="Censys Cookies" ) + cloud_connector_version: str = Field( + default=censys_cloud_connectors_version, + env="CLOUD_CONNECTOR_VERSION", + description="Cloud Connector Version", + ) # Optional providers_config_file: str = Field( From 60604fe000d8defd6b225fb08be7d9b682ad50b5 Mon Sep 17 00:00:00 2001 From: Eric Butera Date: Mon, 2 Oct 2023 13:53:10 -0400 Subject: [PATCH 09/13] fix: issues from rebase --- .../aws_connector/connector.py | 65 ++++++++++++++----- 1 file changed, 47 insertions(+), 18 deletions(-) diff --git a/src/censys/cloud_connectors/aws_connector/connector.py b/src/censys/cloud_connectors/aws_connector/connector.py index 04b6b89..ea2644b 100644 --- a/src/censys/cloud_connectors/aws_connector/connector.py +++ b/src/censys/cloud_connectors/aws_connector/connector.py @@ -156,7 +156,7 @@ def scan_seeds(self, **kwargs): ) self.scan_contexts[scan_context_key] = scan_context - # TODO: self.dispatch_event(EventTypeEnum.SCAN_STARTED) + # TODO: self.dispatch_event(EventTypeEnum.SCAN_STARTED) <-- make sure events are working (and not broken by the worker pool change) super().scan_seeds(**kwargs) # TODO: self.dispatch_event(EventTypeEnum.SCAN_FINISHED) @@ -190,7 +190,7 @@ def scan_all(self): self.scan_contexts = {} # for credential in self.provider_settings.get_credentials(): - for credential in provider_settings.get_credentials(): + for credential in provider_setting.get_credentials(): # self.credential = credential # self.account_number = credential["account_number"] account_number = credential["account_number"] @@ -198,13 +198,13 @@ def scan_all(self): # self.ignored_tags = ignored_tags # for each account + region combination, run each seed scanner - for region in provider_settings.regions: + for region in provider_setting.regions: # self.temp_sts_cred = None # self.region = region try: with Healthcheck( self.settings, - provider_settings, + provider_setting, provider={ "region": region, "account_number": account_number, @@ -212,11 +212,10 @@ def scan_all(self): }, ): self.logger.debug( - "starting pool account:%s region:%s", + "starting seed pool account:%s region:%s", account_number, region, ) - # Credentials aren't exactly obvious how they work # Assume role flow: # - use the "primary account" access + secret key to "connect" @@ -225,7 +224,7 @@ def scan_all(self): # - note: creds expire after N time (hours?) scan_context_key = f"{account_number}_{region}" scan_context = AwsScanContext( - provider_settings=provider_settings, + provider_settings=provider_setting, credential=credential, account_number=account_number, region=region, @@ -237,7 +236,6 @@ def scan_all(self): # scan workflow: # - get seeds + cloud-assets + tags-plugin # - submit seeds + cloud-assets - print(f"scan_all logging level {self.logger.level}") pool.apply_async( self.scan_seeds, kwds={ @@ -252,7 +250,7 @@ def scan_all(self): # self.scan(**kwargs) except Exception as e: self.logger.error( - f"Unable to scan account {credential['account_number']} in region {region}. Error: {e}" + f"Unable to scan account {account_number} in region {region}. Error: {e}" ) self.dispatch_event(EventTypeEnum.SCAN_FAILED, exception=e) # self.region = None @@ -264,13 +262,34 @@ def scan_all(self): with Healthcheck( self.settings, provider_setting, - provider={"account_number": self.account_number}, + provider={"account_number": account_number}, ): + self.logger.debug( + "starting cloud-asset pool account:%s region:%s", + account_number, + region, + ) + # Credentials aren't exactly obvious how they work + # Assume role flow: + # - use the "primary account" access + secret key to "connect" + # - call STS assume role to get "temporary credentials" + # - temporary credentials can be used for all resource types in an account + region + # - note: creds expire after N time (hours?) + scan_context_key = f"{account_number}_{region}" + scan_context = AwsScanContext( + provider_settings=provider_setting, + credential=credential, + account_number=account_number, + region=region, + ignored_tags=ignored_tags, + logger=None, + # temp_sts_cred=None, + ) pool.apply_async( self.scan_cloud_assets, kwds={ - "credential": credential, - "region": region, + "scan_context_key": scan_context_key, + "scan_context": scan_context, }, ) except Exception as e: @@ -514,7 +533,7 @@ def get_api_gateway_domains_v1(self, **kwargs): key = kwargs["scan_context_key"] ctx: AwsScanContext = self.scan_contexts[key] - client: APIGatewayClient = self.get_aws_client(service=AwsServices.API_GATEWAY, ctx) + client: APIGatewayClient = self.get_aws_client(AwsServices.API_GATEWAY, ctx) label = self.format_label(SeedLabel.API_GATEWAY, ctx.account_number, ctx.region) try: @@ -569,6 +588,8 @@ def get_api_gateway_domains(self, **kwargs): """Retrieve all versions of Api Gateway data and emit seeds.""" self.get_api_gateway_domains_v1(**kwargs) self.get_api_gateway_domains_v2(**kwargs) + + # TODO: this won't work anymore since self.seeds is no longer used label = self.format_label(SeedLabel.API_GATEWAY) if not self.seeds.get(label): self.delete_seeds_by_label(label) @@ -579,9 +600,11 @@ def get_load_balancers_v1(self, **kwargs): ctx: AwsScanContext = self.scan_contexts[key] client: ElasticLoadBalancingClient = self.get_aws_client( - service=AwsServices.LOAD_BALANCER, ctx + AwsServices.LOAD_BALANCER, ctx + ) + label = self.format_label( + SeedLabel.LOAD_BALANCER, ctx.account_number, ctx.region ) - label = self.format_label(SeedLabel.LOAD_BALANCER, ctx.account_number, ctx.region) try: seeds = [] @@ -610,7 +633,9 @@ def get_load_balancers_v2(self, **kwargs): client: ElasticLoadBalancingv2Client = self.get_aws_client( AwsServices.LOAD_BALANCER_V2, ctx ) - label = self.format_label(SeedLabel.LOAD_BALANCER, ctx.account_number, ctx.region) + label = self.format_label( + SeedLabel.LOAD_BALANCER, ctx.account_number, ctx.region + ) try: seeds = [] @@ -635,6 +660,8 @@ def get_load_balancers(self, **kwargs): """Retrieve Elastic Load Balancers (ELB) data and emit seeds.""" self.get_load_balancers_v1(**kwargs) self.get_load_balancers_v2(**kwargs) + + # TODO: this won't work anymore since self.seeds is no longer used label = self.format_label(SeedLabel.LOAD_BALANCER) if not self.seeds.get(label): self.delete_seeds_by_label(label) @@ -644,7 +671,9 @@ def get_network_interfaces(self, **kwargs): key = kwargs["scan_context_key"] ctx: AwsScanContext = self.scan_contexts[key] - label = self.format_label(SeedLabel.NETWORK_INTERFACE, ctx.account_number, ctx.region) + label = self.format_label( + SeedLabel.NETWORK_INTERFACE, ctx.account_number, ctx.region + ) interfaces = self.describe_network_interfaces(ctx) instance_tags, instance_tag_sets = self.get_resource_tags(ctx) @@ -790,7 +819,7 @@ def get_rds_instances(self, **kwargs): key = kwargs["scan_context_key"] ctx: AwsScanContext = self.scan_contexts[key] - client: RDSClient = self.get_aws_client(service=AwsServices.RDS, ctx) + client: RDSClient = self.get_aws_client(AwsServices.RDS, ctx) label = self.format_label(SeedLabel.RDS, ctx.account_number, ctx.region) has_added_seeds = False From 0dd13f0180ecd730173960b1063a3077e3345282 Mon Sep 17 00:00:00 2001 From: Eric Butera Date: Tue, 3 Oct 2023 10:01:34 -0400 Subject: [PATCH 10/13] WIP: - remove unused comment code - add temp_sts_credential to ctx --- .../aws_connector/connector.py | 203 ++++-------------- 1 file changed, 46 insertions(+), 157 deletions(-) diff --git a/src/censys/cloud_connectors/aws_connector/connector.py b/src/censys/cloud_connectors/aws_connector/connector.py index ea2644b..7b6aa52 100644 --- a/src/censys/cloud_connectors/aws_connector/connector.py +++ b/src/censys/cloud_connectors/aws_connector/connector.py @@ -48,13 +48,10 @@ VALID_RECORD_TYPES = ["A", "CNAME"] IGNORED_TAGS = ["censys-cloud-connector-ignore"] +# TODO: potentially send seeds with empty values to remove "stale" seeds # TODO: fix self.{property} references: # This has to happen because if the worker pool spawns multiple account + regions, each worker will change the self.{property} value, thus making each process scan the SAME account. # -# instead of changing everything everywhere, perhaps a data structure can handle this? -# make a dictionary of provider-setting-key (which is account + region) -# then inside scan use self.scan_contexts[provider-setting-key] = {...} - # TODO: logging changes: # add account + region, current resource-type @@ -64,7 +61,7 @@ class AwsScanContext: """Required configuration context for scan().""" provider_settings: AwsSpecificSettings - # temp_sts_cred: Optional[dict] + temp_sts_cred: Optional[dict] credential: dict account_number: str region: str @@ -84,29 +81,6 @@ class AwsCloudConnector(CloudConnector): provider = ProviderEnum.AWS - # workaround for storing multiple configurations during a scan() call - # multiprocessing dictates that each worker runs scan in a different process - # each process will share the same AwsCloudConnector instance - # if a worker sets a self property, that is updated for _all_ workers - # therefore, make a dict that each worker can reference it's unique account+region configuration - # - # each scan_contexts entry will have a unique key so that multiple accounts and regions can be scanned in parallel - # scan_config_entry = { - # "temp_sts_cred": {}, "account_number": "", "region": "", "ignored_tags":[], credential: {} - # } - scan_contexts: dict[str, AwsScanContext] = {} - - # Temporary STS credentials created with Assume Role will be stored here during - # a connector scan. - # temp_sts_cred: Optional[dict] = None - # During a run this will be set to the current account being scanned - # It is common to have multiple top level accounts in providers.yml - # provider_settings: AwsSpecificSettings - # credential: dict = {} - # account_number: str - # ignored_tags: list[str] - # pool: Pool - global_ignored_tags: set[ str ] # this can remain self. as it is global across all accounts @@ -130,17 +104,10 @@ def __init__(self, settings: Settings): AwsResourceTypes.STORAGE_BUCKET: self.get_s3_instances, } self.global_ignored_tags: set[str] = set(IGNORED_TAGS) - # self.ignored_tags = [] - # self.pool = Pool(processes=settings.scan_concurrency) def scan_seeds(self, **kwargs): """Scan AWS.""" - # when scan() is called, it has been forked into a separate process (from scan_all) - - scan_context_key = kwargs["scan_context_key"] scan_context: AwsScanContext = kwargs["scan_context"] - # scan_context must be set within scan(), otherwise race conditions happen where it doesn't exist when accessed - self.scan_contexts[scan_context_key] = scan_context # multiprocessing requires separate logger instances per process # TODO: there is still something odd going on here, notice logger isnt being set to self.logger, but by even calling get_logger it "fixes" the sub-process log level from being WARNING to .env's DEBUG @@ -155,18 +122,24 @@ def scan_seeds(self, **kwargs): f"Scanning AWS - account:{scan_context.account_number} region:{scan_context.region}" ) - self.scan_contexts[scan_context_key] = scan_context # TODO: self.dispatch_event(EventTypeEnum.SCAN_STARTED) <-- make sure events are working (and not broken by the worker pool change) super().scan_seeds(**kwargs) # TODO: self.dispatch_event(EventTypeEnum.SCAN_FINISHED) def scan_cloud_assets(self, **kwargs): """Scan AWS for cloud assets.""" - # TODO: pull scan_seeds changes in after rebase - scan_context_key = kwargs["scan_context_key"] + # TODO: pull self.scan_seeds changes in after rebase + scan_context = kwargs["scan_context"] - self.scan_contexts[scan_context_key] = scan_context - self.logger.info(f"Scanning AWS account {scan_context.account_number}") + + logger = get_logger( + log_name=f"{self.provider.lower()}_cloud_connector", + level=self.settings.logging_level, + provider=f"{self.provider}_{scan_context.account_number}", + ) + scan_context.logger = logger + + self.logger.info(f"Scanning AWS - account:{scan_context.account_number}") super().scan_cloud_assets(**kwargs) def scan_all(self): @@ -187,20 +160,16 @@ def scan_all(self): # DO NOT use provider_settings anywhere in this class! # provider_settings exists for the parent CloudConnector self.provider_settings = provider_setting - self.scan_contexts = {} # for credential in self.provider_settings.get_credentials(): for credential in provider_setting.get_credentials(): - # self.credential = credential - # self.account_number = credential["account_number"] account_number = credential["account_number"] ignored_tags = self.get_ignored_tags(credential["ignore_tags"]) - # self.ignored_tags = ignored_tags # for each account + region combination, run each seed scanner for region in provider_setting.regions: + # TODO: verify sts still works and caches # self.temp_sts_cred = None - # self.region = region try: with Healthcheck( self.settings, @@ -208,21 +177,14 @@ def scan_all(self): provider={ "region": region, "account_number": account_number, - # self.account_number, }, ): self.logger.debug( - "starting seed pool account:%s region:%s", + "starting pool account:%s region:%s", account_number, region, ) - # Credentials aren't exactly obvious how they work - # Assume role flow: - # - use the "primary account" access + secret key to "connect" - # - call STS assume role to get "temporary credentials" - # - temporary credentials can be used for all resource types in an account + region - # - note: creds expire after N time (hours?) - scan_context_key = f"{account_number}_{region}" + scan_context = AwsScanContext( provider_settings=provider_setting, credential=credential, @@ -230,52 +192,30 @@ def scan_all(self): region=region, ignored_tags=ignored_tags, logger=None, - # temp_sts_cred=None, + temp_sts_cred=None, ) - # scan workflow: - # - get seeds + cloud-assets + tags-plugin - # - submit seeds + cloud-assets pool.apply_async( self.scan_seeds, kwds={ - # TODO remove all of this except `scan_context_key` - # "provider_setting": provider_setting, - # "credential": credential, - # "region": region, - "scan_context_key": scan_context_key, "scan_context": scan_context, }, ) - # self.scan(**kwargs) except Exception as e: self.logger.error( f"Unable to scan account {account_number} in region {region}. Error: {e}" ) self.dispatch_event(EventTypeEnum.SCAN_FAILED, exception=e) - # self.region = None + # TODO: find a way to combine parts of this with seeds above (scan context) # for each account, run each cloud asset scanner try: - # self.temp_sts_cred = None - # self.region = None + # TODO: verify sts still works and caches with Healthcheck( self.settings, provider_setting, provider={"account_number": account_number}, ): - self.logger.debug( - "starting cloud-asset pool account:%s region:%s", - account_number, - region, - ) - # Credentials aren't exactly obvious how they work - # Assume role flow: - # - use the "primary account" access + secret key to "connect" - # - call STS assume role to get "temporary credentials" - # - temporary credentials can be used for all resource types in an account + region - # - note: creds expire after N time (hours?) - scan_context_key = f"{account_number}_{region}" scan_context = AwsScanContext( provider_settings=provider_setting, credential=credential, @@ -283,18 +223,17 @@ def scan_all(self): region=region, ignored_tags=ignored_tags, logger=None, - # temp_sts_cred=None, + temp_sts_cred=None, ) pool.apply_async( self.scan_cloud_assets, kwds={ - "scan_context_key": scan_context_key, "scan_context": scan_context, }, ) except Exception as e: self.logger.error( - f"Unable to scan account {self.account_number}. Error: {e}" + f"Unable to scan account {account_number}. Error: {e}" ) self.dispatch_event(EventTypeEnum.SCAN_FAILED, exception=e) @@ -337,7 +276,6 @@ def credentials(self, ctx: AwsScanContext) -> dict: # Once activated the temporary STS creds will be used by all # subsequent AWS service client calls. - # TODO: fix self.credential if role_name := ctx.credential.get("role_name"): self.logger.debug(f"Using STS for role {role_name}") return self.get_assume_role_credentials(ctx) # (account_number, role_name) @@ -417,6 +355,10 @@ def get_assume_role_credentials( # if self.temp_sts_cred: # self.logger.debug("Using cached temporary STS credentials") # else: + if ctx.temp_sts_cred: + self.logger.debug("Using cached temporary STS credential") + return ctx.temp_sts_cred + try: temp_creds = self.assume_role( # account_number, region, role_name, role_session_name @@ -424,10 +366,11 @@ def get_assume_role_credentials( role_name=role_name, role_session_name=role_session_name, ) + self.logger.debug(f"Created temporary STS credentials for role {role_name}") + # TODO: fix self.temp_sts_cred # self.temp_sts_cred = self.boto_cred( - self.logger.debug(f"Created temporary STS credentials for role {role_name}") - return self.boto_cred( + ctx.temp_sts_cred = self.boto_cred( ctx.region, temp_creds["AccessKeyId"], temp_creds["SecretAccessKey"], @@ -436,8 +379,8 @@ def get_assume_role_credentials( except Exception as e: self.logger.error(f"Failed to assume role: {e}") raise - # TODO: fix self.temp_sts_cred - # return self.temp_sts_cred + + return ctx.temp_sts_cred def boto_cred( self, @@ -530,8 +473,7 @@ def assume_role( def get_api_gateway_domains_v1(self, **kwargs): """Retrieve all API Gateway V1 domains and emit seeds.""" - key = kwargs["scan_context_key"] - ctx: AwsScanContext = self.scan_contexts[key] + ctx: AwsScanContext = kwargs["scan_context"] client: APIGatewayClient = self.get_aws_client(AwsServices.API_GATEWAY, ctx) label = self.format_label(SeedLabel.API_GATEWAY, ctx.account_number, ctx.region) @@ -544,7 +486,6 @@ def get_api_gateway_domains_v1(self, **kwargs): with SuppressValidationError(): # domain_seed = DomainSeed(value=domain_name, label=label) # self.add_seed(domain_seed, api_gateway_res=domain) - # self.emit_seed(ctx, domain_seed, api_gateway_res=domain) seed = self.process_seed( DomainSeed( value=domain_name, label=label, api_gateway_res=domain @@ -557,8 +498,7 @@ def get_api_gateway_domains_v1(self, **kwargs): def get_api_gateway_domains_v2(self, **kwargs): """Retrieve API Gateway V2 domains and emit seeds.""" - key = kwargs["scan_context_key"] - ctx: AwsScanContext = self.scan_contexts[key] + ctx: AwsScanContext = kwargs["scan_context"] client: ApiGatewayV2Client = self.get_aws_client( AwsServices.API_GATEWAY_V2, ctx @@ -573,7 +513,6 @@ def get_api_gateway_domains_v2(self, **kwargs): with SuppressValidationError(): # domain_seed = DomainSeed(value=domain_name, label=label) # self.add_seed(domain_seed, api_gateway_res=domain) - # self.emit_seed(ctx, domain_seed, api_gateway_res=domain) seed = self.process_seed( DomainSeed( value=domain_name, label=label, api_gateway_res=domain @@ -590,14 +529,14 @@ def get_api_gateway_domains(self, **kwargs): self.get_api_gateway_domains_v2(**kwargs) # TODO: this won't work anymore since self.seeds is no longer used - label = self.format_label(SeedLabel.API_GATEWAY) + ctx: AwsScanContext = kwargs["scan_context"] + label = self.format_label(SeedLabel.API_GATEWAY, ctx.account_number, ctx.region) if not self.seeds.get(label): self.delete_seeds_by_label(label) def get_load_balancers_v1(self, **kwargs): """Retrieve Elastic Load Balancers (ELB) V1 data and emit seeds.""" - key = kwargs["scan_context_key"] - ctx: AwsScanContext = self.scan_contexts[key] + ctx: AwsScanContext = kwargs["scan_context"] client: ElasticLoadBalancingClient = self.get_aws_client( AwsServices.LOAD_BALANCER, ctx @@ -614,7 +553,6 @@ def get_load_balancers_v1(self, **kwargs): with SuppressValidationError(): # domain_seed = DomainSeed(value=value, label=label) # self.add_seed(domain_seed, elb_res=elb, aws_client=client) - # self.emit_seed(ctx, domain_seed, elb_res=elb, aws_client=client) seed = self.process_seed( DomainSeed(value=value, label=label), elb_res=elb, @@ -627,8 +565,7 @@ def get_load_balancers_v1(self, **kwargs): def get_load_balancers_v2(self, **kwargs): """Retrieve Elastic Load Balancers (ELB) V2 data and emit seeds.""" - key = kwargs["scan_context_key"] - ctx: AwsScanContext = self.scan_contexts[key] + ctx: AwsScanContext = kwargs["scan_context"] client: ElasticLoadBalancingv2Client = self.get_aws_client( AwsServices.LOAD_BALANCER_V2, ctx @@ -645,7 +582,6 @@ def get_load_balancers_v2(self, **kwargs): with SuppressValidationError(): # domain_seed = DomainSeed(value=value, label=label) # self.add_seed(domain_seed, elb_res=elb, aws_client=client) - # self.emit_seed(ctx, domain_seed, elb_res=elb, aws_client=client) seed = self.process_seed( DomainSeed(value=value, label=label), elb_res=elb, @@ -662,14 +598,16 @@ def get_load_balancers(self, **kwargs): self.get_load_balancers_v2(**kwargs) # TODO: this won't work anymore since self.seeds is no longer used - label = self.format_label(SeedLabel.LOAD_BALANCER) + ctx: AwsScanContext = kwargs["scan_context"] + label = self.format_label( + SeedLabel.LOAD_BALANCER, ctx.account_number, ctx.region + ) if not self.seeds.get(label): self.delete_seeds_by_label(label) def get_network_interfaces(self, **kwargs): """Retrieve EC2 Elastic Network Interfaces (ENI) data and emit seeds.""" - key = kwargs["scan_context_key"] - ctx: AwsScanContext = self.scan_contexts[key] + ctx: AwsScanContext = kwargs["scan_context"] label = self.format_label( SeedLabel.NETWORK_INTERFACE, ctx.account_number, ctx.region @@ -690,7 +628,6 @@ def get_network_interfaces(self, **kwargs): with SuppressValidationError(): # ip_seed = IpSeed(value=ip_address, label=label) # self.add_seed(ip_seed, tags=instance_tag_sets.get(instance_id)) - # self.emit_seed(ctx, ip_seed, tags=instance_tag_sets.get(instance_id)) seed = self.process_seed( IpSeed(value=ip_address, label=label), tags=instance_tag_sets.get(instance_id), @@ -780,7 +717,7 @@ def get_resource_tags( resource_tags: dict = {} resource_tag_sets: dict = {} - for tag in self.get_resource_tags_paginated(resource_types, ctx): + for tag in self.get_resource_tags_paginated(ctx, resource_types): # Tags come in two formats: # 1. Tag = { Key = "Name", Value = "actual-tag-name" } # 2. Tag = { Key = "actual-key-name", Value = "tag-value-that-is-unused-here"} @@ -816,8 +753,7 @@ def network_interfaces_ignored_tags( def get_rds_instances(self, **kwargs): """Retrieve Relational Database Services (RDS) data and emit seeds.""" - key = kwargs["scan_context_key"] - ctx: AwsScanContext = self.scan_contexts[key] + ctx: AwsScanContext = kwargs["scan_context"] client: RDSClient = self.get_aws_client(AwsServices.RDS, ctx) label = self.format_label(SeedLabel.RDS, ctx.account_number, ctx.region) @@ -834,7 +770,6 @@ def get_rds_instances(self, **kwargs): with SuppressValidationError(): # domain_seed = DomainSeed(value=domain_name, label=label) # self.add_seed(domain_seed, rds_res=instance) - # self.emit_seed(ctx, domain_seed, rds_res=instance) has_added_seeds = True seed = self.process_seed( DomainSeed(value=domain_name, label=label), rds_res=instance @@ -881,8 +816,7 @@ def _get_route53_zone_resources( def get_route53_zones(self, **kwargs): """Retrieve Route 53 Zones and emit seeds.""" - key = kwargs["scan_context_key"] - ctx: AwsScanContext = self.scan_contexts[key] + ctx: AwsScanContext = kwargs["scan_context"] client: Route53Client = self.get_aws_client(AwsServices.ROUTE53_ZONES, ctx) label = self.format_label( @@ -890,8 +824,6 @@ def get_route53_zones(self, **kwargs): ) has_added_seeds = False - # TODO: potentially send seeds with empty values to remove "stale" seeds - # Notice add_seed has extra keyword arguments - these were piped into add_seed for dispatch_event # - add_seed cannot use self.seeds anymore because concurrency # - add_seed dispatched_event PER seed, but seeds were later submitted in submit_seeds @@ -915,14 +847,6 @@ def get_route53_zones(self, **kwargs): aws_client=client, ) seeds.append(seed) - # self.emit_seed(ctx, domain_seed, route53_zone_res=zone) - # self.dispatch_event( - # EventTypeEnum.SEED_FOUND, - # seed=domain_seed, - # route53_zone_res=zone, - # aws_client=client, - # ) - # seeds.append(domain_seed) id = zone.get("Id") resource_sets = self._get_route53_zone_resources(client, id) @@ -934,23 +858,7 @@ def get_route53_zones(self, **kwargs): domain_name = resource_set.get("Name").rstrip(".") with SuppressValidationError(): # domain_seed = DomainSeed(value=domain_name, label=label) - # - # TODO: label is for this entire loop, emitting per item will make more requests than necessary! - # seeds[seed.label].push(seed) # self.add_seed(domain_seed, route53_zone_res=zone, aws_client=client) - # TODO: add_seed quadratic time - loops here, then loops to submit seed - # self.emit_seed( - # ctx, domain_seed, route53_zone_res=zone, aws_client=client - # ) - # - # self.dispatch_event( - # EventTypeEnum.SEED_FOUND, - # seed=domain_seed, - # route53_zone_res=zone, - # aws_client=client, - # ) - # seeds.append(domain_seed) - seed = self.process_seed( DomainSeed(value=domain_name, label=label), route53_zone_res=zone, @@ -967,8 +875,7 @@ def get_route53_zones(self, **kwargs): def get_ecs_instances(self, **kwargs): """Retrieve Elastic Container Service data and emit seeds.""" - key = kwargs["scan_context_key"] - ctx: AwsScanContext = self.scan_contexts[key] + ctx: AwsScanContext = kwargs["scan_context"] ecs: ECSClient = self.get_aws_client(AwsServices.ECS, ctx) ec2: EC2Client = self.get_aws_client(AwsServices.EC2, ctx) @@ -1004,14 +911,7 @@ def get_ecs_instances(self, **kwargs): with SuppressValidationError(): # ip_seed = IpSeed(value=ip_address, label=label) - # TODO: don't use add_seed - # instead, emit Payload - # modifying add seed would require managing account+region or use AwsScanContext which requires more time than available # self.add_seed(ip_seed, ecs_res=instance) - # self.emit_seed(ctx, ip_seed, ecs_res=instance) - # or maybe self.enqueue(seed) - # would be best to async queue these - # but we are in a pool already... seed = self.process_seed( IpSeed(value=ip_address, label=label), ecs_res=instance ) @@ -1039,8 +939,7 @@ def get_s3_region(self, client: S3Client, bucket: str) -> str: def get_s3_instances(self, **kwargs): """Retrieve Simple Storage Service data and emit seeds.""" - key = kwargs["scan_context_key"] - ctx: AwsScanContext = self.scan_contexts[key] + ctx: AwsScanContext = kwargs["scan_context"] client: S3Client = self.get_aws_client(AwsServices.STORAGE_BUCKET, ctx) @@ -1063,15 +962,10 @@ def get_s3_instances(self, **kwargs): label = self.format_label( SeedLabel.STORAGE_BUCKET, ctx.account_number, - # oh this is interesting.... lookup_region OR ctx.region.. which one? - # pretty sure it's lookup_region, otherwise whats the point of looking up the bucket's region? lookup_region, # ctx.region, ) - # TODO: this isnt right - # assets = [] - with SuppressValidationError(): asset = AwsStorageBucketAsset( value=AwsStorageBucketAsset.url(bucket_name, lookup_region), @@ -1087,15 +981,10 @@ def get_s3_instances(self, **kwargs): asset = self.process_cloud_asset( asset, bucket_name=bucket_name, aws_client=client ) - # assets.append(asset) if label not in findings: findings[label] = [] findings[label].append(asset) - # TODO convert this to findings below - # self.submit_cloud_asset_payload(label, assets) - - # TODO: submit findings map here for label, assets in findings.items(): self.submit_cloud_asset_payload(label, assets) except ClientError as e: From 4e626effe61b47afe5d9c6a1578811d1e8cce49c Mon Sep 17 00:00:00 2001 From: Eric Butera Date: Fri, 20 Oct 2023 12:33:58 -0400 Subject: [PATCH 11/13] build: poetry lock --- poetry.lock | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/poetry.lock b/poetry.lock index 371b2da..40e594c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "adal" @@ -3063,12 +3063,12 @@ files = [ google-auth = ">=2.14.1,<3.0dev" googleapis-common-protos = ">=1.56.2,<2.0dev" grpcio = [ + {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ + {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" requests = ">=2.18.0,<3.0.0dev" @@ -3135,8 +3135,8 @@ google-cloud-org-policy = ">=0.1.2,<2.0.0" google-cloud-os-config = ">=1.0.0,<2.0.0dev" grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev" proto-plus = [ - {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""}, {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""}, + {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" @@ -3154,8 +3154,8 @@ files = [ [package.dependencies] google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} proto-plus = [ - {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""}, {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""}, + {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" @@ -3173,8 +3173,8 @@ files = [ [package.dependencies] google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} proto-plus = [ - {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""}, {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""}, + {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" @@ -3193,8 +3193,8 @@ files = [ google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev" proto-plus = [ - {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""}, {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""}, + {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" @@ -5618,4 +5618,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "fa3a761a2df073f297e1044a6f54c62bb14ba4055b1a438b043707c32b696e79" +content-hash = "e879fb7d689b0a6c227a150d860ff13466c00fe8646d01aa051a91b14ee8c6fb" From 417f8561e2bb404ad20b6c3e42df6963345646e1 Mon Sep 17 00:00:00 2001 From: Eric Butera Date: Wed, 25 Oct 2023 09:02:43 -0400 Subject: [PATCH 12/13] feat: azure multiprocessing --- .../azure_connector/connector.py | 497 +++++++++++++----- src/censys/cloud_connectors/common/aurora.py | 4 +- .../cloud_connectors/common/connector.py | 49 +- 3 files changed, 400 insertions(+), 150 deletions(-) diff --git a/src/censys/cloud_connectors/azure_connector/connector.py b/src/censys/cloud_connectors/azure_connector/connector.py index 320b7a7..3d55927 100644 --- a/src/censys/cloud_connectors/azure_connector/connector.py +++ b/src/censys/cloud_connectors/azure_connector/connector.py @@ -1,5 +1,7 @@ """Azure Cloud Connector.""" from collections.abc import Generator +from dataclasses import dataclass +from multiprocessing import Pool from typing import Optional from azure.core.exceptions import ( @@ -23,6 +25,7 @@ from censys.cloud_connectors.common.context import SuppressValidationError from censys.cloud_connectors.common.enums import EventTypeEnum, ProviderEnum from censys.cloud_connectors.common.healthcheck import Healthcheck +from censys.cloud_connectors.common.logger import get_logger from censys.cloud_connectors.common.seed import DomainSeed, IpSeed from censys.cloud_connectors.common.settings import Settings @@ -30,6 +33,18 @@ from .settings import AzureSpecificSettings +@dataclass +class AzureScanContext: + """Required configuration context for Azure scans.""" + + label_prefix: str + provider_settings: AzureSpecificSettings + subscription_id: str + credentials: ClientSecretCredential + possible_labels: set[str] # = set() # TODO: verify this works + scan_all_regions: bool + + class AzureCloudConnector(CloudConnector): """Azure Cloud Connector.""" @@ -56,72 +71,148 @@ def __init__(self, settings: Settings): self.cloud_asset_scanners = { AzureResourceTypes.STORAGE_ACCOUNTS: self.get_storage_containers, } - self.possible_labels = set() + # self.possible_labels = set() self.scan_all_regions = settings.azure_refresh_all_regions - def get_all_labels(self): + def get_all_labels(self, subscription_id: str, label_prefix: str) -> set[str]: """Get Azure labels.""" subscription_client = SubscriptionClient(self.credentials) - locations = subscription_client.subscriptions.list_locations( - self.subscription_id - ) + locations = subscription_client.subscriptions.list_locations(subscription_id) - self.possible_labels.clear() + # self.possible_labels.clear() + possible_labels = set() for location in locations: - self.possible_labels.add( - f"{self.label_prefix}{self.subscription_id}/{location.name}" + # self.possible_labels.add( + possible_labels.add( + # f"{self.label_prefix}{self.subscription_id}/{location.name}" + f"{label_prefix}{subscription_id}/{location.name}" ) - def scan(self): - """Scan Azure Subscription.""" + return possible_labels + + def scan(self, **kwargs): + """Scan Azure Subscription. + + Args: + **kwargs: Keyword arguments. + scan_context (AzureScanContext): Azure scan context. + """ + ctx: AzureScanContext = kwargs["scan_context"] + + ctx.credentials = ClientSecretCredential( + tenant_id=ctx.provider_settings.tenant_id, + client_id=ctx.provider_settings.client_id, + client_secret=ctx.provider_settings.client_secret, + ) + + logger = get_logger( + log_name=f"{self.provider.lower()}_cloud_connector", + level=self.settings.logging_level, + provider=f"{self.provider}_{ctx.subscription_id}", + ) + ctx.logger = logger + ctx.logger.info(f"Scanning {self.provider} - sub:{ctx.subscription_id}") + with Healthcheck( self.settings, - self.provider_settings, - provider={"subscription_id": self.subscription_id}, + ctx.provider_settings, # self.provider_settings, + # provider={"subscription_id": self.subscription_id}, + provider={"subscription_id": ctx.subscription_id}, exception_map={ ClientAuthenticationError: "PERMISSIONS", }, ): - super().scan() + super().scan(**kwargs) def scan_all(self): """Scan all Azure Subscriptions.""" provider_settings: dict[ tuple, AzureSpecificSettings ] = self.settings.providers.get(self.provider, {}) + + self.logger.debug( + f"scanning {self.provider} using {self.settings.scan_concurrency} processes" + ) + + label_prefix = self.get_provider_label_prefix() + pool = Pool(processes=self.settings.scan_concurrency) + for provider_setting in provider_settings.values(): + # this is so confusing - plural settings to setting? self.provider_settings = provider_setting - self.credentials = ClientSecretCredential( - tenant_id=provider_setting.tenant_id, - client_id=provider_setting.client_id, - client_secret=provider_setting.client_secret, - ) + + # self.credentials = ClientSecretCredential( + # credentials = ClientSecretCredential( + # tenant_id=provider_setting.tenant_id, + # client_id=provider_setting.client_id, + # client_secret=provider_setting.client_secret, + # ) + for subscription_id in self.provider_settings.subscription_id: self.logger.info(f"Scanning Azure Subscription {subscription_id}") - self.subscription_id = subscription_id + # self.subscription_id = subscription_id + try: + possible_labels: set[str] if self.scan_all_regions: - self.get_all_labels() + possible_labels = self.get_all_labels( + subscription_id, label_prefix + ) + else: + possible_labels = set() + + scan_context = AzureScanContext( + provider_settings=provider_setting, + subscription_id=subscription_id, + credentials=None, + # credentials=credentials, + possible_labels=possible_labels, # self.possible_labels, + scan_all_regions=self.scan_all_regions, + label_prefix=label_prefix, + ) - self.scan() + # self.scan(**{"scan_context": scan_context}) + # self.scan_seeds(**{"scan_context": scan_context}) + # pool.apply_async( + pool.apply_async( + # self.scan_seeds, + self.scan, + kwds={ + "scan_context": scan_context, + }, + error_callback=lambda e: self.logger.error(f"Async Error: {e}"), + ) + # TODO: figure out how to make this wait until scan is finished: if self.scan_all_regions: - for label_not_found in self.possible_labels: + # for label_not_found in self.possible_labels: + for label_not_found in scan_context.possible_labels: self.delete_seeds_by_label(label_not_found) except Exception as e: self.logger.error( f"Unable to scan Azure Subscription {subscription_id}. Error: {e}" ) self.dispatch_event(EventTypeEnum.SCAN_FAILED, exception=e) - self.subscription_id = None - - def format_label(self, asset: AzureModel) -> str: + # self.subscription_id = None + + pool.close() + pool.join() + + # TODO: reorder params (sub_id, asset, label_prefix) + def format_label( + self, + asset: AzureModel, + subscription_id: str, + label_prefix: str, + ) -> str: """Format Azure asset label. Args: asset (AzureModel): Azure asset. + subscription_id (str): Azure subscription ID. + label_prefix (str): Label prefix. Returns: str: Formatted label. @@ -132,116 +223,219 @@ def format_label(self, asset: AzureModel) -> str: asset_location: Optional[str] = getattr(asset, "location", None) if not asset_location: raise ValueError("Asset has no location.") - return f"{self.label_prefix}{self.subscription_id}/{asset_location}" + # return f"{self.label_prefix}{self.subscription_id}/{asset_location}" + return f"{label_prefix}{subscription_id}/{asset_location}" + + def get_ip_addresses(self, **kwargs): + """Get Azure IP addresses. + + Args: + **kwargs: Keyword arguments. + scan_context (AzureScanContext): Azure scan context. + """ + ctx: AzureScanContext = kwargs["scan_context"] + # network_client = NetworkManagementClient(self.credentials, self.subscription_id) + network_client = NetworkManagementClient(ctx.credentials, ctx.subscription_id) - def get_ip_addresses(self): - """Get Azure IP addresses.""" - network_client = NetworkManagementClient(self.credentials, self.subscription_id) try: assets = network_client.public_ip_addresses.list_all() except HttpResponseError as error: - self.logger.error(f"Failed to get Azure IP addresses: {error.message}") + ctx.logger.error(f"Failed to get Azure IP addresses: {error.message}") return for asset in assets: - asset_dict = asset.as_dict() - if ip_address := asset_dict.get("ip_address"): - with SuppressValidationError(): - label = self.format_label(asset) - ip_seed = IpSeed(value=ip_address, label=label) - self.add_seed(ip_seed) - self.possible_labels.discard(label) - - def get_clusters(self): - """Get Azure clusters.""" + try: + seeds = [] + label = self.format_label(asset, ctx.subscription_id, ctx.label_prefix) + asset_dict = asset.as_dict() + + if ip_address := asset_dict.get("ip_address"): + with SuppressValidationError(): + # ip_seed = IpSeed(value=ip_address, label=label) + # self.add_seed(ip_seed) + # self.possible_labels.discard(label) + seed = self.process_seed(IpSeed(value=ip_address, label=label)) + seeds.append(seed) + ctx.possible_labels.discard(label) + + self.submit_seed_payload(label, seeds) + except Exception as e: + ctx.logger.error(f"Failed to process Azure IP addresses: {e}") + + def get_clusters(self, **kwargs): + """Get Azure clusters. + + Args: + **kwargs: Keyword arguments. + scan_context (AzureScanContext): Azure scan context. + """ + ctx: AzureScanContext = kwargs["scan_context"] + # container_client = ContainerInstanceManagementClient(self.credentials, self.subscription_id) container_client = ContainerInstanceManagementClient( - self.credentials, self.subscription_id + ctx.credentials, ctx.subscription_id ) try: assets = container_client.container_groups.list() except HttpResponseError as error: - self.logger.error( + ctx.logger.error( f"Failed to get Azure Container Instances: {error.message}" ) return for asset in assets: - asset_dict = asset.as_dict() - if ( - (ip_address_dict := asset_dict.get("ip_address")) - and (ip_address_dict.get("type") == "Public") - and (ip_address := ip_address_dict.get("ip")) - ): - label = self.format_label(asset) - with SuppressValidationError(): - ip_seed = IpSeed(value=ip_address, label=label) - self.add_seed(ip_seed) - self.possible_labels.discard(label) - if domain := ip_address_dict.get("fqdn"): + try: + asset_dict = asset.as_dict() + if ( + (ip_address_dict := asset_dict.get("ip_address")) + and (ip_address_dict.get("type") == "Public") + and (ip_address := ip_address_dict.get("ip")) + ): + seeds = [] + label = self.format_label( + asset, ctx.subscription_id, ctx.label_prefix + ) + with SuppressValidationError(): - domain_seed = DomainSeed(value=domain, label=label) - self.add_seed(domain_seed) - self.possible_labels.discard(label) + # ip_seed = IpSeed(value=ip_address, label=label) + # self.add_seed(ip_seed) + # self.possible_labels.discard(label) + seed = self.process_seed(IpSeed(value=ip_address, label=label)) + seeds.append(seed) + ctx.possible_labels.discard(label) + + if domain := ip_address_dict.get("fqdn"): + with SuppressValidationError(): + # domain_seed = DomainSeed(value=domain, label=label) + # self.add_seed(domain_seed) + # self.possible_labels.discard(label) + seed = self.process_seed( + DomainSeed(value=domain, label=label) + ) + seeds.append(seed) + ctx.possible_labels.discard(label) + + self.submit_seed_payload(label, seeds) + except Exception as e: + ctx.logger.error(f"Failed to process Azure clusters: {e}") + + def get_sql_servers(self, **kwargs): + """Get Azure SQL servers. + + Args: + **kwargs: Keyword arguments. + scan_context (AzureScanContext): Azure scan context. + """ + ctx: AzureScanContext = kwargs["scan_context"] + # sql_client = SqlManagementClient(self.credentials, self.subscription_id) + sql_client = SqlManagementClient(ctx.credentials, ctx.subscription_id) - def get_sql_servers(self): - """Get Azure SQL servers.""" - sql_client = SqlManagementClient(self.credentials, self.subscription_id) try: assets = sql_client.servers.list() except HttpResponseError as error: - self.logger.error(f"Failed to get Azure SQL servers: {error.message}") + ctx.logger.error(f"Failed to get Azure SQL servers: {error.message}") return for asset in assets: asset_dict = asset.as_dict() - if ( - domain := asset_dict.get("fully_qualified_domain_name") - ) and asset_dict.get("public_network_access") == "Enabled": - with SuppressValidationError(): - label = self.format_label(asset) - domain_seed = DomainSeed(value=domain, label=label) - self.add_seed(domain_seed) - self.possible_labels.discard(label) - - def get_dns_records(self): - """Get Azure DNS records.""" - dns_client = DnsManagementClient(self.credentials, self.subscription_id) + try: + if ( + domain := asset_dict.get("fully_qualified_domain_name") + ) and asset_dict.get("public_network_access") == "Enabled": + with SuppressValidationError(): + label = self.format_label( + asset, ctx.subscription_id, ctx.label_prefix + ) + # domain_seed = DomainSeed(value=domain, label=label) + # self.add_seed(domain_seed) + # self.possible_labels.discard(label) + + # TODO: verify that asset+label->payload is correct; other methods have a seeds array that is appended to and 1 payload is sent + seed = self.process_seed(DomainSeed(value=domain, label=label)) + ctx.possible_labels.discard(label) + self.submit_seed_payload(label, [seed]) + except Exception as e: + ctx.logger.error(f"Failed to process Azure SQL servers: {e}") + + def get_dns_records(self, **kwargs): + """Get Azure DNS records. + + Args: + **kwargs: Keyword arguments. + scan_context (AzureScanContext): Azure scan context. + """ + ctx: AzureScanContext = kwargs["scan_context"] + # dns_client = DnsManagementClient(self.credentials, self.subscription_id) + dns_client = DnsManagementClient(ctx.credentials, ctx.subscription_id) + try: zones = dns_client.zones.list() except HttpResponseError as error: - self.logger.error(f"Failed to get Azure DNS records: {error.message}") + ctx.logger.error(f"Failed to get Azure DNS records: {error.message}") return - for zone in zones: - zone_dict = zone.as_dict() - label = self.format_label(zone) - # TODO: Do we need to check if zone is public? (ie. do we care?) - if zone_dict.get("zone_type") != "Public": # pragma: no cover - continue - zone_resource_group = zone_dict.get("id").split("/")[4] - for asset in dns_client.record_sets.list_all_by_dns_zone( - zone_resource_group, zone_dict.get("name") - ): - asset_dict = asset.as_dict() - if domain_name := asset_dict.get("fqdn"): - with SuppressValidationError(): - domain_seed = DomainSeed(value=domain_name, label=label) - self.add_seed(domain_seed) - self.possible_labels.discard(label) - if cname := asset_dict.get("cname_record", {}).get("cname"): - with SuppressValidationError(): - domain_seed = DomainSeed(value=cname, label=label) - self.add_seed(domain_seed) - self.possible_labels.discard(label) - for a_record in asset_dict.get("a_records", []): - ip_address = a_record.get("ipv4_address") - if not ip_address: - continue + try: + for zone in zones: + zone_dict = zone.as_dict() + # TODO: Do we need to check if zone is public? (ie. do we care?) + if zone_dict.get("zone_type") != "Public": # pragma: no cover + continue - with SuppressValidationError(): - ip_seed = IpSeed(value=ip_address, label=label) - self.add_seed(ip_seed) - self.possible_labels.discard(label) + try: + label = self.format_label( + zone, ctx.subscription_id, ctx.label_prefix + ) + seeds = [] + + zone_resource_group = zone_dict.get("id").split("/")[4] + for asset in dns_client.record_sets.list_all_by_dns_zone( + zone_resource_group, zone_dict.get("name") + ): + asset_dict = asset.as_dict() + + if domain_name := asset_dict.get("fqdn"): + with SuppressValidationError(): + # domain_seed = DomainSeed(value=domain_name, label=label) + # self.add_seed(domain_seed) + # self.possible_labels.discard(label) + seed = self.process_seed( + DomainSeed(value=domain_name, label=label) + ) + seeds.append(seed) + ctx.possible_labels.discard(label) + + if cname := asset_dict.get("cname_record", {}).get("cname"): + with SuppressValidationError(): + # domain_seed = DomainSeed(value=cname, label=label) + # self.add_seed(domain_seed) + # self.possible_labels.discard(label) + seed = self.process_seed( + DomainSeed(value=cname, label=label) + ) + seeds.append(seed) + ctx.possible_labels.discard(label) + + for a_record in asset_dict.get("a_records", []): + ip_address = a_record.get("ipv4_address") + if not ip_address: + continue + + with SuppressValidationError(): + # ip_seed = IpSeed(value=ip_address, label=label) + # self.add_seed(ip_seed) + # self.possible_labels.discard(label) + seed = self.process_seed( + IpSeed(value=ip_address, label=label) + ) + seeds.append(seed) + ctx.possible_labels.discard(label) + + self.submit_seed_payload(label, seeds) + except Exception as e: + ctx.logger.error(f"Failed to process Azure DNS records: {e}") + + except Exception as e: + # TODO: health check should have a way to emit errors yet still proceed to next resource type + ctx.logger.error(f"Failed to list Azure DNS records: {e}") def _list_containers( self, bucket_client: BlobServiceClient, account: StorageAccount @@ -263,46 +457,73 @@ def _list_containers( ) return - def get_storage_containers(self): - """Get Azure containers.""" - storage_client = StorageManagementClient(self.credentials, self.subscription_id) + def get_storage_containers(self, **kwargs): + """Get Azure containers. + + Args: + **kwargs: Keyword arguments. + scan_context (AzureScanContext): Azure scan context. + """ + ctx: AzureScanContext = kwargs["scan_context"] + # storage_client = StorageManagementClient(self.credentials, self.subscription_id) + storage_client = StorageManagementClient(ctx.credentials, ctx.subscription_id) + try: accounts = storage_client.storage_accounts.list() except HttpResponseError as error: - self.logger.error(f"Failed to get Azure storage accounts: {error.message}") + ctx.logger.error(f"Failed to get Azure storage accounts: {error.message}") return for account in accounts: - bucket_client = BlobServiceClient( - f"https://{account.name}.blob.core.windows.net/", self.credentials - ) - label = self.format_label(account) - account_dict = account.as_dict() - if (custom_domain := account_dict.get("custom_domain")) and ( - domain := custom_domain.get("name") - ): - with SuppressValidationError(): - domain_seed = DomainSeed(value=domain, label=label) - self.add_seed(domain_seed) - self.possible_labels.discard(label) - uid = f"{self.subscription_id}/{self.credentials._tenant_id}/{account.name}" - - for container in self._list_containers(bucket_client, account): - try: - container_client = bucket_client.get_container_client(container) - container_url = container_client.url + try: + # bucket_client = BlobServiceClient(f"https://{account.name}.blob.core.windows.net/", self.credentials) + bucket_client = BlobServiceClient( + f"https://{account.name}.blob.core.windows.net/", ctx.credentials + ) + + label = self.format_label( + account, ctx.subscription_id, ctx.label_prefix + ) + account_dict = account.as_dict() + + # create seed from storage container + if (custom_domain := account_dict.get("custom_domain")) and ( + domain := custom_domain.get("name") + ): with SuppressValidationError(): - container_asset = AzureContainerAsset( - value=container_url, - uid=uid, - scan_data={ - "accountNumber": self.subscription_id, - "publicAccess": container.public_access, - "location": account.location, - }, + # domain_seed = DomainSeed(value=domain, label=label) + # self.add_seed(domain_seed) + # self.possible_labels.discard(label) + seed = self.process_seed(DomainSeed(value=domain, label=label)) + ctx.possible_labels.discard(label) + self.submit_seed_payload(label, seed) + + # create cloud asset from storage container + # uid = f"{self.subscription_id}/{self.credentials._tenant_id}/{account.name}" + uid = ( + f"{ctx.subscription_id}/{ctx.credentials._tenant_id}/{account.name}" + ) + + for container in self._list_containers(bucket_client, account): + try: + container_client = bucket_client.get_container_client(container) + container_url = container_client.url + with SuppressValidationError(): + container_asset = AzureContainerAsset( + value=container_url, + uid=uid, + scan_data={ + "accountNumber": ctx.subscription_id, # "accountNumber": self.subscription_id, + "publicAccess": container.public_access, + "location": account.location, + }, + ) + # self.add_cloud_asset(container_asset) + self.submit_cloud_asset_payload(label, [container_asset]) + except ServiceRequestError as error: # pragma: no cover + ctx.logger.error( + f"Failed to get Azure container {container} for {account.name}: {error.message}" ) - self.add_cloud_asset(container_asset) - except ServiceRequestError as error: # pragma: no cover - self.logger.error( - f"Failed to get Azure container {container} for {account.name}: {error.message}" - ) + + except Exception as e: + ctx.logger.error(f"Failed to process Azure storage accounts: {e}") diff --git a/src/censys/cloud_connectors/common/aurora.py b/src/censys/cloud_connectors/common/aurora.py index 3cf8d33..e76e26f 100644 --- a/src/censys/cloud_connectors/common/aurora.py +++ b/src/censys/cloud_connectors/common/aurora.py @@ -8,7 +8,7 @@ class Aurora(Seeds): """Aurora API client.""" - base_path = "/api" + base_path = "/api/integrations" # TODO @_backoff_wrapper def enqueue_payload(self, payload: CloudEvent) -> None: @@ -20,7 +20,7 @@ def enqueue_payload(self, payload: CloudEvent) -> None: headers, body = to_structured(payload) request_kwargs = {"timeout": self.timeout, "data": body, "headers": headers} - url = f"{self._api_url}{self.base_path}/payload/enqueue" + url = f"{self._api_url}{self.base_path}/v1/payloads/enqueue" resp = self._call_method(self._session.post, url, request_kwargs) if resp.ok: diff --git a/src/censys/cloud_connectors/common/connector.py b/src/censys/cloud_connectors/common/connector.py index 27b5803..b67f934 100644 --- a/src/censys/cloud_connectors/common/connector.py +++ b/src/censys/cloud_connectors/common/connector.py @@ -49,7 +49,7 @@ def __init__(self, settings: Settings): """ if not self.provider: raise ValueError("The provider must be set.") - self.label_prefix = self.provider.label() + ": " + self.label_prefix = self.get_provider_label_prefix() self.settings = settings self.logger = get_logger( log_name=f"{self.provider.lower()}_cloud_connector", @@ -79,6 +79,14 @@ def __init__(self, settings: Settings): self.cloud_assets = defaultdict(set) self.current_service = None + def get_provider_label_prefix(self): + """Get the provider label prefix. + + Returns: + str: Provider label prefix. + """ + return self.provider.label() + ": " + def delete_seeds_by_label(self, label: str): """Replace seeds for [label] with an empty list. @@ -86,8 +94,14 @@ def delete_seeds_by_label(self, label: str): label: Label for seeds to be deleted. """ try: - self.logger.debug(f"Deleting any seeds matching label {label}.") - self.seeds_api.replace_seeds_by_label(label, [], True) + if self.settings.dry_run: + self.logger.debug( + f"Dry run: Skipping deleting any seeds matching label {label}." + ) + else: + self.logger.debug(f"Deleting any seeds matching label {label}.") + self.seeds_api.replace_seeds_by_label(label, [], True) + except CensysAsmException as e: self.logger.error(f"Error deleting seeds for label {label}: {e}") self.logger.info(f"Deleted any seeds for label {label}.") @@ -171,6 +185,7 @@ def dispatch_event( context = self.get_event_context(event_type, service) CloudConnectorPluginRegistry.dispatch_event(context=context, **kwargs) + # TODO: remove def add_seed(self, seed: Seed, **kwargs): """Add a seed. @@ -185,6 +200,7 @@ def add_seed(self, seed: Seed, **kwargs): self.logger.debug(f"Found Seed: {seed.to_dict()}") self.dispatch_event(EventTypeEnum.SEED_FOUND, seed=seed, **kwargs) + # TODO: remove def add_cloud_asset(self, cloud_asset: CloudAsset, **kwargs): """Add a cloud asset. @@ -201,6 +217,7 @@ def add_cloud_asset(self, cloud_asset: CloudAsset, **kwargs): EventTypeEnum.CLOUD_ASSET_FOUND, cloud_asset=cloud_asset, **kwargs ) + # TODO remove def submit_seeds(self): """Submit the seeds to Censys ASM.""" # TODO: not compatible with multiprocessing @@ -216,6 +233,7 @@ def submit_seeds(self): self.logger.info(f"Submitted {submitted_seeds} seeds.") self.dispatch_event(EventTypeEnum.SEEDS_SUBMITTED, count=submitted_seeds) + # TODO remove def submit_cloud_assets(self): """Submit the cloud assets to Censys ASM.""" # TODO: not compatible with multiprocessing @@ -233,6 +251,7 @@ def submit_cloud_assets(self): EventTypeEnum.CLOUD_ASSETS_SUBMITTED, count=submitted_assets ) + # TODO: rename this method to something like prepare_seed def process_seed(self, seed: Seed, **kwargs) -> Seed: """Prepare a seed for submission. Also dispatch events. @@ -299,9 +318,17 @@ def enqueue_payload(self, payload: CloudEvent) -> str: Returns: str: Event ID. """ - result = self.aurora_api.enqueue_payload(payload) - event_id = result.get("eventId", "ERROR") - return event_id + if self.settings.dry_run: + self.logger.debug("Dry run: Skipping enqueueing payload.") + return "dry-run-event-id" + else: + result = self.aurora_api.enqueue_payload(payload) + event_id = result.get("eventId") + if not event_id: + self.logger.error( + f"Error enqueuing payload {payload} event_id:{event_id}" + ) + return event_id def submit_seed_payload(self, label: str, seeds: list[Seeds]) -> str: """Submit a seed payload. @@ -319,7 +346,7 @@ def submit_seed_payload(self, label: str, seeds: list[Seeds]) -> str: } payload = self.payload(PayloadTypes.PAYLOAD_SEED, data) event_id = self.enqueue_payload(payload) - self.logger.debug(f"seed payload {payload} event_id:{event_id}") + self.logger.debug(f"Payload label:{label} event_id:{event_id}") return event_id def submit_cloud_asset_payload(self, uid: str, cloud_assets: list[CloudAsset]): @@ -335,7 +362,7 @@ def submit_cloud_asset_payload(self, uid: str, cloud_assets: list[CloudAsset]): } payload = self.payload(PayloadTypes.PAYLOAD_CLOUD_ASSET, data) event_id = self.enqueue_payload(payload) - self.logger.debug(f"cloud asset payload {payload} event_id:{event_id}") + self.logger.debug(f"Payload uid:{uid} event_id:{event_id}") return event_id def clear(self): @@ -347,6 +374,7 @@ def clear(self): self.logger.debug(f"Clearing {len(self.cloud_assets)} cloud assets") self.cloud_assets.clear() + # TODO: remove def submit(self, **kwargs): # pragma: no cover """Submit the seeds and cloud assets to Censys ASM.""" if self.settings.dry_run: @@ -358,6 +386,7 @@ def submit(self, **kwargs): # pragma: no cover self.clear() + # TODO: remove def submit_seeds_wrapper(self): # pragma: no cover """Submit the seeds to Censys ASM.""" if self.settings.dry_run: @@ -381,7 +410,7 @@ def scan_seeds(self, **kwargs): self.logger.info("Gathering seeds...") self.dispatch_event(EventTypeEnum.SCAN_STARTED) self.get_seeds(**kwargs) - self.submit_seeds_wrapper() + # self.submit_seeds_wrapper() self.dispatch_event(EventTypeEnum.SCAN_FINISHED) def scan_cloud_assets(self, **kwargs): @@ -399,7 +428,7 @@ def scan(self, **kwargs): self.dispatch_event(EventTypeEnum.SCAN_STARTED) self.get_seeds(**kwargs) self.get_cloud_assets(**kwargs) - self.submit() + # self.submit() self.dispatch_event(EventTypeEnum.SCAN_FINISHED) @abstractmethod From e9848fcdfa880953b1e61d1b931661ca0b574c95 Mon Sep 17 00:00:00 2001 From: Eric Butera Date: Thu, 26 Oct 2023 16:08:16 -0400 Subject: [PATCH 13/13] refactor: azure - credential passing - possible labels + delete seeds by label - cloud asset use label prefix - healthcheck log errors if dry run enabled --- .../azure_connector/connector.py | 85 ++++++++++++------- .../cloud_connectors/common/connector.py | 10 ++- .../cloud_connectors/common/healthcheck.py | 8 ++ 3 files changed, 68 insertions(+), 35 deletions(-) diff --git a/src/censys/cloud_connectors/azure_connector/connector.py b/src/censys/cloud_connectors/azure_connector/connector.py index 3d55927..8ab48fc 100644 --- a/src/censys/cloud_connectors/azure_connector/connector.py +++ b/src/censys/cloud_connectors/azure_connector/connector.py @@ -33,6 +33,8 @@ from .settings import AzureSpecificSettings +# TODO: make a ctor that has required params but also optional params +# the scan() method is currently building all of this out which could lead to misconfiguration and runtime errors @dataclass class AzureScanContext: """Required configuration context for Azure scans.""" @@ -49,10 +51,13 @@ class AzureCloudConnector(CloudConnector): """Azure Cloud Connector.""" provider = ProviderEnum.AZURE - subscription_id: str - credentials: ClientSecretCredential + # subscription_id: str + # credentials: ClientSecretCredential + # + # TODO: provider_settings is used in the parent class... figure out how to break out those methods (would fix the parent self.logger calls using root logger instead of ctx.logger) provider_settings: AzureSpecificSettings - possible_labels: set[str] + # + # possible_labels: set[str] scan_all_regions: bool def __init__(self, settings: Settings): @@ -74,9 +79,22 @@ def __init__(self, settings: Settings): # self.possible_labels = set() self.scan_all_regions = settings.azure_refresh_all_regions - def get_all_labels(self, subscription_id: str, label_prefix: str) -> set[str]: - """Get Azure labels.""" - subscription_client = SubscriptionClient(self.credentials) + def get_all_labels( + self, + credentials: ClientSecretCredential, + subscription_id: str, + label_prefix: str, + ) -> set[str]: + """Get Azure labels. + + Args: + subscription_id (str): Azure subscription ID. + label_prefix (str): Label prefix. + + Returns: + set[str]: Azure labels. + """ + subscription_client = SubscriptionClient(credentials) locations = subscription_client.subscriptions.list_locations(subscription_id) @@ -107,6 +125,15 @@ def scan(self, **kwargs): client_secret=ctx.provider_settings.client_secret, ) + # TODO: remove possible_labels + ctx.label_prefix = self.get_provider_label_prefix() # TODO: move to ctor + if self.scan_all_regions: + ctx.possible_labels = self.get_all_labels( + ctx.subscription_id, ctx.label_prefix + ) + else: + ctx.possible_labels = set() + logger = get_logger( log_name=f"{self.provider.lower()}_cloud_connector", level=self.settings.logging_level, @@ -126,6 +153,12 @@ def scan(self, **kwargs): ): super().scan(**kwargs) + # TODO: remove possible_labels + if self.scan_all_regions: + # for label_not_found in self.possible_labels: + for label_not_found in ctx.possible_labels: + self.delete_seeds_by_label(label_not_found) + def scan_all(self): """Scan all Azure Subscriptions.""" provider_settings: dict[ @@ -136,7 +169,6 @@ def scan_all(self): f"scanning {self.provider} using {self.settings.scan_concurrency} processes" ) - label_prefix = self.get_provider_label_prefix() pool = Pool(processes=self.settings.scan_concurrency) for provider_setting in provider_settings.values(): @@ -155,26 +187,18 @@ def scan_all(self): # self.subscription_id = subscription_id try: - possible_labels: set[str] - if self.scan_all_regions: - possible_labels = self.get_all_labels( - subscription_id, label_prefix - ) - else: - possible_labels = set() - scan_context = AzureScanContext( provider_settings=provider_setting, subscription_id=subscription_id, credentials=None, # credentials=credentials, - possible_labels=possible_labels, # self.possible_labels, + # possible_labels=possible_labels, # self.possible_labels, + possible_labels=None, scan_all_regions=self.scan_all_regions, - label_prefix=label_prefix, + label_prefix=None, # label_prefix=label_prefix, ) # self.scan(**{"scan_context": scan_context}) - # self.scan_seeds(**{"scan_context": scan_context}) # pool.apply_async( pool.apply_async( # self.scan_seeds, @@ -184,12 +208,6 @@ def scan_all(self): }, error_callback=lambda e: self.logger.error(f"Async Error: {e}"), ) - - # TODO: figure out how to make this wait until scan is finished: - if self.scan_all_regions: - # for label_not_found in self.possible_labels: - for label_not_found in scan_context.possible_labels: - self.delete_seeds_by_label(label_not_found) except Exception as e: self.logger.error( f"Unable to scan Azure Subscription {subscription_id}. Error: {e}" @@ -509,14 +527,17 @@ def get_storage_containers(self, **kwargs): container_client = bucket_client.get_container_client(container) container_url = container_client.url with SuppressValidationError(): - container_asset = AzureContainerAsset( - value=container_url, - uid=uid, - scan_data={ - "accountNumber": ctx.subscription_id, # "accountNumber": self.subscription_id, - "publicAccess": container.public_access, - "location": account.location, - }, + container_asset = self.process_cloud_asset( + AzureContainerAsset( + value=container_url, + uid=uid, + scan_data={ + "accountNumber": ctx.subscription_id, # "accountNumber": self.subscription_id, + "publicAccess": container.public_access, + "location": account.location, + }, + ), + label_prefix=ctx.label_prefix, ) # self.add_cloud_asset(container_asset) self.submit_cloud_asset_payload(label, [container_asset]) diff --git a/src/censys/cloud_connectors/common/connector.py b/src/censys/cloud_connectors/common/connector.py index b67f934..b65b625 100644 --- a/src/censys/cloud_connectors/common/connector.py +++ b/src/censys/cloud_connectors/common/connector.py @@ -268,15 +268,19 @@ def process_seed(self, seed: Seed, **kwargs) -> Seed: self.dispatch_event(EventTypeEnum.SEED_FOUND, seed=seed, **kwargs) return seed - def process_cloud_asset(self, cloud_asset: CloudAsset, **kwargs) -> CloudAsset: + def process_cloud_asset( + self, cloud_asset: CloudAsset, label_prefix: str, **kwargs + ) -> CloudAsset: """Prepare a cloud asset for submission. Args: cloud_asset (CloudAsset): The cloud asset to add. **kwargs: Additional data for event dispatching. """ - if not cloud_asset.uid.startswith(self.label_prefix): - cloud_asset.uid = self.label_prefix + cloud_asset.uid + # TODO: remove self.label_prefix + label_prefix = label_prefix or self.label_prefix + if not cloud_asset.uid.startswith(label_prefix): + cloud_asset.uid = label_prefix + cloud_asset.uid self.logger.debug(f"Found Cloud Asset: {cloud_asset.to_dict()}") self.dispatch_event( diff --git a/src/censys/cloud_connectors/common/healthcheck.py b/src/censys/cloud_connectors/common/healthcheck.py index 784a421..ef5c862 100644 --- a/src/censys/cloud_connectors/common/healthcheck.py +++ b/src/censys/cloud_connectors/common/healthcheck.py @@ -97,6 +97,10 @@ def __exit__( exc_traceback (Optional[TracebackType]): The traceback. """ if not self.settings.healthcheck_enabled: + if exc_type is not None: + self.logger.debug( + f"Healthcheck not enabled. Errors would have been exc_type:{exc_type} exc_value:{exc_value}" + ) return if exc_type is not None: error_code = self.exception_map.get(exc_type) # type: ignore @@ -125,9 +129,11 @@ def start(self) -> None: """ if not self.provider_payload: raise ValueError("The provider must be set.") + self.run_id = self._session.post( self.start_url, json={"provider": self.provider_payload} ).json()["runId"] + self.logger.debug( f"Starting Run ID: {self.run_id}", extra={"provider": self.provider_payload} ) @@ -143,12 +149,14 @@ def finish(self, metadata: Optional[dict] = None) -> None: """ if not self.run_id: raise ValueError("The run ID must be set.") + body = {} if not self.settings.healthcheck_enabled: self.logger.info( "Healthcheck not enabled. Skipping submission of healthcheck data." ) + self.logger.debug(f"Health check run:{self.run_id} data: {body}") else: if metadata: body["metadata"] = metadata