Skip to content

Commit

Permalink
Merge pull request #199 from stratosphereips/ondra-fix-randnomization…
Browse files Browse the repository at this point in the history
…-upon-reset

Fixed episodic randomization upon game reset
  • Loading branch information
ondrej-lukas authored Apr 9, 2024
2 parents 78d631a + 6a19624 commit 5262033
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions env/network_security_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,11 @@ def __init__(self, task_config_file) -> None:

# Process episodic randomization of goal position
if not self._randomize_goal_every_episode:
# episodic randomization is not required, randomize once now
# REPLACE 'random' keyword once
logger.info("Episodic randomization disabled, generating static goal_conditions")
self._goal_conditions = self._generate_win_conditions(self._goal_conditions)
else:
self._orig_goal_conditions = copy.deepcopy(self._goal_conditions)
logger.info("Episodic randomization enabled")
logger.info("Episodic randomization enabled, keeping 'random' keyword in the goal description.")

# At this point all 'random' values should be assigned to something
# Check if dynamic network and ip adddresses are required
Expand Down Expand Up @@ -646,6 +645,7 @@ def _create_new_network_mapping(self):
exit(-1)
# Invalid IP address boundary
logger.info(f"New network mapping:{mapping_nets}")

# genereate mapping for ips:
for net,ips in self._networks.items():
ip_list = list(netaddr.IPNetwork(str(mapping_nets[net])))[1:]
Expand All @@ -655,6 +655,7 @@ def _create_new_network_mapping(self):
mapping_ips[ip] = components.IP(str(ip_list[i]))
# Always add random, in case random is selected for ips
mapping_ips['random'] = 'random'
logger.info(f"Mapping IPs done:{mapping_ips}")

# update ALL data structure in the environment with the new mappings
# self._networks
Expand Down Expand Up @@ -685,14 +686,16 @@ def _create_new_network_mapping(self):
new_attacker_start["known_services"] = {mapping_ips[ip]:service for ip,service in self._attacker_start_position["known_services"].items()}
new_attacker_start["known_data"] = {mapping_ips[ip]:data for ip,data in self._attacker_start_position["known_data"].items()}
self._attacker_start_position = new_attacker_start
logger.info(f"Starting position mapping: {new_attacker_start}")
# goal definition
new_goal = {}
new_goal["known_networks"] = {mapping_nets[net] for net in self._goal_conditions["known_networks"]}
new_goal["known_hosts"] = {mapping_ips[ip] for ip in self._goal_conditions["known_hosts"]}
new_goal["controlled_hosts"] = {mapping_ips[ip] for ip in self._goal_conditions["controlled_hosts"]}
new_goal["known_services"] = {mapping_ips[ip]:service for ip,service in self._goal_conditions["known_services"].items()}
new_goal["known_data"] = {mapping_ips[ip]:data for ip,data in self._goal_conditions["known_data"].items()}

logger.info(f"Goal mapping: {new_goal}")
# update goal mapping
for old_ip in mapping_ips.keys():
if str(old_ip) in self._goal_description:
self._goal_description = self._goal_description.replace(str(old_ip), str(mapping_ips[old_ip]))
Expand Down Expand Up @@ -1011,6 +1014,7 @@ def reset(self, trajectory_filename=None)->components.Observation:
and prepare for a new episode
"""
# write all steps in the episode replay buffer in the file
logger.info("Initiating reset")
if self._episode_replay_buffer is not None:
steps = []
for state,action,reward,next_state in self._episode_replay_buffer:
Expand Down Expand Up @@ -1045,21 +1049,15 @@ def reset(self, trajectory_filename=None)->components.Observation:
self._defender.reset()

if self.task_config.get_use_dynamic_addresses():
logger.info("Changes IPs dyamically")
self._create_new_network_mapping()

logger.info("IPs changed successfully")
#reset self._data to orignal state
self._data = copy.deepcopy(self._data_original)

#create starting state (randomized if needed)
self._current_state = self._create_starting_state()

logger.info("New starting state created")
#create win conditions for this episode (randomize if needed)
if self._randomize_goal_every_episode:
self._goal_conditions = copy.deepcopy(self._generate_win_conditions(self._orig_goal_conditions))
else:
self._goal_conditions = copy.deepcopy(self._goal_conditions)

self._goal_conditions = copy.deepcopy(self._generate_win_conditions(self._goal_conditions))
logger.info(f'Current state: {self._current_state}')
initial_reward = 0
info = {}
Expand Down

0 comments on commit 5262033

Please sign in to comment.