Skip to content

Commit

Permalink
tim
Browse files Browse the repository at this point in the history
  • Loading branch information
syrkis committed Oct 30, 2024
2 parents 0d26c3c + aeb1303 commit 6bcd433
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 18 deletions.
9 changes: 9 additions & 0 deletions parabellum/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,20 @@ def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]: # ty
# def step_env(self, rng, state: State, action: Array): # type: ignore
# obs, state, rewards, dones, infos = super().step_env(rng, state, action)
# delete world_state from obs
<<<<<<< HEAD
# obs.pop("world_state")
# if not self.reset_when_done:
# for key in dones.keys():
# dones[key] = False
# return obs, state, rewards, dones, infos
=======
obs.pop("world_state")
if not self.reset_when_done:
for key in dones.keys():
infos[key] = dones[key]
dones[key] = False
return obs, state, rewards, dones, infos
>>>>>>> aeb13033e57083cc512a60f8f60a3db47a65ac32

def get_obs_unit_list(self, state: State) -> Dict[str, chex.Array]: # type: ignore
"""Applies observation function to state."""
Expand Down
80 changes: 75 additions & 5 deletions parabellum/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def geography_fn(place, buffer=400):
# 0: building", 1: "water", 2: "highway", 3: "forest", 4: "garden"
kernel = jnp.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
trans = lambda x: jnp.rot90(x, 3)
# <<<<<<< HEAD
terrain = tps.Terrain(
building=trans(raster[0]),
water=trans(
Expand All @@ -102,6 +103,15 @@ def geography_fn(place, buffer=400):
return terrain


# =======
# terrain = tps.Terrain(building=trans(raster[0] - convolve(raster[0]*raster[2], kernel, mode='same')>0),
# water=trans(raster[1] - convolve(raster[1]*raster[2], kernel, mode='same')>0),
# forest=trans(jnp.logical_or(raster[3], raster[4])),
# basemap=basemap)
# return terrain, gdf
# >>>>>>> aeb13033e57083cc512a60f8f60a3db47a65ac32


def raster_fn(gdf, shape) -> Array:
bbox = gdf.total_bounds
t = transform.from_bounds(*bbox, *shape) # type: ignore
Expand All @@ -117,14 +127,74 @@ def feature_fn(t, feature, gdf, shape):
return raster


# %%
def normalize(x):
return (np.array(x) - m) / (M - m)


def get_bridges(gdf):
xmin, ymin, xmax, ymax = gdf.total_bounds
m = np.array([xmin, ymin])
M = np.array([xmax, ymax])

bridges = {}
for idx, bridge in gdf[gdf["bridge"] == "yes"].iterrows():
if type(bridge["name"]) == str:
bridges[idx[1]] = {
"name": bridge["name"],
"coords": normalize(
[bridge.geometry.centroid.x, bridge.geometry.centroid.y]
),
}
return bridges


"""
# %%
if __name__ == "__main__":
place = "Thun, Switzerland"
<<<<<<< HEAD
terrain = geography_fn(place, 300)
=======
terrain, gdf = geography_fn(place, 300)
>>>>>>> aeb13033e57083cc512a60f8f60a3db47a65ac32
fig, axes = plt.subplots(1, 5, figsize=(20, 20))
axes[0].imshow(terrain.building, cmap="gray")
axes[1].imshow(terrain.water, cmap="gray")
axes[2].imshow(terrain.forest, cmap="gray")
axes[3].imshow(terrain.building + terrain.water + terrain.forest)
axes[4].imshow(terrain.basemap)
axes[0].imshow(jnp.rot90(terrain.building), cmap="gray")
axes[1].imshow(jnp.rot90(terrain.water), cmap="gray")
axes[2].imshow(jnp.rot90(terrain.forest), cmap="gray")
axes[3].imshow(jnp.rot90(terrain.building + terrain.water + terrain.forest))
axes[4].imshow(jnp.rot90(terrain.basemap))
# %%
W, H, _ = terrain.basemap.shape
bridges = get_bridges(gdf)
# %%
print("Bridges:")
for bridge in bridges.values():
x, y = int(bridge["coords"][0]*300), int(bridge["coords"][1]*300)
print(bridge["name"], f"at ({x}, {y})")
# %%
plt.subplots(figsize=(7,7))
plt.imshow(jnp.rot90(terrain.basemap))
X = [b["coords"][0]*W for b in bridges.values()]
Y = [(1-b["coords"][1])*H for b in bridges.values()]
plt.scatter(X, Y)
for i in range(len(X)):
x,y = int(X[i]), int(Y[i])
plt.text(x, y, str((int(x/W*300), int((1-(y/H))*300))))
# %%
# %% [raw]
# fig, ax = plt.subplots(figsize=(10, 10))
# gdf.plot(ax=ax, color='lightgray') # Plot all features
# bridges.plot(ax=ax, color='red') # Highlight bridges in red
# plt.show()
# %%
"""
43 changes: 30 additions & 13 deletions parabellum/terrain_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def make_terrain(terrain_args, size):
{"line":[0.75, 0.25, 0., 0.2]}, {"line":[0.75, 0.55, 0., 0.19]},
{"line":[0.6, 0.25, 0.15, 0.]}], 'water': None, 'forest': None},
"playground": {'building': [{"line":[0.5, 0.5, 0.5, 0.]}], 'water': None, 'forest': None},
"water_park": {
'building': [{"line":[0.5, 0.5, 0.5, 0.]}],
"water": [{"rect":[0., 0.8, 0.1, 0.05]}, {"rect": [0.2, 0.8, 0.8, 0.05]}],
"playground2": {
'building': [],
"water": [{"rect":[0., 0.8, 0.1, 0.1]}, {"rect": [0.2, 0.8, 0.8, 0.1]}],
"forest": [{"rect": [0., 0., 1., 0.2]}]
},
"triangle": {'building': [{"line": [0.33, 0., 0., 1.]}, {"line": [0.66, 0., 0., 1.]}], 'water': None, 'forest': None},
Expand All @@ -81,23 +81,38 @@ def make_terrain(terrain_args, size):
"water": [{"rect": [0.15, 0.2, 0.1, 0.5]}, {"rect": [0.4, 0.2, 0.1, 0.5]}, {"rect": [0.2, 0.2, 0.25, 0.1]}],
"forest": []
},
"bridges": {
'building': [],
"water": [{"rect": [0.475, 0., 0.05, 0.1]}, {"rect": [0.475, 0.15, 0.05, 0.575]}, {"rect": [0.475, 0.775, 0.05, 1.]},
{"rect": [0., 0.475, 0.225, 0.05]}, {"rect": [0.275, 0.475, 0.45, 0.05]}, {"rect": [0.775, 0.475, 0.23, 0.05]}],
"forest": [{"rect": [0.1, 0.625, 0.275, 0.275]}, {"rect": [0.725, 0., 0.3, 0.275]}, ]
}
}

# %% [raw]
# import matplotlib.pyplot as plt
# size = 50
# size = 100
# raster = np.zeros((size, size))
# rect = [0.2, 0.3, 0.05, 0.4]
# rect = [0.475, 0., 0.05, 0.1]
# raster = map_raster_from_rect(raster, rect, size)
# rect = [0.475, 0.15, 0.05, 0.575]
# raster = map_raster_from_rect(raster, rect, size)
# rect = [0.475, 0.775, 0.05, 1.]
# raster = map_raster_from_rect(raster, rect, size)
# rect = [0.4, 0.3, 0.05, 0.4]
#
# rect = [0., 0.475, 0.225, 0.05]
# raster = map_raster_from_rect(raster, rect, size)
# rect = [0.2, 0.3, 0.25, 0.05]
# rect = [0.275, 0.475, 0.45, 0.05]
# raster = map_raster_from_rect(raster, rect, size)
# rect = [0.2, 0.7, 0.25, 0.05]
# rect = [0.775, 0.475, 0.23, 0.05]
# raster = map_raster_from_rect(raster, rect, size)
# rect = [0.6, 0.3, 0.4, 0.45]
#
# rect = [0.1, 0.625, 0.275, 0.275]
# raster = map_raster_from_rect(raster, rect, size)
# plt.imshow(jnp.rot90(raster))
# rect = [0.725, 0., 0.3, 0.275]
# raster = map_raster_from_rect(raster, rect, size)
#
# plt.imshow(raster[::-1, :])

# %% [markdown]
# # Main
Expand All @@ -107,11 +122,13 @@ def make_terrain(terrain_args, size):
import matplotlib.pyplot as plt

# %%
terrain = make_terrain(db["u_shape"], size=50)
terrain = make_terrain(db["bridges"], size=100)

# %%
plt.imshow(jnp.rot90(terrain.basemap))

# %%
bl = (39.5, 5)
tr = (44.5, 10)
plt.scatter(bl[0], 49-bl[1])
plt.scatter(tr[0], 49-tr[1], marker="+")

# %%

0 comments on commit 6bcd433

Please sign in to comment.