From c46e0fd8cb1001ed2a0b50fe638898bc13532d54 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 9 Apr 2024 16:34:22 +0200 Subject: [PATCH 01/20] extend_primer_transform --- torchrl/envs/transforms/transforms.py | 61 ++++++++++++++++++++------- 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b83de8b71f8..26ad2c76074 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4431,8 +4431,12 @@ class TensorDictPrimer(Transform): random (bool, optional): if ``True``, the values will be drawn randomly from the TensorSpec domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed. Defaults to `False`. - default_value (float, optional): if non-random filling is chosen, this - value will be used to populate the tensors. Defaults to `0.0`. + default_value (float, Callable, Dict[NestedKey, float], Dict[NestedKey, Callable], optional): If non-random + filling is chosen, `default_value` will be used to populate the tensors. If `default_value` is a float, + all elements of the tensors will be set to that value. If it is a callable, this callable is expected to + return a tensor fitting the specs, and it will be used to generate the tensors. Finally, if `default_value` + is a dictionary of tensors or a dictionary of callables with keys matching those of the specs, these will + be used to generate the corresponding tensors. Defaults to `0.0`. reset_key (NestedKey, optional): the reset key to be used as partial reset indicator. Must be unique. If not provided, defaults to the only reset key of the parent environment (if it has only one) @@ -4489,8 +4493,11 @@ class TensorDictPrimer(Transform): def __init__( self, primers: dict | CompositeSpec = None, - random: bool = False, - default_value: float = 0.0, + random: bool | None = None, + default_value: float + | Callable + | Dict[NestedKey, float] + | Dict[NestedKey, Callable] = 0.0, reset_key: NestedKey | None = None, **kwargs, ): @@ -4505,8 +4512,17 @@ def __init__( if not isinstance(kwargs, CompositeSpec): kwargs = CompositeSpec(kwargs) self.primers = kwargs + if (random is not None) and ( + isinstance(default_value, (dict, float, Callable)) and default_value != 0.0 + ): + raise ValueError( + "Setting random to True and providing a default_value are incompatible." + ) self.random = random + if isinstance(self.default_value, dict): + default_value = {key: default_value for key in primers.keys(True, True)} self.default_value = default_value + self._validated = False self.reset_key = reset_key # sanity check @@ -4559,6 +4575,9 @@ def to(self, *args, **kwargs): self.primers = self.primers.to(device) return super().to(*args, **kwargs) + def _maybe_expand_shape(self, spec): + return spec.expand((*self.parent.batch_size, *spec.shape)) + def transform_observation_spec( self, observation_spec: CompositeSpec ) -> CompositeSpec: @@ -4568,10 +4587,14 @@ def transform_observation_spec( ) for key, spec in self.primers.items(): if spec.shape[: len(observation_spec.shape)] != observation_spec.shape: - raise RuntimeError( - f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. " - f"Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's shape is {spec.shape}." - ) + import ipdb; ipdb.set_trace() + try: + spec = self._maybe_expand_shape(spec) + except AttributeError: + raise RuntimeError( + f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. " + f"Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's shape is {spec.shape}." + ) try: device = observation_spec.device except RuntimeError: @@ -4590,7 +4613,7 @@ def _batch_size(self): return self.parent.batch_size def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - for key, spec in self.primers.items(): + for key, spec in self.primers.items(True, True): if spec.shape[: len(tensordict.shape)] != tensordict.shape: raise RuntimeError( "The leading shape of the spec must match the tensordict's, " @@ -4601,10 +4624,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.random: value = spec.rand() else: - value = torch.full_like( - spec.zero(), - self.default_value, - ) + import ipdb; ipdb.set_trace() + if callable(self.default_value[key]): + value = self.default_value[key]() + # validate the value + if not self._validated: + self.validate(value) + self._validated = True + else: + value = torch.full_like( + spec.zero(), + self.default_value[key], + ) tensordict.set(key, value) return tensordict @@ -4634,13 +4665,13 @@ def _reset( ) _reset = _get_reset(self.reset_key, tensordict) if _reset.any(): - for key, spec in self.primers.items(): + for key, spec in self.primers.items(True, True): if self.random: value = spec.rand(shape) else: value = torch.full_like( spec.zero(shape), - self.default_value, + self.default_value[key], ) prev_val = tensordict.get(key, 0.0) value = torch.where(expand_as_right(_reset, value), value, prev_val) From 822c04ced9515fc9ab39329d610b3c2fa6f2ba84 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 9 Apr 2024 16:59:49 +0200 Subject: [PATCH 02/20] extend_primer_transform --- torchrl/envs/transforms/transforms.py | 33 +++++++++++++++++++-------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 26ad2c76074..09303047141 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4519,7 +4519,13 @@ def __init__( "Setting random to True and providing a default_value are incompatible." ) self.random = random - if isinstance(self.default_value, dict): + if isinstance(default_value, dict): + if len(default_value) != len(primers) and set(dict.keys()) != set( + primers.keys(True, True) + ): + raise ValueError( + "If a default_value dictionary is provided, it must match the primers keys." + ) default_value = {key: default_value for key in primers.keys(True, True)} self.default_value = default_value self._validated = False @@ -4612,6 +4618,10 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: def _batch_size(self): return self.parent.batch_size + def validate(self, value, spec): + # TODO: implement this + return True + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: for key, spec in self.primers.items(True, True): if spec.shape[: len(tensordict.shape)] != tensordict.shape: @@ -4624,18 +4634,23 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.random: value = spec.rand() else: - import ipdb; ipdb.set_trace() - if callable(self.default_value[key]): - value = self.default_value[key]() - # validate the value - if not self._validated: - self.validate(value) - self._validated = True + if isinstance(self.default_value, dict): + value = self.default_value[key] + if callable(value): + value = value() + if not self._validated: + self.validate(value, self.primers[key]) + else: + value = torch.full_like( + spec.zero(), + value, + ) else: value = torch.full_like( spec.zero(), - self.default_value[key], + self.default_value, ) + self._validated = True tensordict.set(key, value) return tensordict From c5f1a474544ed9a3d3a8350af0e3e78641744a40 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 9 Apr 2024 17:08:43 +0200 Subject: [PATCH 03/20] extend_primer_transform --- torchrl/envs/transforms/transforms.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 09303047141..ee13fa2851c 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4646,9 +4646,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: value, ) else: + value = self.default_value + if callable(value): + value = value() value = torch.full_like( spec.zero(), - self.default_value, + value, ) self._validated = True tensordict.set(key, value) @@ -4684,10 +4687,23 @@ def _reset( if self.random: value = spec.rand(shape) else: - value = torch.full_like( - spec.zero(shape), - self.default_value[key], - ) + if isinstance(self.default_value, dict): + value = self.default_value[key] + if callable(value): + value = value() + else: + value = torch.full_like( + spec.zero(shape), + value, + ) + else: + value = self.default_value + if callable(value): + value = value() + value = torch.full_like( + spec.zero(shape), + value, + ) prev_val = tensordict.get(key, 0.0) value = torch.where(expand_as_right(_reset, value), value, prev_val) tensordict_reset.set(key, value) From 667e21ee1decc6b5c07be12a275db558c9e1379f Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 9 Apr 2024 17:35:39 +0200 Subject: [PATCH 04/20] extend_primer_transform --- torchrl/envs/transforms/transforms.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ee13fa2851c..1e721da9c09 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4581,7 +4581,7 @@ def to(self, *args, **kwargs): self.primers = self.primers.to(device) return super().to(*args, **kwargs) - def _maybe_expand_shape(self, spec): + def _try_expand_shape(self, spec): return spec.expand((*self.parent.batch_size, *spec.shape)) def transform_observation_spec( @@ -4593,14 +4593,20 @@ def transform_observation_spec( ) for key, spec in self.primers.items(): if spec.shape[: len(observation_spec.shape)] != observation_spec.shape: - import ipdb; ipdb.set_trace() try: - spec = self._maybe_expand_shape(spec) + expanded_spec = self._try_expand_shape(spec) except AttributeError: + pass + if ( + expanded_spec.shape[: len(observation_spec.shape)] + != observation_spec.shape + ): raise RuntimeError( - f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. " - f"Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's shape is {spec.shape}." + f"The leading shape of the primer specs ({self.__class__}) should match the one of the " + f"parent env. Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's " + f"shape is {expanded_spec.shape}." ) + spec = expanded_spec try: device = observation_spec.device except RuntimeError: From e088240ebc3f35ab9d919ec633a29322f912885c Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 9 Apr 2024 17:36:31 +0200 Subject: [PATCH 05/20] extend_primer_transform --- torchrl/envs/transforms/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 1e721da9c09..63891d66b6f 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4513,7 +4513,7 @@ def __init__( kwargs = CompositeSpec(kwargs) self.primers = kwargs if (random is not None) and ( - isinstance(default_value, (dict, float, Callable)) and default_value != 0.0 + isinstance(default_value, (dict, float, Callable)) # and default_value != 0.0 ): raise ValueError( "Setting random to True and providing a default_value are incompatible." From 28525cddf273edc5349fddb3fe443ff215443219 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 9 Apr 2024 17:38:24 +0200 Subject: [PATCH 06/20] extend_primer_transform --- torchrl/envs/transforms/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 63891d66b6f..371ad70aa77 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4513,7 +4513,7 @@ def __init__( kwargs = CompositeSpec(kwargs) self.primers = kwargs if (random is not None) and ( - isinstance(default_value, (dict, float, Callable)) # and default_value != 0.0 + isinstance(default_value, (dict, Callable)) # and default_value != 0.0 ): raise ValueError( "Setting random to True and providing a default_value are incompatible." From 9e6c415716b8e400a0a5a8eb468d2dd023446547 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 9 Apr 2024 17:50:57 +0200 Subject: [PATCH 07/20] fix test --- test/test_transforms.py | 12 +++--------- torchrl/envs/transforms/transforms.py | 7 +++---- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 8a11e849e30..960fb393f6a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -6406,17 +6406,11 @@ def test_trans_parallel_env_check(self): finally: env.close() - def test_trans_serial_env_check(self): - with pytest.raises(RuntimeError, match="The leading shape of the primer specs"): - env = TransformedEnv( - SerialEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([4])), - ) - _ = env.observation_spec - + @pytest.mark.parametrize("spec_shape", [[4], [2, 4]]) + def test_trans_serial_env_check(self, spec_shape): env = TransformedEnv( SerialEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])), + TensorDictPrimer(mykey=UnboundedContinuousTensorSpec(spec_shape)), ) check_env_specs(env) assert "mykey" in env.reset().keys() diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 371ad70aa77..6edd550c1cd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4512,9 +4512,7 @@ def __init__( if not isinstance(kwargs, CompositeSpec): kwargs = CompositeSpec(kwargs) self.primers = kwargs - if (random is not None) and ( - isinstance(default_value, (dict, Callable)) # and default_value != 0.0 - ): + if (random is not None) and isinstance(default_value, (dict, Callable)): raise ValueError( "Setting random to True and providing a default_value are incompatible." ) @@ -4594,7 +4592,8 @@ def transform_observation_spec( for key, spec in self.primers.items(): if spec.shape[: len(observation_spec.shape)] != observation_spec.shape: try: - expanded_spec = self._try_expand_shape(spec) + # expanded_spec = self._try_expand_shape(spec) + expanded_spec = spec except AttributeError: pass if ( From 657bf9f308d6c3da6bd0dca8e17650f30b39aefa Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 9 Apr 2024 18:20:37 +0200 Subject: [PATCH 08/20] fix test --- test/test_transforms.py | 12 +++++++++--- torchrl/envs/transforms/transforms.py | 8 +------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 960fb393f6a..8a11e849e30 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -6406,11 +6406,17 @@ def test_trans_parallel_env_check(self): finally: env.close() - @pytest.mark.parametrize("spec_shape", [[4], [2, 4]]) - def test_trans_serial_env_check(self, spec_shape): + def test_trans_serial_env_check(self): + with pytest.raises(RuntimeError, match="The leading shape of the primer specs"): + env = TransformedEnv( + SerialEnv(2, ContinuousActionVecMockEnv), + TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([4])), + ) + _ = env.observation_spec + env = TransformedEnv( SerialEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec(spec_shape)), + TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])), ) check_env_specs(env) assert "mykey" in env.reset().keys() diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 6edd550c1cd..5cc80fc670e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4592,14 +4592,8 @@ def transform_observation_spec( for key, spec in self.primers.items(): if spec.shape[: len(observation_spec.shape)] != observation_spec.shape: try: - # expanded_spec = self._try_expand_shape(spec) - expanded_spec = spec + expanded_spec = self._try_expand_shape(spec) except AttributeError: - pass - if ( - expanded_spec.shape[: len(observation_spec.shape)] - != observation_spec.shape - ): raise RuntimeError( f"The leading shape of the primer specs ({self.__class__}) should match the one of the " f"parent env. Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's " From 25fc14a9bc1474c5d15667a67c5731fa1973f871 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 9 Apr 2024 18:35:11 +0200 Subject: [PATCH 09/20] fix test --- test/test_transforms.py | 12 +++--------- torchrl/envs/transforms/transforms.py | 2 +- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 8a11e849e30..960fb393f6a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -6406,17 +6406,11 @@ def test_trans_parallel_env_check(self): finally: env.close() - def test_trans_serial_env_check(self): - with pytest.raises(RuntimeError, match="The leading shape of the primer specs"): - env = TransformedEnv( - SerialEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([4])), - ) - _ = env.observation_spec - + @pytest.mark.parametrize("spec_shape", [[4], [2, 4]]) + def test_trans_serial_env_check(self, spec_shape): env = TransformedEnv( SerialEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])), + TensorDictPrimer(mykey=UnboundedContinuousTensorSpec(spec_shape)), ) check_env_specs(env) assert "mykey" in env.reset().keys() diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 5cc80fc670e..3bd352d87a2 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4604,7 +4604,7 @@ def transform_observation_spec( device = observation_spec.device except RuntimeError: device = self.device - observation_spec[key] = spec.to(device) + observation_spec[key] = self.primers[key] = spec.to(device) return observation_spec def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: From b97f72781f9399e7fd23b65ebe9889050d5095c6 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 9 Apr 2024 18:43:10 +0200 Subject: [PATCH 10/20] __repr__ --- torchrl/envs/transforms/transforms.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 3bd352d87a2..1f7495be337 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4617,7 +4617,7 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: def _batch_size(self): return self.parent.batch_size - def validate(self, value, spec): + def _validate_value_tensor(self, value, spec): # TODO: implement this return True @@ -4638,7 +4638,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if callable(value): value = value() if not self._validated: - self.validate(value, self.primers[key]) + self._validate_value_tensor(value, self.primers[key]) else: value = torch.full_like( spec.zero(), @@ -4652,7 +4652,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: spec.zero(), value, ) - self._validated = True + if not self._validated: + self._validated = True tensordict.set(key, value) return tensordict @@ -4710,7 +4711,12 @@ def _reset( def __repr__(self) -> str: class_name = self.__class__.__name__ - return f"{class_name}(primers={self.primers}, default_value={self.default_value}, random={self.random})" + default_value = ( + self.default_value + if isinstance(self.default_value, float) + else self.default_value.__class__.__name__ + ) + return f"{class_name}(primers={self.primers}, default_value={default_value}, random={self.random})" class PinMemoryTransform(Transform): From 79f87722d3e9e1d57803ca0fed48aeb6312298ff Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 9 Apr 2024 19:35:42 +0200 Subject: [PATCH 11/20] mv validated --- torchrl/envs/transforms/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 1f7495be337..c7e15d753f3 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4652,9 +4652,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: spec.zero(), value, ) - if not self._validated: - self._validated = True tensordict.set(key, value) + if not self._validated: + self._validated = True return tensordict def _step( From 8895b6212ae653726bafff2a383cca70f4efe62c Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 10 Apr 2024 13:06:58 +0200 Subject: [PATCH 12/20] fix tests --- test/test_transforms.py | 66 +++++++++++++++++++++++++++ torchrl/envs/transforms/transforms.py | 56 ++++++++++++----------- 2 files changed, 96 insertions(+), 26 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 960fb393f6a..fdc2c191e90 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -6510,6 +6510,72 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done): r1 = env.rollout(100, break_when_any_done=break_when_any_done) tensordict.tensordict.assert_allclose_td(r0, r1) + def test_callable_default_value(self): + def create_tensor(): + return torch.ones(3) + + env = TransformedEnv( + ContinuousActionVecMockEnv(), + TensorDictPrimer( + mykey=UnboundedContinuousTensorSpec([3]), default_value=create_tensor + ), + ) + check_env_specs(env) + assert "mykey" in env.reset().keys() + assert ("next", "mykey") in env.rollout(3).keys(True) + + def test_dict_default_value(self): + + # Test with a dict of float default values + key1_spec = UnboundedContinuousTensorSpec([3]) + key2_spec = UnboundedContinuousTensorSpec([3]) + env = TransformedEnv( + ContinuousActionVecMockEnv(), + TensorDictPrimer( + mykey1=key1_spec, + mykey2=key2_spec, + default_value={ + "mykey1": 1.0, + "mykey2": 2.0, + }, + ), + ) + check_env_specs(env) + reset_td = env.reset() + assert "mykey1" in reset_td.keys() + assert "mykey2" in reset_td.keys() + rollout_td = env.rollout(3) + assert ("next", "mykey1") in rollout_td.keys(True) + assert ("next", "mykey2") in rollout_td.keys(True) + assert (rollout_td.get(("next", "mykey1")) == 1.0).all() + assert (rollout_td.get(("next", "mykey2")) == 2.0).all() + + # Test with a dict of callable default values + key1_spec = UnboundedContinuousTensorSpec([3]) + key2_spec = DiscreteTensorSpec(3, dtype=torch.int64) + env = TransformedEnv( + ContinuousActionVecMockEnv(), + TensorDictPrimer( + mykey1=key1_spec, + mykey2=key2_spec, + default_value={ + "mykey1": lambda: torch.ones(3), + "mykey2": lambda: torch.tensor(1, dtype=torch.int64), + }, + ), + ) + check_env_specs(env) + reset_td = env.reset() + assert "mykey1" in reset_td.keys() + assert "mykey2" in reset_td.keys() + rollout_td = env.rollout(3) + assert ("next", "mykey1") in rollout_td.keys(True) + assert ("next", "mykey2") in rollout_td.keys(True) + assert (rollout_td.get(("next", "mykey1")) == torch.ones(3)).all + assert ( + rollout_td.get(("next", "mykey2")) == torch.tensor(1, dtype=torch.int64) + ).all + class TestTimeMaxPool(TransformBase): @pytest.mark.parametrize("T", [2, 4]) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index c7e15d753f3..4dd9c06233b 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4518,13 +4518,15 @@ def __init__( ) self.random = random if isinstance(default_value, dict): - if len(default_value) != len(primers) and set(dict.keys()) != set( - primers.keys(True, True) + if len(default_value) != len(self.primers) and set(dict.keys()) != set( + self.primers.keys(True, True) ): raise ValueError( "If a default_value dictionary is provided, it must match the primers keys." ) - default_value = {key: default_value for key in primers.keys(True, True)} + default_value = { + key: default_value[key] for key in self.primers.keys(True, True) + } self.default_value = default_value self._validated = False self.reset_key = reset_key @@ -4618,7 +4620,20 @@ def _batch_size(self): return self.parent.batch_size def _validate_value_tensor(self, value, spec): - # TODO: implement this + if value.shape != spec.shape: + raise RuntimeError( + f"Value shape ({value.shape}) does not match the spec shape ({spec.shape})." + ) + if value.dtype != spec.dtype: + raise RuntimeError( + f"Value dtype ({value.dtype}) does not match the spec dtype ({spec.dtype})." + ) + if value.device != spec.device: + raise RuntimeError( + f"Value device ({value.device}) does not match the spec device ({spec.device})." + ) + if not spec.is_in(value): + raise RuntimeError(f"Value ({value}) is not in the spec domain ({spec}).") return True def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -4635,19 +4650,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: else: if isinstance(self.default_value, dict): value = self.default_value[key] - if callable(value): - value = value() - if not self._validated: - self._validate_value_tensor(value, self.primers[key]) - else: - value = torch.full_like( - spec.zero(), - value, - ) else: value = self.default_value - if callable(value): - value = value() + if callable(value): + value = value() + if not self._validated: + self._validate_value_tensor(value, spec) + else: value = torch.full_like( spec.zero(), value, @@ -4689,24 +4698,19 @@ def _reset( else: if isinstance(self.default_value, dict): value = self.default_value[key] - if callable(value): - value = value() - else: - value = torch.full_like( - spec.zero(shape), - value, - ) else: value = self.default_value - if callable(value): - value = value() + if callable(value): + value = value() + if not self._validated: + self._validate_value_tensor(value, spec) + else: value = torch.full_like( spec.zero(shape), value, ) - prev_val = tensordict.get(key, 0.0) - value = torch.where(expand_as_right(_reset, value), value, prev_val) tensordict_reset.set(key, value) + self._validated = True return tensordict_reset def __repr__(self) -> str: From e3dcfb9a0a6f4500d9525897f4504b37197460cc Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 10 Apr 2024 13:17:22 +0200 Subject: [PATCH 13/20] minor fix --- torchrl/envs/transforms/transforms.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 4dd9c06233b..071a0ac3826 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4709,6 +4709,10 @@ def _reset( spec.zero(shape), value, ) + prev_val = tensordict.get(key, 0.0) + value = torch.where( + expand_as_right(_reset, value), value, prev_val + ) tensordict_reset.set(key, value) self._validated = True return tensordict_reset From a1cb9a1a4cdc13f4dd295b550d92dfd7fbd23caa Mon Sep 17 00:00:00 2001 From: albert bou Date: Sun, 14 Apr 2024 17:20:54 +0200 Subject: [PATCH 14/20] minor fix --- torchrl/envs/transforms/transforms.py | 50 ++++++++++----------------- 1 file changed, 18 insertions(+), 32 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 071a0ac3826..bba734b4c27 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4518,15 +4518,19 @@ def __init__( ) self.random = random if isinstance(default_value, dict): - if len(default_value) != len(self.primers) and set(dict.keys()) != set( - self.primers.keys(True, True) - ): + primer_keys = {unravel_key(key) for key in self.primers.keys(True, True)} + default_value_keys = {unravel_key(key) for key in default_value.keys()} + if primer_keys != default_value_keys: raise ValueError( "If a default_value dictionary is provided, it must match the primers keys." ) default_value = { key: default_value[key] for key in self.primers.keys(True, True) } + else: + default_value = { + key: default_value for key in self.primers.keys(True, True) + } self.default_value = default_value self._validated = False self.reset_key = reset_key @@ -4620,18 +4624,6 @@ def _batch_size(self): return self.parent.batch_size def _validate_value_tensor(self, value, spec): - if value.shape != spec.shape: - raise RuntimeError( - f"Value shape ({value.shape}) does not match the spec shape ({spec.shape})." - ) - if value.dtype != spec.dtype: - raise RuntimeError( - f"Value dtype ({value.dtype}) does not match the spec dtype ({spec.dtype})." - ) - if value.device != spec.device: - raise RuntimeError( - f"Value device ({value.device}) does not match the spec device ({spec.device})." - ) if not spec.is_in(value): raise RuntimeError(f"Value ({value}) is not in the spec domain ({spec}).") return True @@ -4648,19 +4640,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.random: value = spec.rand() else: - if isinstance(self.default_value, dict): - value = self.default_value[key] - else: - value = self.default_value + value = self.default_value[key] if callable(value): value = value() if not self._validated: self._validate_value_tensor(value, spec) else: - value = torch.full_like( - spec.zero(), + value = torch.full( + spec.shape, value, ) + tensordict.set(key, value) if not self._validated: self._validated = True @@ -4696,17 +4686,14 @@ def _reset( if self.random: value = spec.rand(shape) else: - if isinstance(self.default_value, dict): - value = self.default_value[key] - else: - value = self.default_value + value = self.default_value[key] if callable(value): value = value() if not self._validated: self._validate_value_tensor(value, spec) else: - value = torch.full_like( - spec.zero(shape), + value = torch.full( + spec.shape, value, ) prev_val = tensordict.get(key, 0.0) @@ -4719,11 +4706,10 @@ def _reset( def __repr__(self) -> str: class_name = self.__class__.__name__ - default_value = ( - self.default_value - if isinstance(self.default_value, float) - else self.default_value.__class__.__name__ - ) + default_value = { + key: value if isinstance(value, float) else "Callable" + for key, value in self.default_value.items() + } return f"{class_name}(primers={self.primers}, default_value={default_value}, random={self.random})" From 5a86c256672debd283d4f11028463161afa59ed1 Mon Sep 17 00:00:00 2001 From: albert bou Date: Sun, 14 Apr 2024 17:46:45 +0200 Subject: [PATCH 15/20] suggested changes --- torchrl/envs/transforms/transforms.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index bba734b4c27..a5b45360bd4 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4497,7 +4497,7 @@ def __init__( default_value: float | Callable | Dict[NestedKey, float] - | Dict[NestedKey, Callable] = 0.0, + | Dict[NestedKey, Callable] = None, reset_key: NestedKey | None = None, **kwargs, ): @@ -4512,10 +4512,13 @@ def __init__( if not isinstance(kwargs, CompositeSpec): kwargs = CompositeSpec(kwargs) self.primers = kwargs - if (random is not None) and isinstance(default_value, (dict, Callable)): + if random and default_value: raise ValueError( "Setting random to True and providing a default_value are incompatible." ) + default_value = ( + default_value or 0.0 + ) # if not random and no default value, use 0.0 self.random = random if isinstance(default_value, dict): primer_keys = {unravel_key(key) for key in self.primers.keys(True, True)} @@ -4649,6 +4652,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: value = torch.full( spec.shape, value, + device=spec.device, ) tensordict.set(key, value) @@ -4695,6 +4699,7 @@ def _reset( value = torch.full( spec.shape, value, + device=spec.device, ) prev_val = tensordict.get(key, 0.0) value = torch.where( From a64a97e539b124da2165f46ebd848e410e45da69 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 16 Apr 2024 11:04:03 +0200 Subject: [PATCH 16/20] feedback changes --- torchrl/envs/transforms/transforms.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index a5b45360bd4..625edf9e14d 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4521,7 +4521,7 @@ def __init__( ) # if not random and no default value, use 0.0 self.random = random if isinstance(default_value, dict): - primer_keys = {unravel_key(key) for key in self.primers.keys(True, True)} + primer_keys = self.primers.keys(True, True) default_value_keys = {unravel_key(key) for key in default_value.keys()} if primer_keys != default_value_keys: raise ValueError( @@ -4588,7 +4588,7 @@ def to(self, *args, **kwargs): self.primers = self.primers.to(device) return super().to(*args, **kwargs) - def _try_expand_shape(self, spec): + def _expand_shape(self, spec): return spec.expand((*self.parent.batch_size, *spec.shape)) def transform_observation_spec( @@ -4600,14 +4600,7 @@ def transform_observation_spec( ) for key, spec in self.primers.items(): if spec.shape[: len(observation_spec.shape)] != observation_spec.shape: - try: - expanded_spec = self._try_expand_shape(spec) - except AttributeError: - raise RuntimeError( - f"The leading shape of the primer specs ({self.__class__}) should match the one of the " - f"parent env. Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's " - f"shape is {expanded_spec.shape}." - ) + expanded_spec = self._expand_shape(spec) spec = expanded_spec try: device = observation_spec.device From 311f40b026e14e4a47e2413fbcc0162ed3d9f65a Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 16 Apr 2024 11:38:24 +0200 Subject: [PATCH 17/20] feedback changes --- torchrl/envs/transforms/transforms.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 625edf9e14d..5e73c81c7f5 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4521,15 +4521,11 @@ def __init__( ) # if not random and no default value, use 0.0 self.random = random if isinstance(default_value, dict): - primer_keys = self.primers.keys(True, True) - default_value_keys = {unravel_key(key) for key in default_value.keys()} - if primer_keys != default_value_keys: + default_value = TensorDict(default_value, []).to_dict() + if set(default_value.keys()) != set(self.primers.keys(True, True)): raise ValueError( "If a default_value dictionary is provided, it must match the primers keys." ) - default_value = { - key: default_value[key] for key in self.primers.keys(True, True) - } else: default_value = { key: default_value for key in self.primers.keys(True, True) From 3219150c26687b4548c383412e6016429b036180 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 16 Apr 2024 11:56:02 +0200 Subject: [PATCH 18/20] feedback changes --- torchrl/envs/transforms/transforms.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 5e73c81c7f5..25ade081f6b 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4521,8 +4521,13 @@ def __init__( ) # if not random and no default value, use 0.0 self.random = random if isinstance(default_value, dict): - default_value = TensorDict(default_value, []).to_dict() - if set(default_value.keys()) != set(self.primers.keys(True, True)): + default_value = TensorDict(default_value, []) + default_value_keys = default_value.keys( + True, + True, + is_leaf=lambda x: issubclass(x, (NonTensorData, torch.Tensor)), + ) + if set(default_value_keys) != set(self.primers.keys(True, True)): raise ValueError( "If a default_value dictionary is provided, it must match the primers keys." ) From 891bd4243179cd05f79a6870be1d11e38202632a Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 16 Apr 2024 16:39:28 +0200 Subject: [PATCH 19/20] fix tests --- test/test_transforms.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index fdc2c191e90..4e31be00d39 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -6856,18 +6856,13 @@ def make_env(): finally: env.close() - def test_trans_serial_env_check(self): + @pytest.mark.parametrize("shape", [(), (2,)]) + def test_trans_serial_env_check(self, shape): state_dim = 7 action_dim = 7 - with pytest.raises(RuntimeError, match="The leading shape of the primer"): - env = TransformedEnv( - SerialEnv(2, ContinuousActionVecMockEnv), - gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=()), - ) - check_env_specs(env) env = TransformedEnv( SerialEnv(2, ContinuousActionVecMockEnv), - gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=(2,)), + gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=shape), ) try: check_env_specs(env) From 1625514ec5960ae168e2093f32b3b17d54403a93 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 18 Apr 2024 16:03:21 +0100 Subject: [PATCH 20/20] amend --- torchrl/objectives/value/functional.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index e93386b34ef..082c0ae9e9a 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -import functools import math import warnings