-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Matplotlib upgrade #270
Open
zombie-einstein
wants to merge
60
commits into
instadeepai:main
Choose a base branch
from
zombie-einstein:matplotlib_upgrade
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Matplotlib upgrade #270
Changes from all commits
Commits
Show all changes
60 commits
Select commit
Hold shift + click to select a range
3597d9e
feat: Implement predator prey env (#1)
zombie-einstein c955320
Merge branch 'instadeepai:main' into main
zombie-einstein 6b34657
Merge branch 'main' into main
sash-a 988339b
fix: PR fixes (#2)
zombie-einstein a0fe7a5
Merge branch 'instadeepai:main' into main
zombie-einstein b4cce01
style: Run updated pre-commit
zombie-einstein cb6d88d
refactor: Consolidate predator prey type
zombie-einstein 06de3a0
feat: Implement search and rescue (#3)
zombie-einstein 34beab6
fix: PR fixes (#4)
zombie-einstein f5fa659
Merge branch 'instadeepai:main' into main
zombie-einstein 072db18
refactor: PR fixes (#5)
zombie-einstein 162a74d
feat: Allow variable environment dimensions (#6)
zombie-einstein 4996869
Merge branch 'main' into main
zombie-einstein 6322f61
fix: Locate targets in single pass (#8)
zombie-einstein 4ba7688
Merge branch 'instadeepai:main' into main
zombie-einstein 9a654b9
feat: training and customisable observations (#7)
zombie-einstein 5021e20
feat: view all targets (#9)
zombie-einstein c5c7b85
Merge branch 'instadeepai:main' into main
zombie-einstein 13ffb84
Merge branch 'instadeepai:main' into main
zombie-einstein 9e8ac5c
feat: Scaled rewards and target velocities (#10)
zombie-einstein 5c509c7
Pass shape information to timesteps (#11)
zombie-einstein 8acf242
test: extend tests and docs (#12)
zombie-einstein 1792aa6
fix: unpin jax requirement
zombie-einstein 1e66e78
Include agent positions in observation (#13)
zombie-einstein 296e98a
Update animation functions
zombie-einstein fe8880a
Update rubiks cube viewer for new API
zombie-einstein 407ff79
Upgrade Esquilax and remove unused random keys (#14)
zombie-einstein b52fefd
Address PR comments
zombie-einstein 04fe710
docs: Review docstrings and docs (#15)
zombie-einstein ac3f811
fix: Remove enum annotations
zombie-einstein 943a51b
refactor: address pr comments (#16)
zombie-einstein 6a3fdb1
Parameter tweaks
zombie-einstein ac8838f
refactor: Observation tweaks (#17)
zombie-einstein 05eeedf
refactor: address pr comments (#18)
zombie-einstein 5353ef7
chore: revert to using set colours
sash-a bc9e252
fix: minor training bug due to refactor
sash-a eac2f1f
chore: update default parameters to ones tested in mava
sash-a 79a7aa8
chore: add search and rescue to the readme
sash-a de8e869
Update graph-coloring viewer
zombie-einstein 243ddde
Update mmst
zombie-einstein 211d578
Merge branch 'main' of github.com:zombie-einstein/jumanji into matplo…
zombie-einstein cbb159d
Merge branch 'instadeepai:main' into main
zombie-einstein 385935d
Merge branch 'main' of github.com:zombie-einstein/jumanji into matplo…
zombie-einstein ffb9c6f
Update suduko viewer
zombie-einstein cadb337
Fix tsp viewer
zombie-einstein 3bcb2a5
Refactor graph coloring for multiple episodes
zombie-einstein 5b6ada7
Fix coloring and mmst for multiple episodes
zombie-einstein 759cc45
Fix knapsack image loading
zombie-einstein 5361a75
Replace pkg_resources usage and add animation test script
zombie-einstein b5613a1
Fix refresh of cvrp on new episode
zombie-einstein dbc52bb
Refresh tsp at new episode
zombie-einstein 4c3a139
Refresh multi-crvp at new episode
zombie-einstein e15e70e
Cleanup and refactor graph layout functionality
zombie-einstein 71ca43a
Address PR comments
zombie-einstein 43bbb77
Consolidate viewer functionality
zombie-einstein 7bfd2c8
Consolidate animation creation
zombie-einstein bec0ac5
Viewer tweaks
zombie-einstein 7feba67
Fix cmap warnings
zombie-einstein 489ff67
Update jupyter backend and animation rendering
zombie-einstein 16d58b7
Revert sudoku animation change
zombie-einstein File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# Copyright 2022 InstaDeep Ltd. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import os | ||
|
||
import jax | ||
import requests | ||
from hydra import compose, initialize | ||
|
||
from jumanji.training.setup_train import setup_agent, setup_env | ||
|
||
envs = [ | ||
"bin_pack", | ||
"cleaner", | ||
"connector", | ||
"cvrp", | ||
"flat_pack", | ||
"game_2048", | ||
"graph_coloring", | ||
"job_shop", | ||
"knapsack", | ||
"lbf", | ||
"maze", | ||
"minesweeper", | ||
"mmst", | ||
"multi_cvrp", | ||
"pac_man", | ||
"robot_warehouse", | ||
"rubiks_cube", | ||
"search_and_rescue", | ||
"sliding_tile_puzzle", | ||
"snake", | ||
"sokoban", | ||
"sudoku", | ||
"tetris", | ||
"tsp", | ||
] | ||
|
||
|
||
def download_file(url: str, file_path: str) -> None: | ||
# Send an HTTP GET request to the URL | ||
response = requests.get(url) | ||
# Check if the request was successful (status code 200) | ||
if response.status_code == 200: | ||
with open(file_path, "wb") as f: | ||
f.write(response.content) | ||
else: | ||
print("Failed to download the file.") | ||
|
||
|
||
def create_animation(env_name: str, agent: str = "random", num_episodes: int = 2) -> None: | ||
print(f"Animating {env_name}") | ||
|
||
os.makedirs("configs", exist_ok=True) | ||
config_url = "https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/config.yaml" | ||
download_file(config_url, "configs/config.yaml") | ||
env_url = f"https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/env/{env_name}.yaml" | ||
os.makedirs("configs/env", exist_ok=True) | ||
download_file(env_url, f"configs/env/{env_name}.yaml") | ||
os.makedirs("animations", exist_ok=True) | ||
|
||
with initialize(version_base=None, config_path="configs"): | ||
cfg = compose(config_name="config.yaml", overrides=[f"env={env_name}", f"agent={agent}"]) | ||
|
||
env = setup_env(cfg).unwrapped | ||
agent = setup_agent(cfg, env) | ||
policy = jax.jit(agent.make_policy(stochastic=False)) | ||
|
||
reset_fn = jax.jit(env.reset) | ||
step_fn = jax.jit(env.step) | ||
states = [] | ||
key = jax.random.PRNGKey(cfg.seed) | ||
|
||
for _ in range(num_episodes): | ||
key, reset_key = jax.random.split(key) | ||
state, timestep = reset_fn(reset_key) | ||
states.append(state) | ||
|
||
while not timestep.last(): | ||
key, action_key = jax.random.split(key) | ||
observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation) | ||
action = policy(observation, action_key) | ||
state, timestep = step_fn(state, action.squeeze(axis=0)) | ||
states.append(state) | ||
|
||
env.animate(states, 100, f"animations/{env_name}_animation.gif") | ||
|
||
|
||
if __name__ == "__main__": | ||
cli = argparse.ArgumentParser() | ||
cli.add_argument( | ||
"envs", | ||
nargs="*", | ||
type=str, | ||
default=None, | ||
) | ||
|
||
args = cli.parse_args() | ||
arg_envs = args.envs | ||
env_list = envs if len(arg_envs) == 0 else args.envs | ||
|
||
for env in env_list: | ||
try: | ||
create_animation(env) | ||
except Exception as e: | ||
print(f"{env} failed", e) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright 2022 InstaDeep Ltd. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import List, Tuple | ||
|
||
import chex | ||
import numpy as np | ||
|
||
|
||
def _compute_repulsive_forces( | ||
repulsive_forces: np.ndarray, pos: np.ndarray, k: float, num_nodes: int | ||
) -> np.ndarray: | ||
for i in range(num_nodes): | ||
for j in range(i + 1, num_nodes): | ||
delta = pos[i] - pos[j] | ||
distance = np.linalg.norm(delta) | ||
direction = delta / (distance + 1e-6) | ||
force = k * k / (distance + 1e-6) | ||
repulsive_forces[i] += direction * force | ||
repulsive_forces[j] -= direction * force | ||
|
||
return repulsive_forces | ||
|
||
|
||
def _compute_attractive_forces( | ||
graph: chex.Array, | ||
attractive_forces: np.ndarray, | ||
pos: np.ndarray, | ||
k: float, | ||
num_nodes: int, | ||
) -> np.ndarray: | ||
for i in range(num_nodes): | ||
for j in range(num_nodes): | ||
if graph[i, j]: | ||
delta = pos[i] - pos[j] | ||
distance = np.linalg.norm(delta) | ||
direction = delta / (distance + 1e-6) | ||
force = distance * distance / k | ||
attractive_forces[i] -= direction * force | ||
attractive_forces[j] += direction * force | ||
|
||
return attractive_forces | ||
|
||
|
||
def spring_layout( | ||
graph: chex.Array, num_nodes: int, seed: int = 42, iterations: int = 100 | ||
) -> List[Tuple[float, float]]: | ||
""" | ||
Compute a 2D spring layout for the given graph using | ||
the Fruchterman-Reingold force-directed algorithm. | ||
|
||
The algorithm computes a layout by simulating the graph as a physical system, | ||
where nodes are repelling each other and edges are attracting connected nodes. | ||
The method minimizes the energy of the system over several iterations. | ||
|
||
Args: | ||
graph: A Graph object representing the adjacency matrix of the graph. | ||
num_nodes: Number of graph nodes. | ||
seed: An integer used to seed the random number generator for reproducibility. | ||
iterations: Number of layout refining iterations. | ||
|
||
Returns: | ||
A list of tuples representing the 2D positions of nodes in the graph. | ||
""" | ||
rng = np.random.default_rng(seed) | ||
pos = rng.random((num_nodes, 2)) * 2 - 1 | ||
|
||
k = np.sqrt(5 / num_nodes) | ||
temperature = 2.0 # Added a temperature variable | ||
|
||
for _ in range(iterations): | ||
repulsive_forces = _compute_repulsive_forces(np.zeros((num_nodes, 2)), pos, k, num_nodes) | ||
attractive_forces = _compute_attractive_forces( | ||
graph, np.zeros((num_nodes, 2)), pos, k, num_nodes | ||
) | ||
|
||
pos += (repulsive_forces + attractive_forces) * temperature | ||
# Reduce the temperature (cooling factor) to refine the layout. | ||
temperature *= 0.9 | ||
|
||
pos = np.clip(pos, -1, 1) # Keep positions within the [-1, 1] range | ||
|
||
# Scale positions to fill figure | ||
pos_max = np.max(pos, axis=0) | ||
pos_min = np.min(pos, axis=0) | ||
pos = 0.05 + (pos - pos_min) / (1.1 * (pos_max - pos_min)) | ||
|
||
return [(float(p[0]), float(p[1])) for p in pos] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is now the correct way to update the render in a jupyter notebook (as mentioned in #272). I've tested this locally with what I guess is fairly recent version of jupyter. Only potential issue is if this breaks behaviour of older versions of notebooks?
Without this change I just get the message
Javascript Error: IPython is not defined
when trying to call the render method in a notebook.