Skip to content

Commit

Permalink
[Feature] no_cuda_sync arg in collectors
Browse files Browse the repository at this point in the history
ghstack-source-id: c2e76d6397d9ebc215c976fcf9824df0d5e54d51
Pull Request resolved: #2727
  • Loading branch information
vmoens committed Jan 28, 2025
1 parent dda0df1 commit 7722385
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 2 deletions.
70 changes: 70 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import os

import sys
from typing import Optional
from unittest.mock import patch

import numpy as np
import pytest
Expand Down Expand Up @@ -469,6 +471,74 @@ def test_output_device(self, main_device, storing_device):
break
assert data.device == storing_device

class CudaPolicy(TensorDictSequential):
def __init__(self, n_obs):
module = torch.nn.Linear(n_obs, n_obs, device="cuda")
module.weight.data.copy_(torch.eye(n_obs))
module.bias.data.fill_(0)
m0 = TensorDictModule(module, in_keys=["observation"], out_keys=["hidden"])
m1 = TensorDictModule(
lambda a: a + 1, in_keys=["hidden"], out_keys=["action"]
)
super().__init__(m0, m1)

class GoesThroughEnv(EnvBase):
def __init__(self, n_obs, device):
self.observation_spec = Composite(observation=Unbounded(n_obs))
self.action_spec = Unbounded(n_obs)
self.reward_spec = Unbounded(1)
self.full_done_specs = Composite(done=Unbounded(1, dtype=torch.bool))
super().__init__(device=device)

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
a = tensordict["action"]
if self.device is not None:
assert a.device == self.device
out = tensordict.empty()
out["observation"] = tensordict["observation"] + (
a - tensordict["observation"]
)
out["reward"] = torch.zeros((1,), device=self.device)
out["done"] = torch.zeros((1,), device=self.device, dtype=torch.bool)
return out

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
return self.full_done_specs.zeros().update(self.observation_spec.zeros())

def _set_seed(self, seed: Optional[int]):
return seed

@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device")
@pytest.mark.parametrize("env_device", ["cuda:0", "cpu"])
@pytest.mark.parametrize("storing_device", [None, "cuda:0", "cpu"])
def test_no_synchronize(self, env_device, storing_device):
"""Tests that no_cuda_sync avoids any call to torch.cuda.synchronize() and that the data is not corrupted."""
collector = SyncDataCollector(
create_env_fn=functools.partial(self.GoesThroughEnv, n_obs=1000, device=None),
policy=self.CudaPolicy(n_obs=1000),
frames_per_batch=100,
total_frames=1000,
env_device=env_device,
storing_device=storing_device,
policy_device="cuda:0"
)
assert collector.env.device == torch.device(env_device)
i = 0
with patch("torch.cuda.synchronize") as mock_synchronize:
for d in collector:
for _d in d.unbind(0):
u = _d["observation"].unique()
assert u.numel() == 1
assert u == i
i += 1
u = _d["next", "observation"].unique()
assert u.numel() == 1
assert u == i
mock_synchronize.assert_not_called()


# @pytest.mark.skipif(
# IS_WINDOWS and PYTHON_3_10,
Expand Down
25 changes: 23 additions & 2 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,11 @@ class SyncDataCollector(DataCollectorBase):
cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
If a dictionary of kwargs is passed, it will be used to wrap the policy.
no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
crashes.
Defaults to ``False``.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
Expand Down Expand Up @@ -532,6 +537,7 @@ def __init__(
trust_policy: bool = None,
compile_policy: bool | Dict[str, Any] | None = None,
cudagraph_policy: bool | Dict[str, Any] | None = None,
no_cuda_sync: bool = False,
**kwargs,
):
from torchrl.envs.batched_envs import BatchedEnvBase
Expand Down Expand Up @@ -625,6 +631,7 @@ def __init__(
else:
self._sync_policy = _do_nothing
self.device = device
self.no_cuda_sync = no_cuda_sync
# Check if we need to cast things from device to device
# If the policy has a None device and the env too, no need to cast (we don't know
# and assume the user knows what she's doing).
Expand Down Expand Up @@ -1010,12 +1017,16 @@ def iterator(self) -> Iterator[TensorDictBase]:
Yields: TensorDictBase objects containing (chunks of) trajectories
"""
if self.storing_device and self.storing_device.type == "cuda":
if (
not self.no_cuda_sync
and self.storing_device
and self.storing_device.type == "cuda"
):
stream = torch.cuda.Stream(self.storing_device, priority=-1)
event = stream.record_event()
streams = [stream]
events = [event]
elif self.storing_device is None:
elif not self.no_cuda_sync and self.storing_device is None:
streams = []
events = []
# this way of checking cuda is robust to lazy stacks with mismatching shapes
Expand Down Expand Up @@ -1558,6 +1569,11 @@ class _MultiDataCollector(DataCollectorBase):
cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
If a dictionary of kwargs is passed, it will be used to wrap the policy.
no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
crashes.
Defaults to ``False``.
"""

Expand Down Expand Up @@ -1597,6 +1613,7 @@ def __init__(
trust_policy: bool = None,
compile_policy: bool | Dict[str, Any] | None = None,
cudagraph_policy: bool | Dict[str, Any] | None = None,
no_cuda_sync: bool = False,
):
self.closed = True
self.num_workers = len(create_env_fn)
Expand Down Expand Up @@ -1636,6 +1653,7 @@ def __init__(
self.env_device = env_devices

del storing_device, env_device, policy_device, device
self.no_cuda_sync = no_cuda_sync

self._use_buffers = use_buffers
self.replay_buffer = replay_buffer
Expand Down Expand Up @@ -1909,6 +1927,7 @@ def _run_processes(self) -> None:
"cudagraph_policy": self.cudagraphed_policy_kwargs
if self.cudagraphed_policy
else False,
"no_cuda_sync": self.no_cuda_sync,
}
proc = _ProcessNoWarn(
target=_main_async_collector,
Expand Down Expand Up @@ -2914,6 +2933,7 @@ def _main_async_collector(
trust_policy: bool = False,
compile_policy: bool = False,
cudagraph_policy: bool = False,
no_cuda_sync: bool = False,
) -> None:
pipe_parent.close()
# init variables that will be cleared when closing
Expand Down Expand Up @@ -2943,6 +2963,7 @@ def _main_async_collector(
trust_policy=trust_policy,
compile_policy=compile_policy,
cudagraph_policy=cudagraph_policy,
no_cuda_sync=no_cuda_sync,
)
use_buffers = inner_collector._use_buffers
if verbose:
Expand Down

0 comments on commit 7722385

Please sign in to comment.