Skip to content

Commit

Permalink
[Feature] TD3-bc compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: a210a36df2e3da3426f1c06766f6185817b0ed29
Pull Request resolved: #2657
  • Loading branch information
vmoens committed Dec 16, 2024
1 parent 4718fd2 commit 4a06b11
Show file tree
Hide file tree
Showing 15 changed files with 499 additions and 500 deletions.
6 changes: 1 addition & 5 deletions sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
frames_per_batch = cfg.collector.frames_per_batch
evaluation_interval = cfg.logger.log_interval
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_rollout_steps = cfg.env.max_episode_steps
eval_iter = cfg.logger.eval_iter
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,7 @@ def update(sampled_tensordict: TensorDict, update_actor: bool):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
frames_per_batch = cfg.collector.frames_per_batch
eval_iter = cfg.logger.eval_iter
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_rollout_steps = cfg.env.max_episode_steps
eval_iter = cfg.logger.eval_iter
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/iql/discrete_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/iql/iql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,7 @@ def update(sampled_tensordict):
collected_frames = 0

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
Expand Down
5 changes: 5 additions & 0 deletions sota-implementations/td3/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,8 @@ logger:
mode: online
eval_iter: 25000
video: False

compile:
compile: False
compile_mode:
cudagraphs: False
Loading

0 comments on commit 4a06b11

Please sign in to comment.