diff --git a/test/test_collector.py b/test/test_collector.py
index 413ce57ffe3..490b847b3ae 100644
--- a/test/test_collector.py
+++ b/test/test_collector.py
@@ -10,6 +10,8 @@
import os
import sys
+from typing import Optional
+from unittest.mock import patch
import numpy as np
import pytest
@@ -469,6 +471,74 @@ def test_output_device(self, main_device, storing_device):
break
assert data.device == storing_device
+ class CudaPolicy(TensorDictSequential):
+ def __init__(self, n_obs):
+ module = torch.nn.Linear(n_obs, n_obs, device="cuda")
+ module.weight.data.copy_(torch.eye(n_obs))
+ module.bias.data.fill_(0)
+ m0 = TensorDictModule(module, in_keys=["observation"], out_keys=["hidden"])
+ m1 = TensorDictModule(
+ lambda a: a + 1, in_keys=["hidden"], out_keys=["action"]
+ )
+ super().__init__(m0, m1)
+
+ class GoesThroughEnv(EnvBase):
+ def __init__(self, n_obs, device):
+ self.observation_spec = Composite(observation=Unbounded(n_obs))
+ self.action_spec = Unbounded(n_obs)
+ self.reward_spec = Unbounded(1)
+ self.full_done_specs = Composite(done=Unbounded(1, dtype=torch.bool))
+ super().__init__(device=device)
+
+ def _step(
+ self,
+ tensordict: TensorDictBase,
+ ) -> TensorDictBase:
+ a = tensordict["action"]
+ if self.device is not None:
+ assert a.device == self.device
+ out = tensordict.empty()
+ out["observation"] = tensordict["observation"] + (
+ a - tensordict["observation"]
+ )
+ out["reward"] = torch.zeros((1,), device=self.device)
+ out["done"] = torch.zeros((1,), device=self.device, dtype=torch.bool)
+ return out
+
+ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
+ return self.full_done_specs.zeros().update(self.observation_spec.zeros())
+
+ def _set_seed(self, seed: Optional[int]):
+ return seed
+
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device")
+ @pytest.mark.parametrize("env_device", ["cuda:0", "cpu"])
+ @pytest.mark.parametrize("storing_device", [None, "cuda:0", "cpu"])
+ def test_no_synchronize(self, env_device, storing_device):
+ """Tests that no_cuda_sync avoids any call to torch.cuda.synchronize() and that the data is not corrupted."""
+ collector = SyncDataCollector(
+ create_env_fn=functools.partial(self.GoesThroughEnv, n_obs=1000, device=None),
+ policy=self.CudaPolicy(n_obs=1000),
+ frames_per_batch=100,
+ total_frames=1000,
+ env_device=env_device,
+ storing_device=storing_device,
+ policy_device="cuda:0"
+ )
+ assert collector.env.device == torch.device(env_device)
+ i = 0
+ with patch("torch.cuda.synchronize") as mock_synchronize:
+ for d in collector:
+ for _d in d.unbind(0):
+ u = _d["observation"].unique()
+ assert u.numel() == 1
+ assert u == i
+ i += 1
+ u = _d["next", "observation"].unique()
+ assert u.numel() == 1
+ assert u == i
+ mock_synchronize.assert_not_called()
+
# @pytest.mark.skipif(
# IS_WINDOWS and PYTHON_3_10,
diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py
index 7e687b02999..2ff3aaf746e 100644
--- a/torchrl/collectors/collectors.py
+++ b/torchrl/collectors/collectors.py
@@ -440,6 +440,11 @@ class SyncDataCollector(DataCollectorBase):
cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
If a dictionary of kwargs is passed, it will be used to wrap the policy.
+ no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
+ For environments running directly on CUDA (`IsaacLab `_
+ or `ManiSkills `_) cuda synchronization may cause unexpected
+ crashes.
+ Defaults to ``False``.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
@@ -532,6 +537,7 @@ def __init__(
trust_policy: bool = None,
compile_policy: bool | Dict[str, Any] | None = None,
cudagraph_policy: bool | Dict[str, Any] | None = None,
+ no_cuda_sync: bool = False,
**kwargs,
):
from torchrl.envs.batched_envs import BatchedEnvBase
@@ -625,6 +631,7 @@ def __init__(
else:
self._sync_policy = _do_nothing
self.device = device
+ self.no_cuda_sync = no_cuda_sync
# Check if we need to cast things from device to device
# If the policy has a None device and the env too, no need to cast (we don't know
# and assume the user knows what she's doing).
@@ -1010,12 +1017,16 @@ def iterator(self) -> Iterator[TensorDictBase]:
Yields: TensorDictBase objects containing (chunks of) trajectories
"""
- if self.storing_device and self.storing_device.type == "cuda":
+ if (
+ not self.no_cuda_sync
+ and self.storing_device
+ and self.storing_device.type == "cuda"
+ ):
stream = torch.cuda.Stream(self.storing_device, priority=-1)
event = stream.record_event()
streams = [stream]
events = [event]
- elif self.storing_device is None:
+ elif not self.no_cuda_sync and self.storing_device is None:
streams = []
events = []
# this way of checking cuda is robust to lazy stacks with mismatching shapes
@@ -1558,6 +1569,11 @@ class _MultiDataCollector(DataCollectorBase):
cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
If a dictionary of kwargs is passed, it will be used to wrap the policy.
+ no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
+ For environments running directly on CUDA (`IsaacLab `_
+ or `ManiSkills `_) cuda synchronization may cause unexpected
+ crashes.
+ Defaults to ``False``.
"""
@@ -1597,6 +1613,7 @@ def __init__(
trust_policy: bool = None,
compile_policy: bool | Dict[str, Any] | None = None,
cudagraph_policy: bool | Dict[str, Any] | None = None,
+ no_cuda_sync: bool = False,
):
self.closed = True
self.num_workers = len(create_env_fn)
@@ -1636,6 +1653,7 @@ def __init__(
self.env_device = env_devices
del storing_device, env_device, policy_device, device
+ self.no_cuda_sync = no_cuda_sync
self._use_buffers = use_buffers
self.replay_buffer = replay_buffer
@@ -1909,6 +1927,7 @@ def _run_processes(self) -> None:
"cudagraph_policy": self.cudagraphed_policy_kwargs
if self.cudagraphed_policy
else False,
+ "no_cuda_sync": self.no_cuda_sync,
}
proc = _ProcessNoWarn(
target=_main_async_collector,
@@ -2914,6 +2933,7 @@ def _main_async_collector(
trust_policy: bool = False,
compile_policy: bool = False,
cudagraph_policy: bool = False,
+ no_cuda_sync: bool = False,
) -> None:
pipe_parent.close()
# init variables that will be cleared when closing
@@ -2943,6 +2963,7 @@ def _main_async_collector(
trust_policy=trust_policy,
compile_policy=compile_policy,
cudagraph_policy=cudagraph_policy,
+ no_cuda_sync=no_cuda_sync,
)
use_buffers = inner_collector._use_buffers
if verbose: