Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] no_cuda_sync arg in collectors #2727

Merged
merged 27 commits into from
Jan 29, 2025
81 changes: 81 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
from __future__ import annotations

import argparse
import contextlib
import functools
import gc
import os

import sys
from typing import Optional
from unittest.mock import patch

import numpy as np
import pytest
Expand Down Expand Up @@ -469,6 +472,84 @@ def test_output_device(self, main_device, storing_device):
break
assert data.device == storing_device

class CudaPolicy(TensorDictSequential):
def __init__(self, n_obs):
module = torch.nn.Linear(n_obs, n_obs, device="cuda")
module.weight.data.copy_(torch.eye(n_obs))
module.bias.data.fill_(0)
m0 = TensorDictModule(module, in_keys=["observation"], out_keys=["hidden"])
m1 = TensorDictModule(
lambda a: a + 1, in_keys=["hidden"], out_keys=["action"]
)
super().__init__(m0, m1)

class GoesThroughEnv(EnvBase):
def __init__(self, n_obs, device):
self.observation_spec = Composite(observation=Unbounded(n_obs))
self.action_spec = Unbounded(n_obs)
self.reward_spec = Unbounded(1)
self.full_done_specs = Composite(done=Unbounded(1, dtype=torch.bool))
super().__init__(device=device)

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
a = tensordict["action"]
if self.device is not None:
assert a.device == self.device
out = tensordict.empty()
out["observation"] = tensordict["observation"] + (
a - tensordict["observation"]
)
out["reward"] = torch.zeros((1,), device=self.device)
out["done"] = torch.zeros((1,), device=self.device, dtype=torch.bool)
return out

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
return self.full_done_specs.zeros().update(self.observation_spec.zeros())

def _set_seed(self, seed: Optional[int]):
return seed

@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device")
@pytest.mark.parametrize("env_device", ["cuda:0", "cpu"])
@pytest.mark.parametrize("storing_device", [None, "cuda:0", "cpu"])
@pytest.mark.parametrize("no_cuda_sync", [True, False])
def test_no_synchronize(self, env_device, storing_device, no_cuda_sync):
"""Tests that no_cuda_sync avoids any call to torch.cuda.synchronize() and that the data is not corrupted."""
should_raise = not no_cuda_sync
should_raise = should_raise & (
(env_device == "cpu") or (storing_device == "cpu")
)
with patch("torch.cuda.synchronize") as mock_synchronize, pytest.raises(
AssertionError, match="Expected 'synchronize' to not have been called."
) if should_raise else contextlib.nullcontext():
collector = SyncDataCollector(
create_env_fn=functools.partial(
self.GoesThroughEnv, n_obs=1000, device=None
),
policy=self.CudaPolicy(n_obs=1000),
frames_per_batch=100,
total_frames=1000,
env_device=env_device,
storing_device=storing_device,
policy_device="cuda:0",
no_cuda_sync=no_cuda_sync,
)
assert collector.env.device == torch.device(env_device)
i = 0
for d in collector:
for _d in d.unbind(0):
u = _d["observation"].unique()
assert u.numel() == 1, i
assert u == i, i
i += 1
u = _d["next", "observation"].unique()
assert u.numel() == 1, i
assert u == i, i
mock_synchronize.assert_not_called()


# @pytest.mark.skipif(
# IS_WINDOWS and PYTHON_3_10,
Expand Down
56 changes: 48 additions & 8 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,11 @@ class SyncDataCollector(DataCollectorBase):
cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
If a dictionary of kwargs is passed, it will be used to wrap the policy.
no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
crashes.
Defaults to ``False``.

Examples:
>>> from torchrl.envs.libs.gym import GymEnv
Expand Down Expand Up @@ -532,6 +537,7 @@ def __init__(
trust_policy: bool = None,
compile_policy: bool | Dict[str, Any] | None = None,
cudagraph_policy: bool | Dict[str, Any] | None = None,
no_cuda_sync: bool = False,
**kwargs,
):
from torchrl.envs.batched_envs import BatchedEnvBase
Expand Down Expand Up @@ -625,6 +631,7 @@ def __init__(
else:
self._sync_policy = _do_nothing
self.device = device
self.no_cuda_sync = no_cuda_sync
# Check if we need to cast things from device to device
# If the policy has a None device and the env too, no need to cast (we don't know
# and assume the user knows what she's doing).
Expand Down Expand Up @@ -1010,12 +1017,16 @@ def iterator(self) -> Iterator[TensorDictBase]:
Yields: TensorDictBase objects containing (chunks of) trajectories

"""
if self.storing_device and self.storing_device.type == "cuda":
if (
not self.no_cuda_sync
and self.storing_device
and self.storing_device.type == "cuda"
):
stream = torch.cuda.Stream(self.storing_device, priority=-1)
event = stream.record_event()
streams = [stream]
events = [event]
elif self.storing_device is None:
elif not self.no_cuda_sync and self.storing_device is None:
streams = []
events = []
# this way of checking cuda is robust to lazy stacks with mismatching shapes
Expand Down Expand Up @@ -1166,10 +1177,17 @@ def rollout(self) -> TensorDictBase:
else:
if self._cast_to_policy_device:
if self.policy_device is not None:
# This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking
non_blocking = (
not self.no_cuda_sync
or self.policy_device.type == "cuda"
)
policy_input = self._shuttle.to(
self.policy_device, non_blocking=True
self.policy_device,
non_blocking=non_blocking,
)
self._sync_policy()
if not self.no_cuda_sync:
self._sync_policy()
elif self.policy_device is None:
# we know the tensordict has a device otherwise we would not be here
# we can pass this, clear_device_ must have been called earlier
Expand All @@ -1191,8 +1209,14 @@ def rollout(self) -> TensorDictBase:

if self._cast_to_env_device:
if self.env_device is not None:
env_input = self._shuttle.to(self.env_device, non_blocking=True)
self._sync_env()
non_blocking = (
not self.no_cuda_sync or self.env_device.type == "cuda"
)
env_input = self._shuttle.to(
self.env_device, non_blocking=non_blocking
)
if not self.no_cuda_sync:
self._sync_env()
elif self.env_device is None:
# we know the tensordict has a device otherwise we would not be here
# we can pass this, clear_device_ must have been called earlier
Expand All @@ -1216,10 +1240,16 @@ def rollout(self) -> TensorDictBase:
return
else:
if self.storing_device is not None:
non_blocking = (
not self.no_cuda_sync or self.storing_device.type == "cuda"
)
tensordicts.append(
self._shuttle.to(self.storing_device, non_blocking=True)
self._shuttle.to(
self.storing_device, non_blocking=non_blocking
)
)
self._sync_storage()
if not self.no_cuda_sync:
self._sync_storage()
else:
tensordicts.append(self._shuttle)

Expand Down Expand Up @@ -1558,6 +1588,11 @@ class _MultiDataCollector(DataCollectorBase):
cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
If a dictionary of kwargs is passed, it will be used to wrap the policy.
no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
crashes.
Defaults to ``False``.

"""

Expand Down Expand Up @@ -1597,6 +1632,7 @@ def __init__(
trust_policy: bool = None,
compile_policy: bool | Dict[str, Any] | None = None,
cudagraph_policy: bool | Dict[str, Any] | None = None,
no_cuda_sync: bool = False,
):
self.closed = True
self.num_workers = len(create_env_fn)
Expand Down Expand Up @@ -1636,6 +1672,7 @@ def __init__(
self.env_device = env_devices

del storing_device, env_device, policy_device, device
self.no_cuda_sync = no_cuda_sync

self._use_buffers = use_buffers
self.replay_buffer = replay_buffer
Expand Down Expand Up @@ -1909,6 +1946,7 @@ def _run_processes(self) -> None:
"cudagraph_policy": self.cudagraphed_policy_kwargs
if self.cudagraphed_policy
else False,
"no_cuda_sync": self.no_cuda_sync,
}
proc = _ProcessNoWarn(
target=_main_async_collector,
Expand Down Expand Up @@ -2914,6 +2952,7 @@ def _main_async_collector(
trust_policy: bool = False,
compile_policy: bool = False,
cudagraph_policy: bool = False,
no_cuda_sync: bool = False,
) -> None:
pipe_parent.close()
# init variables that will be cleared when closing
Expand Down Expand Up @@ -2943,6 +2982,7 @@ def _main_async_collector(
trust_policy=trust_policy,
compile_policy=compile_policy,
cudagraph_policy=cudagraph_policy,
no_cuda_sync=no_cuda_sync,
)
use_buffers = inner_collector._use_buffers
if verbose:
Expand Down
Loading