Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matplotlib upgrade #270

Open
wants to merge 60 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
3597d9e
feat: Implement predator prey env (#1)
zombie-einstein Nov 4, 2024
c955320
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 4, 2024
6b34657
Merge branch 'main' into main
sash-a Nov 4, 2024
988339b
fix: PR fixes (#2)
zombie-einstein Nov 5, 2024
a0fe7a5
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 5, 2024
b4cce01
style: Run updated pre-commit
zombie-einstein Nov 6, 2024
cb6d88d
refactor: Consolidate predator prey type
zombie-einstein Nov 7, 2024
06de3a0
feat: Implement search and rescue (#3)
zombie-einstein Nov 11, 2024
34beab6
fix: PR fixes (#4)
zombie-einstein Nov 14, 2024
f5fa659
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 15, 2024
072db18
refactor: PR fixes (#5)
zombie-einstein Nov 19, 2024
162a74d
feat: Allow variable environment dimensions (#6)
zombie-einstein Nov 19, 2024
4996869
Merge branch 'main' into main
zombie-einstein Nov 22, 2024
6322f61
fix: Locate targets in single pass (#8)
zombie-einstein Nov 23, 2024
4ba7688
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 28, 2024
9a654b9
feat: training and customisable observations (#7)
zombie-einstein Dec 7, 2024
5021e20
feat: view all targets (#9)
zombie-einstein Dec 9, 2024
c5c7b85
Merge branch 'instadeepai:main' into main
zombie-einstein Dec 9, 2024
13ffb84
Merge branch 'instadeepai:main' into main
zombie-einstein Dec 9, 2024
9e8ac5c
feat: Scaled rewards and target velocities (#10)
zombie-einstein Dec 11, 2024
5c509c7
Pass shape information to timesteps (#11)
zombie-einstein Dec 11, 2024
8acf242
test: extend tests and docs (#12)
zombie-einstein Dec 11, 2024
1792aa6
fix: unpin jax requirement
zombie-einstein Dec 12, 2024
1e66e78
Include agent positions in observation (#13)
zombie-einstein Dec 12, 2024
296e98a
Update animation functions
zombie-einstein Dec 12, 2024
fe8880a
Update rubiks cube viewer for new API
zombie-einstein Dec 13, 2024
407ff79
Upgrade Esquilax and remove unused random keys (#14)
zombie-einstein Dec 27, 2024
b52fefd
Address PR comments
zombie-einstein Jan 9, 2025
04fe710
docs: Review docstrings and docs (#15)
zombie-einstein Jan 12, 2025
ac3f811
fix: Remove enum annotations
zombie-einstein Jan 12, 2025
943a51b
refactor: address pr comments (#16)
zombie-einstein Jan 17, 2025
6a3fdb1
Parameter tweaks
zombie-einstein Jan 17, 2025
ac8838f
refactor: Observation tweaks (#17)
zombie-einstein Jan 20, 2025
05eeedf
refactor: address pr comments (#18)
zombie-einstein Feb 3, 2025
5353ef7
chore: revert to using set colours
sash-a Feb 4, 2025
bc9e252
fix: minor training bug due to refactor
sash-a Feb 4, 2025
eac2f1f
chore: update default parameters to ones tested in mava
sash-a Feb 4, 2025
79a7aa8
chore: add search and rescue to the readme
sash-a Feb 4, 2025
de8e869
Update graph-coloring viewer
zombie-einstein Feb 4, 2025
243ddde
Update mmst
zombie-einstein Feb 4, 2025
211d578
Merge branch 'main' of github.com:zombie-einstein/jumanji into matplo…
zombie-einstein Feb 4, 2025
cbb159d
Merge branch 'instadeepai:main' into main
zombie-einstein Feb 4, 2025
385935d
Merge branch 'main' of github.com:zombie-einstein/jumanji into matplo…
zombie-einstein Feb 4, 2025
ffb9c6f
Update suduko viewer
zombie-einstein Feb 5, 2025
cadb337
Fix tsp viewer
zombie-einstein Feb 5, 2025
3bcb2a5
Refactor graph coloring for multiple episodes
zombie-einstein Feb 10, 2025
5b6ada7
Fix coloring and mmst for multiple episodes
zombie-einstein Feb 12, 2025
759cc45
Fix knapsack image loading
zombie-einstein Feb 12, 2025
5361a75
Replace pkg_resources usage and add animation test script
zombie-einstein Feb 12, 2025
b5613a1
Fix refresh of cvrp on new episode
zombie-einstein Feb 12, 2025
dbc52bb
Refresh tsp at new episode
zombie-einstein Feb 12, 2025
4c3a139
Refresh multi-crvp at new episode
zombie-einstein Feb 12, 2025
e15e70e
Cleanup and refactor graph layout functionality
zombie-einstein Feb 13, 2025
71ca43a
Address PR comments
zombie-einstein Feb 25, 2025
43bbb77
Consolidate viewer functionality
zombie-einstein Feb 28, 2025
7bfd2c8
Consolidate animation creation
zombie-einstein Feb 28, 2025
bec0ac5
Viewer tweaks
zombie-einstein Mar 5, 2025
7feba67
Fix cmap warnings
zombie-einstein Mar 5, 2025
489ff67
Update jupyter backend and animation rendering
zombie-einstein Mar 6, 2025
16d58b7
Revert sudoku animation change
zombie-einstein Mar 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions examples/visualize_random_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import jax
import requests
from hydra import compose, initialize

from jumanji.training.setup_train import setup_agent, setup_env

envs = [
"bin_pack",
"cleaner",
"connector",
"cvrp",
"flat_pack",
"game_2048",
"graph_coloring",
"job_shop",
"knapsack",
"lbf",
"maze",
"minesweeper",
"mmst",
"multi_cvrp",
"pac_man",
"robot_warehouse",
"rubiks_cube",
"search_and_rescue",
"sliding_tile_puzzle",
"snake",
"sokoban",
"sudoku",
"tetris",
"tsp",
]


def download_file(url: str, file_path: str) -> None:
# Send an HTTP GET request to the URL
response = requests.get(url)
# Check if the request was successful (status code 200)
if response.status_code == 200:
with open(file_path, "wb") as f:
f.write(response.content)
else:
print("Failed to download the file.")


def create_animation(env_name: str, agent: str = "random", num_episodes: int = 2) -> None:
print(f"Animating {env_name}")

os.makedirs("configs", exist_ok=True)
config_url = "https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/config.yaml"
download_file(config_url, "configs/config.yaml")
env_url = f"https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/env/{env_name}.yaml"
os.makedirs("configs/env", exist_ok=True)
download_file(env_url, f"configs/env/{env_name}.yaml")
os.makedirs("animations", exist_ok=True)

with initialize(version_base=None, config_path="configs"):
cfg = compose(config_name="config.yaml", overrides=[f"env={env_name}", f"agent={agent}"])

env = setup_env(cfg).unwrapped
agent = setup_agent(cfg, env)
policy = jax.jit(agent.make_policy(stochastic=False))

reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)
states = []
key = jax.random.PRNGKey(cfg.seed)

for _ in range(num_episodes):
key, reset_key = jax.random.split(key)
state, timestep = reset_fn(reset_key)
states.append(state)

while not timestep.last():
key, action_key = jax.random.split(key)
observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation)
action = policy(observation, action_key)
state, timestep = step_fn(state, action.squeeze(axis=0))
states.append(state)

env.animate(states, 100, f"animations/{env_name}_animation.gif")


if __name__ == "__main__":
cli = argparse.ArgumentParser()
cli.add_argument(
"envs",
nargs="*",
type=str,
default=None,
)

args = cli.parse_args()
arg_envs = args.envs
env_list = envs if len(arg_envs) == 0 else args.envs

for env in env_list:
try:
create_animation(env)
except Exception as e:
print(f"{env} failed", e)
2 changes: 1 addition & 1 deletion jumanji/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def is_notebook() -> bool:
if is_colab():
backend = "inline"
elif is_notebook():
backend = "notebook"
backend = "ipympl"
Copy link
Contributor Author

@zombie-einstein zombie-einstein Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is now the correct way to update the render in a jupyter notebook (as mentioned in #272). I've tested this locally with what I guess is fairly recent version of jupyter. Only potential issue is if this breaks behaviour of older versions of notebooks?

Without this change I just get the message Javascript Error: IPython is not defined when trying to call the render method in a notebook.

else:
backend = ""
IPython.get_ipython().run_line_magic("matplotlib", backend)
Expand Down
99 changes: 99 additions & 0 deletions jumanji/environments/commons/graph_view_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple

import chex
import numpy as np


def _compute_repulsive_forces(
repulsive_forces: np.ndarray, pos: np.ndarray, k: float, num_nodes: int
) -> np.ndarray:
for i in range(num_nodes):
for j in range(i + 1, num_nodes):
delta = pos[i] - pos[j]
distance = np.linalg.norm(delta)
direction = delta / (distance + 1e-6)
force = k * k / (distance + 1e-6)
repulsive_forces[i] += direction * force
repulsive_forces[j] -= direction * force

return repulsive_forces


def _compute_attractive_forces(
graph: chex.Array,
attractive_forces: np.ndarray,
pos: np.ndarray,
k: float,
num_nodes: int,
) -> np.ndarray:
for i in range(num_nodes):
for j in range(num_nodes):
if graph[i, j]:
delta = pos[i] - pos[j]
distance = np.linalg.norm(delta)
direction = delta / (distance + 1e-6)
force = distance * distance / k
attractive_forces[i] -= direction * force
attractive_forces[j] += direction * force

return attractive_forces


def spring_layout(
graph: chex.Array, num_nodes: int, seed: int = 42, iterations: int = 100
) -> List[Tuple[float, float]]:
"""
Compute a 2D spring layout for the given graph using
the Fruchterman-Reingold force-directed algorithm.

The algorithm computes a layout by simulating the graph as a physical system,
where nodes are repelling each other and edges are attracting connected nodes.
The method minimizes the energy of the system over several iterations.

Args:
graph: A Graph object representing the adjacency matrix of the graph.
num_nodes: Number of graph nodes.
seed: An integer used to seed the random number generator for reproducibility.
iterations: Number of layout refining iterations.

Returns:
A list of tuples representing the 2D positions of nodes in the graph.
"""
rng = np.random.default_rng(seed)
pos = rng.random((num_nodes, 2)) * 2 - 1

k = np.sqrt(5 / num_nodes)
temperature = 2.0 # Added a temperature variable

for _ in range(iterations):
repulsive_forces = _compute_repulsive_forces(np.zeros((num_nodes, 2)), pos, k, num_nodes)
attractive_forces = _compute_attractive_forces(
graph, np.zeros((num_nodes, 2)), pos, k, num_nodes
)

pos += (repulsive_forces + attractive_forces) * temperature
# Reduce the temperature (cooling factor) to refine the layout.
temperature *= 0.9

pos = np.clip(pos, -1, 1) # Keep positions within the [-1, 1] range

# Scale positions to fill figure
pos_max = np.max(pos, axis=0)
pos_min = np.min(pos, axis=0)
pos = 0.05 + (pos - pos_min) / (1.1 * (pos_max - pos_min))

return [(float(p[0]), float(p[1])) for p in pos]
67 changes: 10 additions & 57 deletions jumanji/environments/commons/maze_utils/maze_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, ClassVar, Dict, List, Optional, Sequence, Tuple
from typing import ClassVar, Dict, List, Optional, Sequence, Tuple

import chex
import matplotlib.animation
import matplotlib.cm
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import image
from matplotlib.artist import Artist
from matplotlib.axes import Axes
from numpy.typing import NDArray

import jumanji.environments
from jumanji.environments.commons.maze_utils.maze_generation import EMPTY, WALL
from jumanji.viewer import Viewer
from jumanji.viewer import MatplotlibViewer


class MazeViewer(Viewer):
class MazeViewer(MatplotlibViewer):
FONT_STYLE = "monospace"
FIGURE_SIZE = (10.0, 10.0)
# EMPTY is white, WALL is black
COLORS: ClassVar[Dict[int, List[int]]] = {EMPTY: [1, 1, 1], WALL: [0, 0, 0]}

Expand All @@ -43,18 +42,7 @@ def __init__(self, name: str, render_mode: str = "human") -> None:
- "human": render the environment on screen.
- "rgb_array": return a numpy array frame representing the environment.
"""
self._name = name
# The animation must be stored in a variable that lives as long as the
# animation should run. Otherwise, the animation will get garbage-collected.
self._animation: Optional[matplotlib.animation.Animation] = None

self._display: Callable[[plt.Figure], Optional[NDArray]]
if render_mode == "rgb_array":
self._display = self._display_rgb_array
elif render_mode == "human":
self._display = self._display_human
else:
raise ValueError(f"Invalid render mode: {render_mode}")
super().__init__(name, render_mode)

def render(self, maze: chex.Array) -> Optional[NDArray]:
"""
Expand Down Expand Up @@ -89,19 +77,19 @@ def animate(
Returns:
Animation that can be saved as a GIF, MP4, or rendered with HTML.
"""
fig, ax = plt.subplots(num=f"{self._name}Animation", figsize=self.FIGURE_SIZE)
plt.close(fig)
fig, ax = self._get_fig_ax(name_suffix="_animation", show=False)
plt.close(fig=fig)

def make_frame(maze_index: int) -> None:
def make_frame(maze: chex.Array) -> Tuple[Artist]:
ax.clear()
maze = mazes[maze_index]
self._add_grid_image(maze, ax)
return (ax,)

# Create the animation object.
self._animation = matplotlib.animation.FuncAnimation(
fig,
make_frame,
frames=len(mazes),
frames=mazes,
interval=interval,
)

Expand All @@ -111,20 +99,6 @@ def make_frame(maze_index: int) -> None:

return self._animation

def close(self) -> None:
plt.close(self._name)

def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]:
recreate = not plt.fignum_exists(self._name)
fig = plt.figure(self._name, figsize=self.FIGURE_SIZE)
if recreate:
if not plt.isinteractive():
fig.show()
ax = fig.add_subplot()
else:
ax = fig.get_axes()[0]
return fig, ax

def _add_grid_image(self, maze: chex.Array, ax: Axes) -> image.AxesImage:
img = self._create_grid_image(maze)
ax.set_axis_off()
Expand All @@ -137,24 +111,3 @@ def _create_grid_image(self, maze: chex.Array) -> NDArray:
# Draw black frame around maze by padding axis 0 and 1
img = np.pad(img, ((1, 1), (1, 1), (0, 0))) # type: ignore
return img

def _display_human(self, fig: plt.Figure) -> None:
if plt.isinteractive():
# Required to update render when using Jupyter Notebook.
fig.canvas.draw()
if jumanji.environments.is_colab():
plt.show(self._name)
else:
# Required to update render when not using Jupyter Notebook.
fig.canvas.draw_idle()
fig.canvas.flush_events()

def _display_rgb_array(self, fig: plt.Figure) -> NDArray:
fig.canvas.draw()
return np.asarray(fig.canvas.buffer_rgba())

def _clear_display(self) -> None:
if jumanji.environments.is_colab():
import IPython.display

IPython.display.clear_output(True)
Loading