Skip to content

Commit

Permalink
remove line of sight where possible, cos it is slow now
Browse files Browse the repository at this point in the history
  • Loading branch information
syrkis committed Sep 3, 2024
1 parent a34256e commit c08d6a6
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions parabellum/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def __init__(self, scenario: Scenario, **kwargs):
self.max_steps = 200
self._push_units_away = lambda state, firmness=1: state # overwrite push units
self.spawning_sectors = sectors_fn(self.unit_starting_sectors, scenario.terrain_raster.building + scenario.terrain_raster.water)
self.resolution = self.terrain_raster.building.shape[0] + self.terrain_raster.building.shape[1]
self.t = jnp.tile(jnp.linspace(0, 1, self.resolution), (2, self.resolution))


def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]: # type: ignore
Expand Down Expand Up @@ -203,7 +205,8 @@ def get_features(i, j):
< self.unit_type_sight_ranges[state.unit_types[i]]
)
return jax.lax.cond(
visible & state.unit_alive[i] & state.unit_alive[j_idx] & self.has_line_of_sight(state.unit_positions[j_idx], state.unit_positions[i], self.terrain_raster.building + self.terrain_raster.forest),
visible & state.unit_alive[i] & state.unit_alive[j_idx]
& self.has_line_of_sight(state.unit_positions[j_idx], state.unit_positions[i], self.terrain_raster.building + self.terrain_raster.forest),
lambda: features,
lambda: empty_features,
)
Expand Down Expand Up @@ -239,12 +242,16 @@ def _our_push_units_away(
)
return jnp.where(self.unit_type_pushable[unit_types][:, None], unit_positions, pos)

def has_line_of_sight(self, source, target, raster_input):
resolution = raster_input.shape[0] + raster_input.shape[1]
t = jnp.tile(jnp.linspace(0, 1, resolution), (2, resolution))
cells = jnp.array(source[:, jnp.newaxis] * t + (1-t) * target[:, jnp.newaxis], dtype=jnp.int32)
def has_line_of_sight(self, source, target, raster_input): # this is tooooo slow TODO: make it fast
# we could compute this for units in sight only using a switch

cells = jnp.array(source[:, jnp.newaxis] * self.t + (1-self.t) * target[:, jnp.newaxis], dtype=jnp.int32)

mask = jnp.zeros(raster_input.shape).at[cells[1, :], cells[0, :]].set(1)
return ~jnp.any(jnp.logical_and(mask, raster_input))

flag = ~jnp.any(jnp.logical_and(mask, raster_input))

return flag


@partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
Expand Down Expand Up @@ -334,9 +341,6 @@ def update_agent_health(idx, action, key): # TODO: add attack blasts
)
& state.unit_alive[idx]
& state.unit_alive[attacked_idx]
& self.has_line_of_sight(state.unit_positions[idx], state.unit_positions[attacked_idx], self.terrain_raster.building
+ self.terrain_raster.forest
)
)
attack_valid = attack_valid & (idx != attacked_idx)
attack_valid = attack_valid & (state.unit_weapon_cooldowns[idx] <= 0.0)
Expand Down Expand Up @@ -458,7 +462,7 @@ def perform_agent_action(idx, action, key):

n_allies = 10
scenario_kwargs = {"allies_type": 0, "n_allies": n_allies, "enemies_type": 0, "n_enemies": n_allies,
"place": "Vesterbro, Copenhagen, Denmark", "size": 100, "unit_starting_sectors":
"place": "Vesterbro, Copenhagen, Denmark", "size": 256, "unit_starting_sectors":
[([i for i in range(n_allies)], [0.,0.45,0.1,0.1]), ([n_allies+i for i in range(n_allies)], [0.8,0.5,0.1,0.1])]}
scenario = make_scenario(**scenario_kwargs)
env = Environment(scenario)
Expand Down

0 comments on commit c08d6a6

Please sign in to comment.