From f58ccffac75fab17a2085846c2147289b2d1ab6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Jes=C3=BAs=20Pe=C3=B1a=20Rodr=C3=ADguez?= Date: Thu, 30 Jan 2025 13:34:20 +0100 Subject: [PATCH] fix: scan unittests --- prowler/lib/check/check.py | 8 ++--- prowler/lib/scan/scan.py | 10 +++--- tests/lib/scan/scan_test.py | 69 +++++++++++++++++++++++++------------ 3 files changed, 54 insertions(+), 33 deletions(-) diff --git a/prowler/lib/check/check.py b/prowler/lib/check/check.py index d760702fa82..539d63d35db 100644 --- a/prowler/lib/check/check.py +++ b/prowler/lib/check/check.py @@ -139,13 +139,9 @@ def remove_custom_checks_module(input_folder: str, provider: str): def list_services(provider: str) -> set: available_services = set() checks_tuple = recover_checks_from_provider(provider) + split_character = "\\" if os.name == "nt" else "/" for _, check_path in checks_tuple: - # Format: /absolute_path/prowler/providers/{provider}/services/{service_name}/{check_name} - if os.name == "nt": - service_name = check_path.split("\\")[-2] - else: - service_name = check_path.split("/")[-2] - available_services.add(service_name) + available_services.add(check_path.split(split_character)[-2]) return sorted(available_services) diff --git a/prowler/lib/scan/scan.py b/prowler/lib/scan/scan.py index cba4cd41628..15a552643f9 100644 --- a/prowler/lib/scan/scan.py +++ b/prowler/lib/scan/scan.py @@ -1,5 +1,5 @@ import datetime -from typing import Dict, Generator, List, Optional, Set, Tuple +from typing import Dict, Generator, List, Optional, Set from prowler.lib.check.check import ( execute, @@ -255,7 +255,7 @@ def remaining_services(self) -> Dict[str, Set[str]]: if checks } - def scan(self) -> Generator[Tuple[float, List[Finding], dict], None, None]: + def scan(self, custom_checks_metadata: dict = {}) -> Generator[float, List[Finding], dict]: """ Executes the scan by iterating over the checks to execute and executing each check. Yields the progress and findings for each check. @@ -283,7 +283,7 @@ def scan(self) -> Generator[Tuple[float, List[Finding], dict], None, None]: service = get_service_from_check(check_name) try: check_module = self._import_check_module(check_name, service) - findings = self._execute_check(check_module, check_name) + findings = self._execute_check(check_module, check_name, custom_checks_metadata) filtered_findings = self._filter_findings_by_status(findings) except Exception as error: logger.error(f"{check_name} failed: {error}") @@ -321,10 +321,10 @@ def _import_check_module(self, check_name: str, service: str): ) raise - def _execute_check(self, check_module, check_name: str) -> List[Finding]: + def _execute_check(self, check_module, check_name: str, custom_checks_metadata: dict = {}) -> List[Finding]: """Execute a single check and return its findings.""" check_func = getattr(check_module, check_name) - return execute(check_func(), self._provider, {}, None) + return execute(check_func(), self._provider, custom_checks_metadata) def _filter_findings_by_status(self, findings: List[Finding]) -> List[Finding]: """Filter findings based on configured status filters.""" diff --git a/tests/lib/scan/scan_test.py b/tests/lib/scan/scan_test.py index 197a502faca..6c6c4d3aa72 100644 --- a/tests/lib/scan/scan_test.py +++ b/tests/lib/scan/scan_test.py @@ -1,7 +1,6 @@ from importlib.machinery import FileFinder from pkgutil import ModuleInfo from unittest import mock - import pytest from mock import MagicMock, patch @@ -196,9 +195,9 @@ def test_init(mock_provider): mock_provider.type = "aws" scan = Scan(mock_provider, checks=checks_to_execute) - assert scan.provider == mock_provider + assert scan._provider == mock_provider # Check that the checks to execute are sorted and without duplicates - assert scan.checks_to_execute == [ + assert scan._checks_to_execute == [ "accessanalyzer_enabled", "accessanalyzer_enabled_without_findings", "account_maintain_current_contact_details", @@ -258,14 +257,16 @@ def test_init(mock_provider): "config_recorder_all_regions_enabled", "workspaces_vpc_2private_1public_subnets_nat", ] - assert scan.service_checks_to_execute == get_service_checks_mapping( + assert scan._service_checks_map == get_service_checks_mapping( checks_to_execute ) - assert scan.service_checks_completed == {} - assert scan.progress == 0 - assert scan.duration == 0 - assert scan.get_completed_services() == set() - assert scan.get_completed_checks() == set() + assert scan._completed_checks == set() + assert scan._progress == 0 + assert scan._duration == 0 + all_values = set().union(*scan.remaining_services.values()) + for check in scan._checks_to_execute: + assert check in all_values + assert scan.completed_checks == 0 def test_init_with_no_checks( mock_provider, @@ -281,18 +282,24 @@ def test_init_with_no_checks( mock_load_checks_to_execute.assert_called_once() mock_recover_checks_from_provider.assert_called_once_with("aws") - assert scan.provider == mock_provider - assert scan.checks_to_execute == ["accessanalyzer_enabled"] - assert scan.service_checks_to_execute == get_service_checks_mapping( + assert scan._provider == mock_provider + assert scan._checks_to_execute == ["accessanalyzer_enabled"] + assert scan._service_checks_map == get_service_checks_mapping( ["accessanalyzer_enabled"] ) - assert scan.service_checks_completed == {} - assert scan.progress == 0 - assert scan.get_completed_services() == set() - assert scan.get_completed_checks() == set() + assert scan._completed_checks == set() + assert scan._progress == 0 + all_values = set().union(*scan.remaining_services.values()) + for check in scan._checks_to_execute: + assert check in all_values + assert scan.completed_checks == 0 @patch("importlib.import_module") + @patch("prowler.lib.scan.scan.list_services") + @patch("prowler.lib.scan.scan.extract_findings_statistics") def test_scan( + mock_extract_findings_statistics, + mock_list_services, mock_import_module, mock_global_provider, mock_execute, @@ -328,11 +335,10 @@ def test_scan( assert results[0][0] == 100.0 assert scan.progress == 100.0 # Since the scan is mocked, the duration will always be 0 for now - assert scan.duration == 0 - assert scan._number_of_checks_completed == 1 - assert scan.service_checks_completed == { - "accessanalyzer": {"accessanalyzer_enabled"}, - } + assert scan._duration == 0 + assert scan.completed_checks == 1 + print(scan._completed_checks) + assert scan._completed_checks == {"accessanalyzer_enabled"} mock_logger.error.assert_not_called() def test_init_invalid_severity( @@ -396,7 +402,9 @@ def test_init_invalid_status( Scan(mock_provider, checks=checks_to_execute, status=["invalid_status"]) @patch("importlib.import_module") + @patch("prowler.lib.scan.scan.list_services") def test_scan_filter_status( + mock_list_services, mock_import_module, mock_global_provider, mock_recover_checks_from_provider, @@ -422,4 +430,21 @@ def test_scan_filter_status( mock_recover_checks_from_provider.assert_called_once_with("aws") results = list(scan.scan(custom_checks_metadata)) - assert results[0] == (100.0, []) + assert results[0] == (100.0, [], { + "all_fails_are_muted": True, + "findings_count": 0, + "resources_count": 0, + "total_critical_severity_fail": 0, + "total_critical_severity_pass": 0, + "total_fail": 0, + "total_high_severity_fail": 0, + "total_high_severity_pass": 0, + "total_low_severity_fail": 0, + "total_low_severity_pass": 0, + "total_medium_severity_fail": 0, + "total_medium_severity_pass": 0, + "total_muted_fail": 0, + "total_muted_pass": 0, + "total_pass": 0, + }, + )