Skip to content

Commit

Permalink
feat(trainer): sampling without replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
Pierrot LeCon committed Sep 18, 2023
1 parent 3b34557 commit fdf72a3
Show file tree
Hide file tree
Showing 8 changed files with 7 additions and 15 deletions.
1 change: 0 additions & 1 deletion configs/exp/hard.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ trainer:
iterations:
rollouts: 100
epochs: -1
batches: 100
batch_size: 5000

checkpoint:
1 change: 0 additions & 1 deletion configs/exp/normal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ trainer:
iterations:
rollouts: 60
epochs: -1
batches: 60
batch_size: 5000

checkpoint:
1 change: 0 additions & 1 deletion configs/exp/trivial.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ trainer:
iterations:
rollouts: 10
epochs: 500
batches: 10
batch_size: 2500

checkpoint:
1 change: 0 additions & 1 deletion configs/exp/trivial_B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ trainer:
iterations:
rollouts: 40
epochs: 500
batches: 40
batch_size: 5000

checkpoint:
3 changes: 1 addition & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def init_replay_buffer(config: DictConfig) -> ReplayBuffer:
max_size = exp.env.batch_size * exp.iterations.rollouts
return TensorDictReplayBuffer(
storage=LazyTensorStorage(max_size=max_size, device=config.device),
# sampler=SamplerWithoutReplacement(drop_last=True),
sampler=SamplerWithoutReplacement(drop_last=True),
batch_size=exp.iterations.batch_size,
pin_memory=True,
)
Expand Down Expand Up @@ -158,7 +158,6 @@ def init_trainer(
trainer.clip_value,
trainer.scramble_size,
iterations.rollouts,
iterations.batches,
iterations.epochs,
)

Expand Down
6 changes: 3 additions & 3 deletions notes/todo.norg
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ updated: 2023-09-16T01:26:36+0100
-- ( ) Try a pure on-policy learning.
-- ( ) Separate actor and critic.
-- (x) Reset terminated envs during rollouts.
-- ( ) Handle GAE multiple terminated envs in a batch sample.
-- (x) Handle GAE multiple terminated envs in a batch sample.
-- (x) Handle truncated envs.
-- (x) Remove masks.
-- ( ) Sample without replacement during training.
-- ( ) Why the value target is going above 1?
-- (x) Sample without replacement during training.
-- (x) Why the value target is going above 1?
- Curriculum learning.
-- trivial < 2x2 randoms < trivial_B < 3x3 randoms < ...
-- Randomly sample problem difficulties.
Expand Down
1 change: 0 additions & 1 deletion src/policy_gradient/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def split_reset_rollouts(traces: TensorDictBase) -> TensorDictBase:
device=tensor.device,
)
split_tensor[masks] = einops.rearrange(tensor, "b s ... -> (b s) ...")

split_traces[name] = split_tensor

split_traces["masks"] = masks
Expand Down
8 changes: 3 additions & 5 deletions src/policy_gradient/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(
clip_value: float,
scramble_size: float,
rollouts: int,
batches: int,
epochs: int,
):
self.env = env
Expand All @@ -42,7 +41,6 @@ def __init__(
self.replay_buffer = replay_buffer
self.clip_value = clip_value
self.rollouts = rollouts
self.batches = batches
self.epochs = epochs

self.scramble_size = int(scramble_size * self.env.batch_size)
Expand Down Expand Up @@ -145,13 +143,13 @@ def launch_training(
self.model.train()
self.do_rollouts(sampling_mode="softmax", disable_logs=disable_logs)

for _ in tqdm(
range(self.batches),
for batch in tqdm(
self.replay_buffer,
total=len(self.replay_buffer) // self.replay_buffer._batch_size,
desc="Batch",
leave=False,
disable=disable_logs,
):
batch = self.replay_buffer.sample()
self.do_batch_update(batch)

self.scheduler.step()
Expand Down

0 comments on commit fdf72a3

Please sign in to comment.