Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix action comparion global defender #276

Merged
merged 5 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion AIDojoCoordinator/global_defender.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def stochastic_with_threshold(self, action: Action, episode_actions:list, tw_siz
temp_episode_actions.append(action.as_dict)
if len(temp_episode_actions) >= tw_size:
last_n_actions = temp_episode_actions[-tw_size:]
last_n_action_types = [action['type'] for action in last_n_actions]
last_n_action_types = [action['action_type'] for action in last_n_actions]
# compute ratio of action type in the TW
tw_ratio = last_n_action_types.count(str(action.type))/tw_size
# Count how many times this exact (parametrized) action was played in episode
Expand Down
2 changes: 1 addition & 1 deletion tests/run_all_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#python3 -m pytest tests/test_actions.py -p no:warnings -vvvv -s --full-trace
python3 -m pytest tests/test_components.py -p no:warnings -vvvv -s --full-trace
python3 -m pytest tests/test_game_coordinator.py -p no:warnings -vvvv -s --full-trace
# Coordinator tesst
python3 -m pytest tests/test_global_defender.py -p no:warnings -vvvv -s --full-trace
#python3 -m pytest tests/test_coordinator.py -p no:warnings -vvvv -s --full-trace

# run ruff check as well
Expand Down
62 changes: 62 additions & 0 deletions tests/test_global_defender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
from AIDojoCoordinator.game_components import ActionType, Action
from AIDojoCoordinator.global_defender import GlobalDefender
from unittest.mock import patch

@pytest.fixture
def defender():
return GlobalDefender()

@pytest.fixture
def episode_actions():
"""Mock episode actions list."""
return [
Action(ActionType.ScanNetwork, {}).as_dict,
Action(ActionType.FindServices, {}).as_dict,
Action(ActionType.ScanNetwork, {}).as_dict,
Action(ActionType.FindServices, {}).as_dict,
]

def test_short_episode_does_not_detect(defender, episode_actions):
"""Test when the episode action list is too short to make a decision."""
action = Action(ActionType.ScanNetwork, {})
assert not defender.stochastic_with_threshold(action, episode_actions[:2], tw_size=5)

def test_below_threshold_does_not_trigger_detection(defender, episode_actions):
"""Test when action thresholds are NOT exceeded (should return False)."""
action = Action(ActionType.ScanNetwork, {})
assert not defender.stochastic_with_threshold(action, episode_actions, tw_size=5)

def test_exceeding_threshold_triggers_stochastic(defender, episode_actions):
"""Test when thresholds are exceeded and stochastic is triggered."""
action = Action(ActionType.ScanNetwork, {})
episode_actions += [action.as_dict] * 3 # Exceed threshold

with patch.object(defender, "stochastic", return_value=True) as mock_stochastic:
result = defender.stochastic_with_threshold(action, episode_actions, tw_size=5)
mock_stochastic.assert_called_once_with("ScanNetwork") # Ensure stochastic was called
assert result # Expecting True since stochastic is triggered

def test_repeated_episode_action_threshold(defender, episode_actions):
"""Test when an action exceeds the episode repeated action threshold."""
action = Action(ActionType.FindData, {})
episode_actions += [action.as_dict] * 3 # Exceed repeat threshold

with patch.object(defender, "stochastic", return_value=True) as mock_stochastic:
result = defender.stochastic_with_threshold(action, episode_actions, tw_size=5)
mock_stochastic.assert_called_once_with(ActionType.FindData) # Ensure stochastic was called
assert result # Expecting True since stochastic is triggered

def test_other_actions_never_detected(defender, episode_actions):
"""Test that actions not in any threshold lists always return False."""
action = Action(ActionType.JoinGame, {})
assert not defender.stochastic_with_threshold(action, episode_actions, tw_size=5)

def test_mock_stochastic_probabilities(defender, episode_actions):
"""Test stochastic function is only called when thresholds are crossed."""
action = Action(ActionType.ScanNetwork, {})
episode_actions += [{"action_type": str(ActionType.ScanNetwork)}] * 4 # Exceed threshold

with patch("AIDojoCoordinator.global_defender.random", return_value=0.01): # Force detection probability
result = defender.stochastic_with_threshold(action, episode_actions, tw_size=5)
assert result # Should be True since we forced a low probability value