Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Extend TensorDictPrimer default_value options #2071

Merged
merged 21 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 72 additions & 17 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6423,17 +6423,11 @@ def test_trans_parallel_env_check(self):
finally:
env.close()

def test_trans_serial_env_check(self):
with pytest.raises(RuntimeError, match="The leading shape of the primer specs"):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([4])),
)
_ = env.observation_spec

@pytest.mark.parametrize("spec_shape", [[4], [2, 4]])
def test_trans_serial_env_check(self, spec_shape):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])),
TensorDictPrimer(mykey=UnboundedContinuousTensorSpec(spec_shape)),
)
check_env_specs(env)
assert "mykey" in env.reset().keys()
Expand Down Expand Up @@ -6533,6 +6527,72 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done):
r1 = env.rollout(100, break_when_any_done=break_when_any_done)
tensordict.tensordict.assert_allclose_td(r0, r1)

def test_callable_default_value(self):
def create_tensor():
return torch.ones(3)

env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
mykey=UnboundedContinuousTensorSpec([3]), default_value=create_tensor
),
)
check_env_specs(env)
assert "mykey" in env.reset().keys()
assert ("next", "mykey") in env.rollout(3).keys(True)

def test_dict_default_value(self):

# Test with a dict of float default values
key1_spec = UnboundedContinuousTensorSpec([3])
key2_spec = UnboundedContinuousTensorSpec([3])
env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
mykey1=key1_spec,
mykey2=key2_spec,
default_value={
"mykey1": 1.0,
"mykey2": 2.0,
},
),
)
check_env_specs(env)
reset_td = env.reset()
assert "mykey1" in reset_td.keys()
assert "mykey2" in reset_td.keys()
rollout_td = env.rollout(3)
assert ("next", "mykey1") in rollout_td.keys(True)
assert ("next", "mykey2") in rollout_td.keys(True)
assert (rollout_td.get(("next", "mykey1")) == 1.0).all()
assert (rollout_td.get(("next", "mykey2")) == 2.0).all()

# Test with a dict of callable default values
key1_spec = UnboundedContinuousTensorSpec([3])
key2_spec = DiscreteTensorSpec(3, dtype=torch.int64)
env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
mykey1=key1_spec,
mykey2=key2_spec,
default_value={
"mykey1": lambda: torch.ones(3),
"mykey2": lambda: torch.tensor(1, dtype=torch.int64),
},
),
)
check_env_specs(env)
reset_td = env.reset()
assert "mykey1" in reset_td.keys()
assert "mykey2" in reset_td.keys()
rollout_td = env.rollout(3)
assert ("next", "mykey1") in rollout_td.keys(True)
assert ("next", "mykey2") in rollout_td.keys(True)
assert (rollout_td.get(("next", "mykey1")) == torch.ones(3)).all
assert (
rollout_td.get(("next", "mykey2")) == torch.tensor(1, dtype=torch.int64)
).all


class TestTimeMaxPool(TransformBase):
@pytest.mark.parametrize("T", [2, 4])
Expand Down Expand Up @@ -6813,18 +6873,13 @@ def make_env():
finally:
env.close()

def test_trans_serial_env_check(self):
@pytest.mark.parametrize("shape", [(), (2,)])
def test_trans_serial_env_check(self, shape):
state_dim = 7
action_dim = 7
with pytest.raises(RuntimeError, match="The leading shape of the primer"):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=()),
)
check_env_specs(env)
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=(2,)),
gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=shape),
)
try:
check_env_specs(env)
Expand Down
104 changes: 82 additions & 22 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4435,8 +4435,12 @@ class TensorDictPrimer(Transform):
random (bool, optional): if ``True``, the values will be drawn randomly from
the TensorSpec domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed.
Defaults to `False`.
default_value (float, optional): if non-random filling is chosen, this
value will be used to populate the tensors. Defaults to `0.0`.
default_value (float, Callable, Dict[NestedKey, float], Dict[NestedKey, Callable], optional): If non-random
filling is chosen, `default_value` will be used to populate the tensors. If `default_value` is a float,
all elements of the tensors will be set to that value. If it is a callable, this callable is expected to
return a tensor fitting the specs, and it will be used to generate the tensors. Finally, if `default_value`
is a dictionary of tensors or a dictionary of callables with keys matching those of the specs, these will
be used to generate the corresponding tensors. Defaults to `0.0`.
reset_key (NestedKey, optional): the reset key to be used as partial
reset indicator. Must be unique. If not provided, defaults to the
only reset key of the parent environment (if it has only one)
Expand Down Expand Up @@ -4493,8 +4497,11 @@ class TensorDictPrimer(Transform):
def __init__(
self,
primers: dict | CompositeSpec = None,
random: bool = False,
default_value: float = 0.0,
random: bool | None = None,
default_value: float
| Callable
| Dict[NestedKey, float]
| Dict[NestedKey, Callable] = None,
reset_key: NestedKey | None = None,
**kwargs,
):
Expand All @@ -4509,8 +4516,31 @@ def __init__(
if not isinstance(kwargs, CompositeSpec):
kwargs = CompositeSpec(kwargs)
self.primers = kwargs
if random and default_value:
raise ValueError(
"Setting random to True and providing a default_value are incompatible."
)
default_value = (
default_value or 0.0
) # if not random and no default value, use 0.0
self.random = random
if isinstance(default_value, dict):
default_value = TensorDict(default_value, [])
default_value_keys = default_value.keys(
True,
True,
is_leaf=lambda x: issubclass(x, (NonTensorData, torch.Tensor)),
)
if set(default_value_keys) != set(self.primers.keys(True, True)):
raise ValueError(
"If a default_value dictionary is provided, it must match the primers keys."
)
else:
default_value = {
key: default_value for key in self.primers.keys(True, True)
}
self.default_value = default_value
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
self._validated = False
self.reset_key = reset_key

# sanity check
Expand Down Expand Up @@ -4563,6 +4593,9 @@ def to(self, *args, **kwargs):
self.primers = self.primers.to(device)
return super().to(*args, **kwargs)

def _expand_shape(self, spec):
return spec.expand((*self.parent.batch_size, *spec.shape))

def transform_observation_spec(
self, observation_spec: CompositeSpec
) -> CompositeSpec:
Expand All @@ -4572,15 +4605,13 @@ def transform_observation_spec(
)
for key, spec in self.primers.items():
if spec.shape[: len(observation_spec.shape)] != observation_spec.shape:
raise RuntimeError(
f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. "
f"Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's shape is {spec.shape}."
)
expanded_spec = self._expand_shape(spec)
spec = expanded_spec
try:
device = observation_spec.device
except RuntimeError:
device = self.device
observation_spec[key] = spec.to(device)
observation_spec[key] = self.primers[key] = spec.to(device)
return observation_spec

def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
Expand All @@ -4593,8 +4624,13 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
def _batch_size(self):
return self.parent.batch_size

def _validate_value_tensor(self, value, spec):
if not spec.is_in(value):
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(f"Value ({value}) is not in the spec domain ({spec}).")
return True

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
for key, spec in self.primers.items():
for key, spec in self.primers.items(True, True):
if spec.shape[: len(tensordict.shape)] != tensordict.shape:
raise RuntimeError(
"The leading shape of the spec must match the tensordict's, "
Expand All @@ -4605,11 +4641,21 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.random:
value = spec.rand()
else:
value = torch.full_like(
spec.zero(),
self.default_value,
)
value = self.default_value[key]
if callable(value):
value = value()
if not self._validated:
self._validate_value_tensor(value, spec)
else:
value = torch.full(
spec.shape,
value,
device=spec.device,
)

tensordict.set(key, value)
if not self._validated:
self._validated = True
return tensordict

def _step(
Expand Down Expand Up @@ -4638,22 +4684,36 @@ def _reset(
)
_reset = _get_reset(self.reset_key, tensordict)
if _reset.any():
for key, spec in self.primers.items():
for key, spec in self.primers.items(True, True):
if self.random:
value = spec.rand(shape)
else:
value = torch.full_like(
spec.zero(shape),
self.default_value,
)
prev_val = tensordict.get(key, 0.0)
value = torch.where(expand_as_right(_reset, value), value, prev_val)
value = self.default_value[key]
if callable(value):
value = value()
if not self._validated:
self._validate_value_tensor(value, spec)
else:
value = torch.full(
spec.shape,
value,
device=spec.device,
)
prev_val = tensordict.get(key, 0.0)
value = torch.where(
expand_as_right(_reset, value), value, prev_val
)
tensordict_reset.set(key, value)
self._validated = True
return tensordict_reset

def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(primers={self.primers}, default_value={self.default_value}, random={self.random})"
default_value = {
key: value if isinstance(value, float) else "Callable"
for key, value in self.default_value.items()
}
return f"{class_name}(primers={self.primers}, default_value={default_value}, random={self.random})"


class PinMemoryTransform(Transform):
Expand Down
1 change: 0 additions & 1 deletion torchrl/objectives/value/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import functools
import math

import warnings
Expand Down
Loading