Skip to content

Commit

Permalink
feat(trainer): saving game sample as a gif
Browse files Browse the repository at this point in the history
  • Loading branch information
Pierre Pereira committed Sep 11, 2023
1 parent 6a15735 commit b282514
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 17 deletions.
13 changes: 7 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@ authors = [
{name = "Pierrot LC", email = "[email protected]"},
]
dependencies = [
"torch>=2.0.1",
"einops>=0.6.1",
"gymnasium>=0.28.1",
"wandb>=0.15.3",
"matplotlib>=3.7.1",
"hydra-core>=1.3.2",
"torchinfo>=1.8.0",
"tqdm>=4.65.0",
"imageio>=2.31.3",
"matplotlib>=3.7.1",
"positional-encodings[pytorch]>=6.0.1",
"torchrl>=0.1.1",
"pytorch-optimizer>=2.11.1",
"torch>=2.0.1",
"torchinfo>=1.8.0",
"torchrl>=0.1.1",
"tqdm>=4.65.0",
"wandb>=0.15.3",
]
requires-python = ">=3.10"
readme = "README.md"
Expand Down
18 changes: 16 additions & 2 deletions src/environment/draw.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from itertools import product
from pathlib import Path
from typing import Optional, Union
from typing import Optional

import imageio
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import patches, path
Expand Down Expand Up @@ -103,7 +104,7 @@ def draw_triangles(ax: plt.Axes, x: int, y: int, tile: np.ndarray):
def draw_instance(
instance: np.ndarray,
score: float,
filename: Optional[Union[Path, str]] = None,
filename: Optional[Path | str] = None,
) -> np.ndarray:
_, height, width = instance.shape

Expand Down Expand Up @@ -137,3 +138,16 @@ def draw_instance(
plt.close(fig)

return image


def draw_gif(
instances: np.ndarray,
scores: np.ndarray,
filename: Path | str,
):
images = [
draw_instance(instance, score, filename=None)
for instance, score in zip(instances, scores)
]

imageio.mimsave(filename, images, duration=500)
49 changes: 46 additions & 3 deletions src/environment/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
SOUTH,
WEST,
)
from .draw import draw_instance
from .draw import draw_gif, draw_instance
from .generate import random_perfect_instances

# Defines convs that will compute vertical and horizontal matches.
Expand Down Expand Up @@ -61,6 +61,7 @@ def __init__(
instances: torch.Tensor,
device: str,
seed: int = 0,
sample_size: int = 40,
):
"""Initialize the environment.
Expand Down Expand Up @@ -99,6 +100,13 @@ def __init__(
)
self.best_matches_found = 0
self.total_won = 0
self.current_sample_step = 0
self.sample_size = sample_size
self.game_sample = torch.zeros(
(sample_size, N_SIDES, self.board_size, self.board_size),
dtype=torch.long,
device="cpu",
)

# Spaces.
# Those spaces do not take into account that
Expand Down Expand Up @@ -189,15 +197,17 @@ def step(

matches = self.matches

# Update envs infos.
self.update_best_env()
# rewards = matches / self.best_matches_possible - 1
self.game_sample[self.current_sample_step] = self.instances[0].cpu()
self.current_sample_step = (self.current_sample_step + 1) % self.sample_size

rewards = (matches - previous_matches) / self.best_matches_possible
max_matches = torch.stack((self.max_matches, matches), dim=1)
self.max_matches = torch.max(max_matches, dim=1)[0]
self.terminated |= matches == self.best_matches_possible
infos["just-won"] = self.terminated & ~previously_terminated
self.total_won += infos["just-won"].sum().cpu().item()

return self.render(), rewards, self.terminated, False, infos

def roll_tiles(self, tile_ids: torch.Tensor, shifts: torch.Tensor):
Expand Down Expand Up @@ -345,6 +355,7 @@ def render(self, mode: str = "computer") -> torch.Tensor:
---
Args:
mode: The rendering type.
Only "computer" is accepted.
---
Returns:
Expand All @@ -371,6 +382,38 @@ def save_best_env(self, filepath: Path | str):
"""Render the best environment and save it on disk."""
draw_instance(self.best_board.numpy(), self.best_matches_found, filepath)

def save_sample(self, filepath: Path | str):
"""Render the current game sample and save it as a GIF on disk."""
scores = EternityEnv.count_matches(self.game_sample)
draw_gif(self.game_sample.numpy(), scores.numpy(), filepath)

@staticmethod
def count_matches(instances: torch.Tensor) -> torch.Tensor:
"""Return the number of matches of the given instances.
---
Args:
instances: The instances of this environment.
Long tensor of shape of [batch_size, N_SIDES, size, size].
---
Returns:
The matches.
Long tensor of shape [batch_size,].
"""
n_matches = torch.zeros(instances.shape[0], device=instances.device)

for conv in [HORIZONTAL_CONV, VERTICAL_CONV]:
res = torch.conv2d(instances.float(), conv.to(instances.device))
n_matches += (res == 0).float().flatten(start_dim=1).sum(dim=1)

# Remove the 0-0 matches from the count.
for conv in [HORIZONTAL_ZERO_CONV, VERTICAL_ZERO_CONV]:
res = torch.conv2d(instances.float(), conv.to(instances.device))
n_matches -= (res == 0).float().flatten(start_dim=1).sum(dim=1)

return n_matches.long()

@staticmethod
def batched_roll(input_tensor: torch.Tensor, shifts: torch.Tensor) -> torch.Tensor:
"""Batched version of `torch.roll`.
Expand Down
15 changes: 9 additions & 6 deletions src/policy_gradient/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def launch_training(
group: str,
config: dict[str, Any],
mode: str = "online",
eval_every: int = 1,
save_every: int = 100,
):
"""Launches the training loop.
Expand Down Expand Up @@ -155,12 +155,12 @@ def launch_training(

self.scheduler.step()

if i % eval_every == 0 and not disable_logs:
metrics = self.evaluate()
metrics = self.evaluate()
run.log(metrics)

if i % save_every == 0 and not disable_logs:
self.save_model("model.pt")
self.env.save_best_env("board.png")
metrics["best-board"] = wandb.Image("board.png")
run.log(metrics)
self.env.save_sample("sample.gif")

def evaluate(self) -> dict[str, Any]:
"""Evaluates the model and returns some computed metrics."""
Expand Down Expand Up @@ -202,6 +202,9 @@ def evaluate(self) -> dict[str, Any]:
if isinstance(value, torch.Tensor):
metrics[name] = value.cpu().item()

self.env.save_best_env("board.png")
metrics["best-board"] = wandb.Image("board.png")

return metrics

def save_model(self, filepath: Path | str):
Expand Down

0 comments on commit b282514

Please sign in to comment.