diff --git a/test/test_collector.py b/test/test_collector.py index 413ce57ffe3..490b847b3ae 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -10,6 +10,8 @@ import os import sys +from typing import Optional +from unittest.mock import patch import numpy as np import pytest @@ -469,6 +471,74 @@ def test_output_device(self, main_device, storing_device): break assert data.device == storing_device + class CudaPolicy(TensorDictSequential): + def __init__(self, n_obs): + module = torch.nn.Linear(n_obs, n_obs, device="cuda") + module.weight.data.copy_(torch.eye(n_obs)) + module.bias.data.fill_(0) + m0 = TensorDictModule(module, in_keys=["observation"], out_keys=["hidden"]) + m1 = TensorDictModule( + lambda a: a + 1, in_keys=["hidden"], out_keys=["action"] + ) + super().__init__(m0, m1) + + class GoesThroughEnv(EnvBase): + def __init__(self, n_obs, device): + self.observation_spec = Composite(observation=Unbounded(n_obs)) + self.action_spec = Unbounded(n_obs) + self.reward_spec = Unbounded(1) + self.full_done_specs = Composite(done=Unbounded(1, dtype=torch.bool)) + super().__init__(device=device) + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + a = tensordict["action"] + if self.device is not None: + assert a.device == self.device + out = tensordict.empty() + out["observation"] = tensordict["observation"] + ( + a - tensordict["observation"] + ) + out["reward"] = torch.zeros((1,), device=self.device) + out["done"] = torch.zeros((1,), device=self.device, dtype=torch.bool) + return out + + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + return self.full_done_specs.zeros().update(self.observation_spec.zeros()) + + def _set_seed(self, seed: Optional[int]): + return seed + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device") + @pytest.mark.parametrize("env_device", ["cuda:0", "cpu"]) + @pytest.mark.parametrize("storing_device", [None, "cuda:0", "cpu"]) + def test_no_synchronize(self, env_device, storing_device): + """Tests that no_cuda_sync avoids any call to torch.cuda.synchronize() and that the data is not corrupted.""" + collector = SyncDataCollector( + create_env_fn=functools.partial(self.GoesThroughEnv, n_obs=1000, device=None), + policy=self.CudaPolicy(n_obs=1000), + frames_per_batch=100, + total_frames=1000, + env_device=env_device, + storing_device=storing_device, + policy_device="cuda:0" + ) + assert collector.env.device == torch.device(env_device) + i = 0 + with patch("torch.cuda.synchronize") as mock_synchronize: + for d in collector: + for _d in d.unbind(0): + u = _d["observation"].unique() + assert u.numel() == 1 + assert u == i + i += 1 + u = _d["next", "observation"].unique() + assert u.numel() == 1 + assert u == i + mock_synchronize.assert_not_called() + # @pytest.mark.skipif( # IS_WINDOWS and PYTHON_3_10, diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 7e687b02999..2ff3aaf746e 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -440,6 +440,11 @@ class SyncDataCollector(DataCollectorBase): cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. If a dictionary of kwargs is passed, it will be used to wrap the policy. + no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. + For environments running directly on CUDA (`IsaacLab `_ + or `ManiSkills `_) cuda synchronization may cause unexpected + crashes. + Defaults to ``False``. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -532,6 +537,7 @@ def __init__( trust_policy: bool = None, compile_policy: bool | Dict[str, Any] | None = None, cudagraph_policy: bool | Dict[str, Any] | None = None, + no_cuda_sync: bool = False, **kwargs, ): from torchrl.envs.batched_envs import BatchedEnvBase @@ -625,6 +631,7 @@ def __init__( else: self._sync_policy = _do_nothing self.device = device + self.no_cuda_sync = no_cuda_sync # Check if we need to cast things from device to device # If the policy has a None device and the env too, no need to cast (we don't know # and assume the user knows what she's doing). @@ -1010,12 +1017,16 @@ def iterator(self) -> Iterator[TensorDictBase]: Yields: TensorDictBase objects containing (chunks of) trajectories """ - if self.storing_device and self.storing_device.type == "cuda": + if ( + not self.no_cuda_sync + and self.storing_device + and self.storing_device.type == "cuda" + ): stream = torch.cuda.Stream(self.storing_device, priority=-1) event = stream.record_event() streams = [stream] events = [event] - elif self.storing_device is None: + elif not self.no_cuda_sync and self.storing_device is None: streams = [] events = [] # this way of checking cuda is robust to lazy stacks with mismatching shapes @@ -1558,6 +1569,11 @@ class _MultiDataCollector(DataCollectorBase): cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. If a dictionary of kwargs is passed, it will be used to wrap the policy. + no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. + For environments running directly on CUDA (`IsaacLab `_ + or `ManiSkills `_) cuda synchronization may cause unexpected + crashes. + Defaults to ``False``. """ @@ -1597,6 +1613,7 @@ def __init__( trust_policy: bool = None, compile_policy: bool | Dict[str, Any] | None = None, cudagraph_policy: bool | Dict[str, Any] | None = None, + no_cuda_sync: bool = False, ): self.closed = True self.num_workers = len(create_env_fn) @@ -1636,6 +1653,7 @@ def __init__( self.env_device = env_devices del storing_device, env_device, policy_device, device + self.no_cuda_sync = no_cuda_sync self._use_buffers = use_buffers self.replay_buffer = replay_buffer @@ -1909,6 +1927,7 @@ def _run_processes(self) -> None: "cudagraph_policy": self.cudagraphed_policy_kwargs if self.cudagraphed_policy else False, + "no_cuda_sync": self.no_cuda_sync, } proc = _ProcessNoWarn( target=_main_async_collector, @@ -2914,6 +2933,7 @@ def _main_async_collector( trust_policy: bool = False, compile_policy: bool = False, cudagraph_policy: bool = False, + no_cuda_sync: bool = False, ) -> None: pipe_parent.close() # init variables that will be cleared when closing @@ -2943,6 +2963,7 @@ def _main_async_collector( trust_policy=trust_policy, compile_policy=compile_policy, cudagraph_policy=cudagraph_policy, + no_cuda_sync=no_cuda_sync, ) use_buffers = inner_collector._use_buffers if verbose: