diff --git a/configs/exp/hard.yaml b/configs/exp/hard.yaml index 7e4eefc..7205cd8 100644 --- a/configs/exp/hard.yaml +++ b/configs/exp/hard.yaml @@ -37,7 +37,6 @@ trainer: iterations: rollouts: 100 epochs: -1 - batches: 100 batch_size: 5000 checkpoint: diff --git a/configs/exp/normal.yaml b/configs/exp/normal.yaml index d21b23e..1ef4b74 100644 --- a/configs/exp/normal.yaml +++ b/configs/exp/normal.yaml @@ -37,7 +37,6 @@ trainer: iterations: rollouts: 60 epochs: -1 - batches: 60 batch_size: 5000 checkpoint: diff --git a/configs/exp/trivial.yaml b/configs/exp/trivial.yaml index 76d135a..6378c5b 100644 --- a/configs/exp/trivial.yaml +++ b/configs/exp/trivial.yaml @@ -37,7 +37,6 @@ trainer: iterations: rollouts: 10 epochs: 500 - batches: 10 batch_size: 2500 checkpoint: diff --git a/configs/exp/trivial_B.yaml b/configs/exp/trivial_B.yaml index 0b13e14..133e9eb 100644 --- a/configs/exp/trivial_B.yaml +++ b/configs/exp/trivial_B.yaml @@ -37,7 +37,6 @@ trainer: iterations: rollouts: 40 epochs: 500 - batches: 40 batch_size: 5000 checkpoint: diff --git a/main.py b/main.py index 057f626..049ba80 100644 --- a/main.py +++ b/main.py @@ -130,7 +130,7 @@ def init_replay_buffer(config: DictConfig) -> ReplayBuffer: max_size = exp.env.batch_size * exp.iterations.rollouts return TensorDictReplayBuffer( storage=LazyTensorStorage(max_size=max_size, device=config.device), - # sampler=SamplerWithoutReplacement(drop_last=True), + sampler=SamplerWithoutReplacement(drop_last=True), batch_size=exp.iterations.batch_size, pin_memory=True, ) @@ -158,7 +158,6 @@ def init_trainer( trainer.clip_value, trainer.scramble_size, iterations.rollouts, - iterations.batches, iterations.epochs, ) diff --git a/notes/todo.norg b/notes/todo.norg index 75a33ce..4443602 100644 --- a/notes/todo.norg +++ b/notes/todo.norg @@ -13,11 +13,11 @@ updated: 2023-09-16T01:26:36+0100 -- ( ) Try a pure on-policy learning. -- ( ) Separate actor and critic. -- (x) Reset terminated envs during rollouts. - -- ( ) Handle GAE multiple terminated envs in a batch sample. + -- (x) Handle GAE multiple terminated envs in a batch sample. -- (x) Handle truncated envs. -- (x) Remove masks. - -- ( ) Sample without replacement during training. - -- ( ) Why the value target is going above 1? + -- (x) Sample without replacement during training. + -- (x) Why the value target is going above 1? - Curriculum learning. -- trivial < 2x2 randoms < trivial_B < 3x3 randoms < ... -- Randomly sample problem difficulties. diff --git a/src/policy_gradient/rollout.py b/src/policy_gradient/rollout.py index 612bb79..b6703fc 100644 --- a/src/policy_gradient/rollout.py +++ b/src/policy_gradient/rollout.py @@ -127,7 +127,6 @@ def split_reset_rollouts(traces: TensorDictBase) -> TensorDictBase: device=tensor.device, ) split_tensor[masks] = einops.rearrange(tensor, "b s ... -> (b s) ...") - split_traces[name] = split_tensor split_traces["masks"] = masks diff --git a/src/policy_gradient/trainer.py b/src/policy_gradient/trainer.py index b7f66b4..c652612 100644 --- a/src/policy_gradient/trainer.py +++ b/src/policy_gradient/trainer.py @@ -31,7 +31,6 @@ def __init__( clip_value: float, scramble_size: float, rollouts: int, - batches: int, epochs: int, ): self.env = env @@ -42,7 +41,6 @@ def __init__( self.replay_buffer = replay_buffer self.clip_value = clip_value self.rollouts = rollouts - self.batches = batches self.epochs = epochs self.scramble_size = int(scramble_size * self.env.batch_size) @@ -145,13 +143,13 @@ def launch_training( self.model.train() self.do_rollouts(sampling_mode="softmax", disable_logs=disable_logs) - for _ in tqdm( - range(self.batches), + for batch in tqdm( + self.replay_buffer, + total=len(self.replay_buffer) // self.replay_buffer._batch_size, desc="Batch", leave=False, disable=disable_logs, ): - batch = self.replay_buffer.sample() self.do_batch_update(batch) self.scheduler.step()