Skip to content

Commit

Permalink
[Feature] TD3-bc compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 0db0897be82b4c663ff596db895e1a63fc0c5b5d
Pull Request resolved: #2657
  • Loading branch information
vmoens committed Dec 16, 2024
1 parent 5dead6a commit 081f123
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 67 deletions.
5 changes: 5 additions & 0 deletions sota-implementations/td3_bc/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,8 @@ logger:
eval_steps: 1000
eval_envs: 1
video: False

compile:
compile: False
compile_mode:
cudagraphs: False
101 changes: 66 additions & 35 deletions sota-implementations/td3_bc/td3_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
"""
from __future__ import annotations

import time
import warnings

import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule

from torchrl._utils import compile_with_warmup, timeit

from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
Expand Down Expand Up @@ -72,7 +75,16 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer, device=device)

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create agent
model, _ = make_td3_agent(cfg, eval_env, device)
Expand All @@ -83,67 +95,86 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create optimizer
optimizer_actor, optimizer_critic = make_optimizer(cfg.optim, loss_module)

gradient_steps = cfg.optim.gradient_steps
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps
delayed_updates = cfg.optim.policy_update_delay
update_counter = 0
pbar = tqdm.tqdm(range(gradient_steps))
# Training loop
start_time = time.time()
for i in pbar:
pbar.update(1)
# Update actor every delayed_updates
update_counter += 1
update_actor = update_counter % delayed_updates == 0

# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(device)
else:
sampled_tensordict = sampled_tensordict.clone()

def update(sampled_tensordict, update_actor):
# Compute loss
q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict)

# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
q_loss.item()

to_log = {"q_loss": q_loss.item()}
optimizer_critic.zero_grad(set_to_none=True)

# Update actor
if update_actor:
actor_loss, actorloss_metadata = loss_module.actor_loss(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()
optimizer_actor.zero_grad(set_to_none=True)

# Update target params
target_net_updater.step()
else:
actorloss_metadata = {}
actor_loss = q_loss.new_zeros(())
metadata = TensorDict(actorloss_metadata)
metadata.set("q_loss", q_loss.detach())
metadata.set("actor_loss", actor_loss.detach())
return metadata

if cfg.compile.compile:
update = compile_with_warmup(update, mode=compile_mode, warmup=1)

if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)

gradient_steps = cfg.optim.gradient_steps
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps
delayed_updates = cfg.optim.policy_update_delay
pbar = tqdm.tqdm(range(gradient_steps))
# Training loop
for update_counter in pbar:
timeit.printevery(num_prints=1000, total_count=gradient_steps, erase=True)

to_log["actor_loss"] = actor_loss.item()
to_log.update(actorloss_metadata)
# Update actor every delayed_updates
update_actor = update_counter % delayed_updates == 0

with timeit("rb - sample"):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()

with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
metadata = update(sampled_tensordict, update_actor).clone()

to_log = {}
if update_actor:
to_log.update(metadata.to_dict())
else:
to_log.update(metadata.exclude("actor_loss").to_dict())

# evaluation
if i % evaluation_interval == 0:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
if update_counter % evaluation_interval == 0:
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward
if logger is not None:
log_metrics(logger, to_log, i)
to_log.update(timeit.todict(prefix="time"))
log_metrics(logger, to_log, update_counter)

if not eval_env.is_closed:
eval_env.close()
pbar.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
Expand Down
51 changes: 19 additions & 32 deletions sota-implementations/td3_bc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import functools

import torch
from tensordict.nn import TensorDictSequential
from tensordict.nn import TensorDictModule, TensorDictSequential

from torch import nn, optim
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
Expand All @@ -26,14 +26,7 @@
)
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianModule,
MLP,
SafeModule,
SafeSequential,
TanhModule,
ValueOperator,
)
from torchrl.modules import AdditiveGaussianModule, MLP, TanhModule, ValueOperator

from torchrl.objectives import SoftUpdate
from torchrl.objectives.td3_bc import TD3BCLoss
Expand Down Expand Up @@ -98,7 +91,7 @@ def make_environment(cfg, logger=None):
# ---------------------------


def make_offline_replay_buffer(rb_cfg):
def make_offline_replay_buffer(rb_cfg, device):
data = D4RLExperienceReplay(
dataset_id=rb_cfg.dataset,
split_trajs=False,
Expand All @@ -109,6 +102,7 @@ def make_offline_replay_buffer(rb_cfg):
)

data.append_transform(DoubleToFloat())
data.append_transform(lambda td: td.to(device))

return data

Expand All @@ -122,26 +116,22 @@ def make_td3_agent(cfg, train_env, device):
"""Make TD3 agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
actor_net_kwargs = {
"num_cells": cfg.network.hidden_sizes,
"out_features": action_spec.shape[-1],
"activation_class": get_activation(cfg),
}
action_spec = train_env.action_spec_unbatched.to(device)

actor_net = MLP(**actor_net_kwargs)
actor_net = MLP(
num_cells=cfg.network.hidden_sizes,
out_features=action_spec.shape[-1],
activation_class=get_activation(cfg),
device=device,
)

in_keys_actor = in_keys
actor_module = SafeModule(
actor_module = TensorDictModule(
actor_net,
in_keys=in_keys_actor,
out_keys=[
"param",
],
out_keys=["param"],
)
actor = SafeSequential(
actor = TensorDictSequential(
actor_module,
TanhModule(
in_keys=["param"],
Expand All @@ -151,22 +141,19 @@ def make_td3_agent(cfg, train_env, device):
)

# Define Critic Network
qvalue_net_kwargs = {
"num_cells": cfg.network.hidden_sizes,
"out_features": 1,
"activation_class": get_activation(cfg),
}

qvalue_net = MLP(
**qvalue_net_kwargs,
num_cells=cfg.network.hidden_sizes,
out_features=1,
activation_class=get_activation(cfg),
device=device,
)

qvalue = ValueOperator(
in_keys=["action"] + in_keys,
module=qvalue_net,
)

model = nn.ModuleList([actor, qvalue]).to(device)
model = nn.ModuleList([actor, qvalue])

# init nets
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
Expand Down

0 comments on commit 081f123

Please sign in to comment.