Skip to content

Commit

Permalink
Merge pull request #6 from pierrot-lc/transformer-only
Browse files Browse the repository at this point in the history
  • Loading branch information
Pierrot authored Sep 11, 2023
2 parents 33c86b8 + c4d2d77 commit 6a15735
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 93 deletions.
15 changes: 7 additions & 8 deletions configs/exp/hard.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@ env:
batch_size: 10000

model:
n_channels: 16
embedding_dim: 12
n_heads: 1
backbone_cnn_layers: 6
backbone_transformer_layers: 3
decoder_layers: 3
embedding_dim: 64
n_heads: 2
backbone_transformer_layers: 6
decoder_layers: 2
dropout: 0.05

optimizer:
Expand All @@ -21,7 +19,8 @@ optimizer:

scheduler:
warmup_steps: 0
cosine_tmax: 0
cosine_t0: 0
cosine_tmult: 0

loss:
gamma: 0.99
Expand All @@ -33,7 +32,7 @@ loss:

trainer:
clip_value: 1.0
scramble_size: 0.20
scramble_size: 0.50

iterations:
rollouts: 100
Expand Down
21 changes: 10 additions & 11 deletions configs/exp/normal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,41 @@ group: Normal

env:
path: ./instances/eternity_A.txt
batch_size: 5000
batch_size: 10000

model:
n_channels: 16
embedding_dim: 12
n_heads: 1
backbone_cnn_layers: 5
backbone_transformer_layers: 1
embedding_dim: 64
n_heads: 4
backbone_transformer_layers: 6
decoder_layers: 2
dropout: 0.05

optimizer:
optimizer: adamw
learning_rate: 1.0e-3
learning_rate: 1.0e-4
weight_decay: 1.0e-3

scheduler:
warmup_steps: 0
cosine_tmax: 0
cosine_t0: 0
cosine_tmult: 0

loss:
gamma: 0.99
gae_lambda: 0.95
ppo_clip_ac: 0.30
ppo_clip_vf: 0.30
value_weight: 1.0e-3
entropy_weight: 5.0e-4
entropy_weight: 1.0e-3

trainer:
clip_value: 1.0
scramble_size: 0.20
scramble_size: 0.50

iterations:
rollouts: 60
epochs: -1
batches: 60
batch_size: 2500
batch_size: 5000

checkpoint:
15 changes: 7 additions & 8 deletions configs/exp/trivial.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,21 @@ env:
batch_size: 5000

model:
n_channels: 6
embedding_dim: 12
n_heads: 1
backbone_cnn_layers: 4
backbone_transformer_layers: 1
embedding_dim: 24
n_heads: 2
backbone_transformer_layers: 4
decoder_layers: 1
dropout: 0.05

optimizer:
optimizer: lion
optimizer: adamw
learning_rate: 1.0e-3
weight_decay: 1.0e-4

scheduler:
warmup_steps: 0
cosine_tmax: 0
cosine_t0: 0
cosine_tmult: 0

loss:
gamma: 0.99
Expand All @@ -33,7 +32,7 @@ loss:

trainer:
clip_value: 1.0
scramble_size: 0.25
scramble_size: 0.50

iterations:
rollouts: 10
Expand Down
17 changes: 8 additions & 9 deletions configs/exp/trivial_B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,24 @@ group: Trivial B

env:
path: ./instances/eternity_trivial_B.txt
batch_size: 5000
batch_size: 10000

model:
n_channels: 9
embedding_dim: 16
embedding_dim: 24
n_heads: 2
backbone_cnn_layers: 4
backbone_transformer_layers: 1
backbone_transformer_layers: 4
decoder_layers: 1
dropout: 0.05

optimizer:
learning_rate: 1.0e-3
optimizer: lion
optimizer: adamw
weight_decay: 1.0e-4

scheduler:
warmup_steps: 0
cosine_tmax: 0
cosine_t0: 0
cosine_tmult: 0

loss:
gamma: 0.99
Expand All @@ -33,12 +32,12 @@ loss:

trainer:
clip_value: 1.0
scramble_size: 0.10
scramble_size: 0.50

iterations:
rollouts: 40
epochs: 100
batches: 40
batch_size: 2500
batch_size: 5000

checkpoint:
11 changes: 5 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,10 @@ def init_model(config: DictConfig, env: EternityEnv) -> Policy:
"""Initialize the model."""
model = config.exp.model
return Policy(
n_classes=env.n_classes,
board_width=env.board_size,
board_height=env.board_size,
n_channels=model.n_channels,
embedding_dim=model.embedding_dim,
n_heads=model.n_heads,
backbone_cnn_layers=model.backbone_cnn_layers,
backbone_transformer_layers=model.backbone_transformer_layers,
decoder_layers=model.decoder_layers,
dropout=model.dropout,
Expand Down Expand Up @@ -107,10 +104,11 @@ def init_scheduler(
)
schedulers.append(warmup_scheduler)

if scheduler.cosine_tmax > 0:
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
if scheduler.cosine_t0 > 0:
lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer=optimizer,
T_max=scheduler.cosine_tmax,
T_0=scheduler.cosine_t0,
T_mult=scheduler.cosine_tmult,
)
schedulers.append(lr_scheduler)

Expand Down Expand Up @@ -173,6 +171,7 @@ def reload_checkpoint(config: DictConfig, trainer: Trainer):
trainer.model.load_state_dict(state_dict["model"])
# HACK: The training seems to not be stable when loading the optimizer state.
# trainer.optimizer.load_state_dict(state_dict["optimizer"])
trainer.scheduler.load_state_dict(state_dict["scheduler"])
print(f"Checkpoint from {checkpoint_path} loaded.")


Expand Down
53 changes: 11 additions & 42 deletions src/model/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
from positional_encodings.torch_encodings import PositionalEncoding2D
from positional_encodings.torch_encodings import PositionalEncoding2D, Summer

from ..environment import N_SIDES
from .class_encoding import ClassEncoding


class Backbone(nn.Module):
Expand All @@ -24,47 +25,25 @@ class Backbone(nn.Module):

def __init__(
self,
n_classes: int,
n_channels: int,
embedding_dim: int,
n_heads: int,
cnn_layers: int,
transformer_layers: int,
dropout: float,
):
super().__init__()

self.embed_board = nn.Sequential(
# Embed the classes of each size of the tiles.
Rearrange("b t h w -> b h w t"),
nn.Embedding(n_classes, n_channels),
# Encode the classes.
ClassEncoding(embedding_dim),
# Merge the classes of each tile into a single embedding.
Rearrange("b h w t e -> b (t e) h w"),
nn.Conv2d(N_SIDES * n_channels, n_channels, kernel_size=1, padding="same"),
nn.GELU(),
nn.GroupNorm(1, n_channels),
Rearrange("b t h w e -> b h w (t e)"),
nn.Linear(N_SIDES * embedding_dim, embedding_dim),
# Add the 2D positional encodings.
Summer(PositionalEncoding2D(embedding_dim)),
# To transformer layout.
Rearrange("b h w e -> (h w) b e"),
)
self.linear = nn.Sequential(
nn.Linear(n_channels, embedding_dim),
nn.LayerNorm(embedding_dim),
)
self.positional_enc = PositionalEncoding2D(embedding_dim)

self.cnn_layers = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(
n_channels,
n_channels,
kernel_size=3,
padding="same",
),
nn.GELU(),
nn.GroupNorm(1, n_channels),
)
for _ in range(cnn_layers)
]
)
self.transformer_layers = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=embedding_dim,
Expand Down Expand Up @@ -92,16 +71,6 @@ def forward(
The embedded game state.
Shape of [board_height x board_width, batch_size, embedding_dim].
"""
batch_size, _, board_height, board_width = tiles.shape

tiles = self.embed_board(tiles)
for layer in self.cnn_layers:
tiles = layer(tiles) + tiles

tokens = rearrange(tiles, "b e h w -> b h w e")
tokens = self.linear(tokens)
tokens = self.positional_enc(tokens) + tokens
tokens = rearrange(tokens, "b h w e -> (h w) b e")
tokens = self.embed_board(tiles)
tokens = self.transformer_layers(tokens)

return tokens
6 changes: 3 additions & 3 deletions src/model/class_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def forward(self, board: torch.Tensor) -> torch.Tensor:
---
Args:
board: The timesteps of the game states.
Shape of [batch_size, board_height, board_width, N_SIDES].
Shape of [...].
---
Returns:
The board with encoded classes.
Shape of [batch_size, board_height, board_width, N_SIDES, embedding_dim].
Shape of [..., embedding_dim].
"""
assert board.max() <= self.embedding_dim, "Not enough orthogonal vectors!"
assert board.max() < self.embedding_dim, "Not enough orthogonal vectors!"

# Encode the classes of each tile.
return self.class_enc(board)
6 changes: 0 additions & 6 deletions src/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@
class Policy(nn.Module):
def __init__(
self,
n_classes: int,
board_width: int,
board_height: int,
n_channels: int,
embedding_dim: int,
n_heads: int,
backbone_cnn_layers: int,
backbone_transformer_layers: int,
decoder_layers: int,
dropout: float,
Expand All @@ -33,11 +30,8 @@ def __init__(
self.embedding_dim = embedding_dim

self.backbone = Backbone(
n_classes,
n_channels,
embedding_dim,
n_heads,
backbone_cnn_layers,
backbone_transformer_layers,
dropout,
)
Expand Down
2 changes: 2 additions & 0 deletions src/policy_gradient/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,12 @@ def evaluate(self) -> dict[str, Any]:
def save_model(self, filepath: Path | str):
model_state = self.model.state_dict()
optimizer_state = self.optimizer.state_dict()
scheduler_state = self.scheduler.state_dict()
torch.save(
{
"model": model_state,
"optimizer": optimizer_state,
"scheduler": scheduler_state,
},
filepath,
)

0 comments on commit 6a15735

Please sign in to comment.