Skip to content

Commit

Permalink
fix: scan unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
AdriiiPRodri committed Jan 31, 2025
1 parent 63f8186 commit 0995c7a
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 33 deletions.
8 changes: 2 additions & 6 deletions prowler/lib/check/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,9 @@ def remove_custom_checks_module(input_folder: str, provider: str):
def list_services(provider: str) -> set:
available_services = set()
checks_tuple = recover_checks_from_provider(provider)
split_character = "\\" if os.name == "nt" else "/"
for _, check_path in checks_tuple:
# Format: /absolute_path/prowler/providers/{provider}/services/{service_name}/{check_name}
if os.name == "nt":
service_name = check_path.split("\\")[-2]
else:
service_name = check_path.split("/")[-2]
available_services.add(service_name)
available_services.add(check_path.split(split_character)[-2])
return sorted(available_services)


Expand Down
10 changes: 5 additions & 5 deletions prowler/lib/scan/scan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
from typing import Dict, Generator, List, Optional, Set, Tuple
from typing import Dict, Generator, List, Optional, Set

from prowler.lib.check.check import (
execute,
Expand Down Expand Up @@ -255,7 +255,7 @@ def remaining_services(self) -> Dict[str, Set[str]]:
if checks
}

def scan(self) -> Generator[Tuple[float, List[Finding], dict], None, None]:
def scan(self, custom_checks_metadata: dict = {}) -> Generator[float, List[Finding], dict]:
"""
Executes the scan by iterating over the checks to execute and executing each check.
Yields the progress and findings for each check.
Expand Down Expand Up @@ -283,7 +283,7 @@ def scan(self) -> Generator[Tuple[float, List[Finding], dict], None, None]:
service = get_service_from_check(check_name)
try:
check_module = self._import_check_module(check_name, service)
findings = self._execute_check(check_module, check_name)
findings = self._execute_check(check_module, check_name, custom_checks_metadata)
filtered_findings = self._filter_findings_by_status(findings)
except Exception as error:
logger.error(f"{check_name} failed: {error}")
Expand Down Expand Up @@ -321,10 +321,10 @@ def _import_check_module(self, check_name: str, service: str):
)
raise

def _execute_check(self, check_module, check_name: str) -> List[Finding]:
def _execute_check(self, check_module, check_name: str, custom_checks_metadata: dict = {}) -> List[Finding]:
"""Execute a single check and return its findings."""
check_func = getattr(check_module, check_name)
return execute(check_func(), self._provider, {}, None)
return execute(check_func(), self._provider, custom_checks_metadata)

def _filter_findings_by_status(self, findings: List[Finding]) -> List[Finding]:
"""Filter findings based on configured status filters."""
Expand Down
68 changes: 46 additions & 22 deletions tests/lib/scan/scan_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from importlib.machinery import FileFinder
from pkgutil import ModuleInfo
from unittest import mock

import pytest
from mock import MagicMock, patch

Expand Down Expand Up @@ -196,9 +195,9 @@ def test_init(mock_provider):
mock_provider.type = "aws"
scan = Scan(mock_provider, checks=checks_to_execute)

assert scan.provider == mock_provider
assert scan._provider == mock_provider
# Check that the checks to execute are sorted and without duplicates
assert scan.checks_to_execute == [
assert scan._checks_to_execute == [
"accessanalyzer_enabled",
"accessanalyzer_enabled_without_findings",
"account_maintain_current_contact_details",
Expand Down Expand Up @@ -258,14 +257,16 @@ def test_init(mock_provider):
"config_recorder_all_regions_enabled",
"workspaces_vpc_2private_1public_subnets_nat",
]
assert scan.service_checks_to_execute == get_service_checks_mapping(
assert scan._service_checks_map == get_service_checks_mapping(
checks_to_execute
)
assert scan.service_checks_completed == {}
assert scan.progress == 0
assert scan.duration == 0
assert scan.get_completed_services() == set()
assert scan.get_completed_checks() == set()
assert scan._completed_checks == set()
assert scan._progress == 0
assert scan._duration == 0
all_values = set().union(*scan.remaining_services.values())
for check in scan._checks_to_execute:
assert check in all_values
assert scan.completed_checks == 0

def test_init_with_no_checks(
mock_provider,
Expand All @@ -281,18 +282,24 @@ def test_init_with_no_checks(
mock_load_checks_to_execute.assert_called_once()
mock_recover_checks_from_provider.assert_called_once_with("aws")

assert scan.provider == mock_provider
assert scan.checks_to_execute == ["accessanalyzer_enabled"]
assert scan.service_checks_to_execute == get_service_checks_mapping(
assert scan._provider == mock_provider
assert scan._checks_to_execute == ["accessanalyzer_enabled"]
assert scan._service_checks_map == get_service_checks_mapping(
["accessanalyzer_enabled"]
)
assert scan.service_checks_completed == {}
assert scan.progress == 0
assert scan.get_completed_services() == set()
assert scan.get_completed_checks() == set()
assert scan._completed_checks == set()
assert scan._progress == 0
all_values = set().union(*scan.remaining_services.values())
for check in scan._checks_to_execute:
assert check in all_values
assert scan.completed_checks == 0

@patch("importlib.import_module")
@patch("prowler.lib.scan.scan.list_services")
@patch("prowler.lib.scan.scan.extract_findings_statistics")
def test_scan(
mock_extract_findings_statistics,
mock_list_services,
mock_import_module,
mock_global_provider,
mock_execute,
Expand Down Expand Up @@ -328,11 +335,9 @@ def test_scan(
assert results[0][0] == 100.0
assert scan.progress == 100.0
# Since the scan is mocked, the duration will always be 0 for now
assert scan.duration == 0
assert scan._number_of_checks_completed == 1
assert scan.service_checks_completed == {
"accessanalyzer": {"accessanalyzer_enabled"},
}
assert scan._duration == 0
assert scan.completed_checks == 1
assert scan._completed_checks == {"accessanalyzer_enabled"}
mock_logger.error.assert_not_called()

def test_init_invalid_severity(
Expand Down Expand Up @@ -396,7 +401,9 @@ def test_init_invalid_status(
Scan(mock_provider, checks=checks_to_execute, status=["invalid_status"])

@patch("importlib.import_module")
@patch("prowler.lib.scan.scan.list_services")
def test_scan_filter_status(
mock_list_services,
mock_import_module,
mock_global_provider,
mock_recover_checks_from_provider,
Expand All @@ -422,4 +429,21 @@ def test_scan_filter_status(
mock_recover_checks_from_provider.assert_called_once_with("aws")
results = list(scan.scan(custom_checks_metadata))

assert results[0] == (100.0, [])
assert results[0] == (100.0, [], {
"all_fails_are_muted": True,
"findings_count": 0,
"resources_count": 0,
"total_critical_severity_fail": 0,
"total_critical_severity_pass": 0,
"total_fail": 0,
"total_high_severity_fail": 0,
"total_high_severity_pass": 0,
"total_low_severity_fail": 0,
"total_low_severity_pass": 0,
"total_medium_severity_fail": 0,
"total_medium_severity_pass": 0,
"total_muted_fail": 0,
"total_muted_pass": 0,
"total_pass": 0,
},
)

0 comments on commit 0995c7a

Please sign in to comment.