Skip to content

Commit

Permalink
[Feature] no_cuda_sync arg in collectors
Browse files Browse the repository at this point in the history
ghstack-source-id: a2a30a5ca9be16fb82804baa9aab987c16745abe
Pull Request resolved: #2727
  • Loading branch information
vmoens committed Jan 29, 2025
1 parent dda0df1 commit ad1765c
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 8 deletions.
76 changes: 76 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import os

import sys
from typing import Optional
from unittest.mock import patch

import numpy as np
import pytest
Expand Down Expand Up @@ -469,6 +471,80 @@ def test_output_device(self, main_device, storing_device):
break
assert data.device == storing_device

class CudaPolicy(TensorDictSequential):
def __init__(self, n_obs):
module = torch.nn.Linear(n_obs, n_obs, device="cuda")
module.weight.data.copy_(torch.eye(n_obs))
module.bias.data.fill_(0)
m0 = TensorDictModule(module, in_keys=["observation"], out_keys=["hidden"])
m1 = TensorDictModule(
lambda a: a + 1, in_keys=["hidden"], out_keys=["action"]
)
super().__init__(m0, m1)

class GoesThroughEnv(EnvBase):
def __init__(self, n_obs, device):
self.observation_spec = Composite(observation=Unbounded(n_obs))
self.action_spec = Unbounded(n_obs)
self.reward_spec = Unbounded(1)
self.full_done_specs = Composite(done=Unbounded(1, dtype=torch.bool))
super().__init__(device=device)

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
a = tensordict["action"]
if self.device is not None:
assert a.device == self.device
out = tensordict.empty()
out["observation"] = tensordict["observation"] + (
a - tensordict["observation"]
)
out["reward"] = torch.zeros((1,), device=self.device)
out["done"] = torch.zeros((1,), device=self.device, dtype=torch.bool)
return out

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
return self.full_done_specs.zeros().update(self.observation_spec.zeros())

def _set_seed(self, seed: Optional[int]):
return seed

@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device")
@pytest.mark.parametrize("env_device", ["cuda:0", "cpu"])
@pytest.mark.parametrize("storing_device", [None, "cuda:0", "cpu"])
def test_no_synchronize(self, env_device, storing_device):
"""Tests that no_cuda_sync avoids any call to torch.cuda.synchronize() and that the data is not corrupted."""
collector = SyncDataCollector(
create_env_fn=functools.partial(
self.GoesThroughEnv, n_obs=1000, device=None
),
policy=self.CudaPolicy(n_obs=1000),
frames_per_batch=100,
total_frames=1000,
env_device=env_device,
storing_device=storing_device,
policy_device="cuda:0",
# no_cuda_sync=True,
)
assert collector.env.device == torch.device(env_device)
i = 0
with patch("torch.cuda.synchronize") as mock_synchronize:
for d in collector:
for _d in d.unbind(0):
u = _d["observation"].unique()
assert u.numel() == 1, i
assert u == i, i
i += 1
u = _d["next", "observation"].unique()
assert u.numel() == 1, i
assert u == i, i
mock_synchronize.assert_not_called()
assert (
not mock_synchronize.called
), "torch.cuda.synchronize should not be called"


# @pytest.mark.skipif(
# IS_WINDOWS and PYTHON_3_10,
Expand Down
45 changes: 37 additions & 8 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,11 @@ class SyncDataCollector(DataCollectorBase):
cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
If a dictionary of kwargs is passed, it will be used to wrap the policy.
no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
crashes.
Defaults to ``False``.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
Expand Down Expand Up @@ -532,6 +537,7 @@ def __init__(
trust_policy: bool = None,
compile_policy: bool | Dict[str, Any] | None = None,
cudagraph_policy: bool | Dict[str, Any] | None = None,
no_cuda_sync: bool = False,
**kwargs,
):
from torchrl.envs.batched_envs import BatchedEnvBase
Expand Down Expand Up @@ -625,6 +631,7 @@ def __init__(
else:
self._sync_policy = _do_nothing
self.device = device
self.no_cuda_sync = no_cuda_sync
# Check if we need to cast things from device to device
# If the policy has a None device and the env too, no need to cast (we don't know
# and assume the user knows what she's doing).
Expand Down Expand Up @@ -1010,12 +1017,16 @@ def iterator(self) -> Iterator[TensorDictBase]:
Yields: TensorDictBase objects containing (chunks of) trajectories
"""
if self.storing_device and self.storing_device.type == "cuda":
if (
not self.no_cuda_sync
and self.storing_device
and self.storing_device.type == "cuda"
):
stream = torch.cuda.Stream(self.storing_device, priority=-1)
event = stream.record_event()
streams = [stream]
events = [event]
elif self.storing_device is None:
elif not self.no_cuda_sync and self.storing_device is None:
streams = []
events = []
# this way of checking cuda is robust to lazy stacks with mismatching shapes
Expand Down Expand Up @@ -1167,9 +1178,11 @@ def rollout(self) -> TensorDictBase:
if self._cast_to_policy_device:
if self.policy_device is not None:
policy_input = self._shuttle.to(
self.policy_device, non_blocking=True
self.policy_device,
non_blocking=not self.no_cuda_sync,
)
self._sync_policy()
if not self.no_cuda_sync:
self._sync_policy()
elif self.policy_device is None:
# we know the tensordict has a device otherwise we would not be here
# we can pass this, clear_device_ must have been called earlier
Expand All @@ -1191,8 +1204,11 @@ def rollout(self) -> TensorDictBase:

if self._cast_to_env_device:
if self.env_device is not None:
env_input = self._shuttle.to(self.env_device, non_blocking=True)
self._sync_env()
env_input = self._shuttle.to(
self.env_device, non_blocking=not self.no_cuda_sync
)
if not self.no_cuda_sync:
self._sync_env()
elif self.env_device is None:
# we know the tensordict has a device otherwise we would not be here
# we can pass this, clear_device_ must have been called earlier
Expand All @@ -1217,9 +1233,12 @@ def rollout(self) -> TensorDictBase:
else:
if self.storing_device is not None:
tensordicts.append(
self._shuttle.to(self.storing_device, non_blocking=True)
self._shuttle.to(
self.storing_device, non_blocking=not self.no_cuda_sync
)
)
self._sync_storage()
if not self.no_cuda_sync:
self._sync_storage()
else:
tensordicts.append(self._shuttle)

Expand Down Expand Up @@ -1558,6 +1577,11 @@ class _MultiDataCollector(DataCollectorBase):
cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
If a dictionary of kwargs is passed, it will be used to wrap the policy.
no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
crashes.
Defaults to ``False``.
"""

Expand Down Expand Up @@ -1597,6 +1621,7 @@ def __init__(
trust_policy: bool = None,
compile_policy: bool | Dict[str, Any] | None = None,
cudagraph_policy: bool | Dict[str, Any] | None = None,
no_cuda_sync: bool = False,
):
self.closed = True
self.num_workers = len(create_env_fn)
Expand Down Expand Up @@ -1636,6 +1661,7 @@ def __init__(
self.env_device = env_devices

del storing_device, env_device, policy_device, device
self.no_cuda_sync = no_cuda_sync

self._use_buffers = use_buffers
self.replay_buffer = replay_buffer
Expand Down Expand Up @@ -1909,6 +1935,7 @@ def _run_processes(self) -> None:
"cudagraph_policy": self.cudagraphed_policy_kwargs
if self.cudagraphed_policy
else False,
"no_cuda_sync": self.no_cuda_sync,
}
proc = _ProcessNoWarn(
target=_main_async_collector,
Expand Down Expand Up @@ -2914,6 +2941,7 @@ def _main_async_collector(
trust_policy: bool = False,
compile_policy: bool = False,
cudagraph_policy: bool = False,
no_cuda_sync: bool = False,
) -> None:
pipe_parent.close()
# init variables that will be cleared when closing
Expand Down Expand Up @@ -2943,6 +2971,7 @@ def _main_async_collector(
trust_policy=trust_policy,
compile_policy=compile_policy,
cudagraph_policy=cudagraph_policy,
no_cuda_sync=no_cuda_sync,
)
use_buffers = inner_collector._use_buffers
if verbose:
Expand Down

0 comments on commit ad1765c

Please sign in to comment.