Skip to content

Commit

Permalink
feat: soft implementation with tests!
Browse files Browse the repository at this point in the history
  • Loading branch information
Pierrot LeCon committed Jul 14, 2023
1 parent f7f7ccb commit 06f5aaf
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 12 deletions.
10 changes: 5 additions & 5 deletions eternity-doc/todo.norg
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ title: todo
authors: pierrotlc
categories: notes
created: 2023-03-08
updated: 2023-07-13
updated: 2023-07-14
@end

* TODO
- ( ) Add a rotation invariant regularizer loss (if not using a GNN or G-CNN).
- ( ) Cut the rollout based on the last minimum score (remove the first useless moves).
- (x) Start with a CNN encoder for local embeddings and finish by a transformer encoder
for global embeddings.
- ( ) Go back to a rollout buffer, remove the recurrent model and add the optional
- (x) Go back to a rollout buffer, remove the recurrent model and add the optional
timestep encoding.
- ( ) Do a special rollout by using a batched soft MCTS (see {*Ideas}) to collect better actions.
-- It will be much longer to collect rollouts, hence the use of a rollout buffer.
-- Possibility to add exploration by using exploration sampling methods.
- (-) Do a special rollout by using a batched soft MCTS (see {*Ideas}) to collect better actions.
-- (-) It will be much longer to collect rollouts, hence the use of a rollout buffer.
-- (-) Possibility to add exploration by using exploration sampling methods.
- ( ) Start some rollouts from previous best boards.

* Ideas
Expand Down
6 changes: 3 additions & 3 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ tests:
python3 -m pytest --import-mode importlib .

trivial:
python3 main.py exp=trivial
python3 main.py exp=trivial mode=offline

trivial_B:
python3 main.py exp=trivial_B
python3 main.py exp=trivial_B mode=offline

normal:
python3 main.py exp=normal
python3 main.py exp=normal mode=offline
12 changes: 9 additions & 3 deletions src/environment/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
All actions are made on the batch.
"""
from pathlib import Path
from typing import Any
from typing import Any, Optional

import gymnasium as gym
import gymnasium.spaces as spaces
Expand Down Expand Up @@ -129,12 +129,18 @@ def __init__(
low=0, high=1, shape=self.instances.shape[1:], dtype=np.uint8
)

def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
def reset(
self, instances: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, dict[str, Any]]:
"""Reset the environment.
Scrambles the instances and reset their infos.
"""
self.scramble_instances()
if instances is not None:
self.instances = instances.detach()
else:
self.scramble_instances()

self.step_id = 0
self.truncated = False
self.terminated = torch.zeros(
Expand Down
146 changes: 146 additions & 0 deletions src/mcts/soft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""Soft Monte Carlo Tree Search implementation.
Differences with pure MCTS:
- Only simple rollouts to get a crude estimate of the next best actions.
- No exploration/exploitation intermediate states.
- No backpropagation, pure simulation.
This simplications are there to enable better exploitation of the batch parallelism.
"""
import torch
from einops import rearrange
from math import prod

from ..model import Policy, N_SIDES
from ..environment import EternityEnv


class SoftMCTS:
def __init__(self, env: EternityEnv, model: Policy, device: str | torch.device):
self.env = env
self.model = model
self.device = device
self.n_simulations = 10
self.sampling_mode = "sample"

self.root_instances = env.render().detach()
self.root_step_id = env.step_id

self.action_returns = torch.zeros(
(env.batch_size, env.board_size, env.board_size, N_SIDES, N_SIDES),
dtype=torch.float32,
device=device,
)
self.action_visits = torch.zeros(
(env.batch_size, env.board_size, env.board_size, N_SIDES, N_SIDES),
dtype=torch.long,
device=device,
)

@torch.inference_mode()
def simulation(self):
"""Do a Monte Carlo simulation from the root instances until
the episodes are finished.
"""
batch_size, board_height, board_width, _, _ = self.action_returns.shape

states, _ = self.env.reset(self.root_instances)
root_actions, _, _ = self.model(states, self.sampling_mode)
states, rewards, _, _, infos = self.env.step(root_actions)

while not self.env.truncated and not torch.all(self.env.terminated):
actions, _, _ = self.model(states, self.sampling_mode)
states, rewards, _, _, infos = self.env.step(actions)

returns = self.env.max_matches / self.env.best_matches
self.action_returns = SoftMCTS.batched_add(
self.action_returns, root_actions, returns
)
self.action_visits = SoftMCTS.batched_add(self.action_visits, root_actions, 1)

@torch.inference_mode()
def run(self) -> torch.Tensor:
"""Do the simulations and return the best action found for each instance."""
for _ in range(self.n_simulations):
self.simulation()

scores = self.action_returns / self.action_visits
return SoftMCTS.best_actions(scores)

@staticmethod
def batched_add(
input_tensor: torch.Tensor, actions: torch.Tensor, to_add: torch.Tensor | float
) -> torch.Tensor:
"""Add to the input tensor the elements at the given actions indices.
---
Args:
input_tensor: Tensor to add the elements to.
Shape of [batch_size, n_actions_1, n_actions_2, n_actions_3, n_actions_4].
actions: Indices to which we add the elements in the input.
Shape of [batch_size, 4].
to_add: Elements to add to the input.
Shape of [batch_size,].
---
Returns:
The input tensor with the elements added.
Shape of [batch_size, n_actions_1, n_actions_2, n_actions_3, n_actions_4].
"""
(
batch_size,
n_actions_1,
n_actions_2,
n_actions_3,
n_actions_4,
) = input_tensor.shape
input_tensor = input_tensor.flatten()
indices = (
actions[:, 0] * n_actions_1
+ actions[:, 1] * n_actions_2
+ actions[:, 2] * n_actions_3
+ actions[:, 3] * n_actions_4
)
offsets = torch.arange(
start=0,
end=input_tensor.shape[0],
step=input_tensor.shape[0] // batch_size,
)
indices = indices + offsets

input_tensor[indices] += to_add
input_tensor = rearrange(
input_tensor,
"(b a1 a2 a3 a4) -> b a1 a2 a3 a4",
a1=n_actions_1,
a2=n_actions_2,
a3=n_actions_3,
a4=n_actions_4,
)
return input_tensor

@staticmethod
def best_actions(scores: torch.Tensor) -> torch.Tensor:
"""Return the coordinates maximizing the score for each instance.
---
Args:
scores: The scores of each pairs of (instance, actions).
Tensor of shape [batch_size, n_actions_1, n_actions_2, n_actions_3, n_actions_4].
---
Returns:
The coordinates of the best actions for each instance.
Tensor of shape [batch_size, 4].
"""
actions_shape = scores.shape[1:]
n_elements = prod(actions_shape)
scores = scores.flatten(start_dim=1)
best_scores = scores.argmax(dim=1)

best_actions = []
for n_actions in actions_shape:
n_elements //= n_actions
coord, best_scores = best_scores // n_elements, best_scores % n_elements
best_actions.append(coord)

return torch.stack(best_actions, dim=1)
87 changes: 87 additions & 0 deletions src/mcts/test_mcts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from itertools import product
import pytest
import torch

from .soft import SoftMCTS


@pytest.mark.parametrize(
"input_tensor, actions, to_add",
[
(
torch.randn(4, 4, 4, 2, 2),
torch.LongTensor(
[
[3, 2, 0, 1],
[3, 3, 0, 1],
[3, 3, 0, 1],
[3, 3, 0, 1],
]
),
torch.randn(4),
),
(
torch.randn(4, 4, 4, 2, 2),
torch.LongTensor(
[
[3, 2, 0, 1],
[3, 3, 0, 1],
[3, 3, 0, 1],
[3, 3, 0, 1],
]
),
1,
),
(
torch.randn(4, 8, 2, 5, 8),
torch.LongTensor(
[
[6, 0, 1, 7],
[3, 1, 2, 1],
[4, 1, 4, 0],
[2, 0, 0, 1],
]
),
torch.randn(4),
),
],
)
def test_batched_add(
input_tensor: torch.Tensor, actions: torch.Tensor, to_add: torch.Tensor | float
):
true_output = input_tensor.detach()
for sample_id in range(input_tensor.shape[0]):
action = actions[sample_id]
el = to_add[sample_id] if type(to_add) is torch.Tensor else to_add
true_output[sample_id, action[0], action[1], action[2], action[3]] += el

output = SoftMCTS.batched_add(input_tensor, actions, to_add)
assert torch.all(output == true_output)


@pytest.mark.parametrize(
"scores",
[
torch.randn(12, 4, 4, 2, 2),
torch.randn(12, 8, 2, 5, 8),
torch.randn(2, 4, 4, 2, 2),
],
)
def test_best_actions(scores: torch.Tensor):
output = SoftMCTS.best_actions(scores)
true_output = []
for batch_id in range(scores.shape[0]):
best_actions = None
best_score = float("-inf")
for action_1, action_2, action_3, action_4 in product(
*[range(scores.shape[i]) for i in range(1, 5)]
):
score = scores[batch_id, action_1, action_2, action_3, action_4]
if score > best_score:
best_score = score
best_actions = (action_1, action_2, action_3, action_4)

true_output.append(torch.LongTensor(best_actions))

true_output = torch.stack(true_output, dim=0)
assert torch.all(true_output == output)
2 changes: 1 addition & 1 deletion src/reinforce/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def do_rollouts(self, sampling_mode: str):
states, _ = self.env.reset()

while not self.env.truncated and not torch.all(self.env.terminated):
actions, logprobs, entropies = self.model(states, sampling_mode)
actions, _, _ = self.model(states, sampling_mode)
new_states, rewards, _, _, infos = self.env.step(actions)
self.rollout_buffer.store(
states,
Expand Down

0 comments on commit 06f5aaf

Please sign in to comment.