Skip to content

Commit

Permalink
[Feature] Custom conversion tool for gym specs
Browse files Browse the repository at this point in the history
ghstack-source-id: d38bb02f15267a9b1637b3ed25fb44ef013e2456
Pull Request resolved: #2726
  • Loading branch information
vmoens committed Jan 30, 2025
1 parent 5fd5092 commit dbc8e2e
Show file tree
Hide file tree
Showing 5 changed files with 330 additions and 119 deletions.
3 changes: 2 additions & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ in the relevant functions:
>>> print(env2._env.env.env)
<gym.envs.classic_control.pendulum.PendulumEnv at 0x1629916a0>

We can see that the two libraries modify the value returned by :func:`~.gym.gym_backend()`
We can see that the two libraries modify the value returned by :func:`~torchrl.envs.gym.gym_backend()`
which can be further used to indicate which library needs to be used for
the current computation. :class:`~.gym.set_gym_backend` is also a decorator:
we can use it to tell to a specific function what gym backend needs to be used
Expand Down Expand Up @@ -1188,3 +1188,4 @@ the following function will return ``1`` when queried:
VmasWrapper
gym_backend
set_gym_backend
register_gym_spec_conversion
35 changes: 35 additions & 0 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
Composite,
MultiCategorical,
MultiOneHot,
NonTensor,
OneHot,
ReplayBuffer,
ReplayBufferEnsemble,
Expand Down Expand Up @@ -119,6 +120,7 @@
GymWrapper,
MOGymEnv,
MOGymWrapper,
register_gym_spec_conversion,
set_gym_backend,
)
from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv
Expand Down Expand Up @@ -337,6 +339,39 @@ def test_gym_spec_cast(self, categorical):
assert spec == recon
assert recon.shape == spec.shape

def test_gym_new_spec_reg(self):
Space = gym_backend("spaces").Space

class MySpaceParent(Space):
...

s_parent = MySpaceParent()

class MySpaceChild(MySpaceParent):
...

# We intentionally register first the child then the parent
@register_gym_spec_conversion(MySpaceChild)
def convert_myspace_child(spec, **kwargs):
return NonTensor((), example_data="child")

@register_gym_spec_conversion(MySpaceParent)
def convert_myspace_parent(spec, **kwargs):
return NonTensor((), example_data="parent")

s_child = MySpaceChild()
assert _gym_to_torchrl_spec_transform(s_parent).example_data == "parent"
assert _gym_to_torchrl_spec_transform(s_child).example_data == "child"

class NoConversionSpace(Space):
...

s_no_conv = NoConversionSpace()
with pytest.raises(
KeyError, match="No conversion tool could be found with the gym space"
):
_gym_to_torchrl_spec_transform(s_no_conv)

@pytest.mark.parametrize("order", ["tuple_seq"])
@implement_for("gym")
def test_gym_spec_cast_tuple_sequential(self, order):
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
OpenSpielWrapper,
PettingZooEnv,
PettingZooWrapper,
register_gym_spec_conversion,
RoboHiveEnv,
set_gym_backend,
SMACv2Env,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/libs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
GymWrapper,
MOGymEnv,
MOGymWrapper,
register_gym_spec_conversion,
set_gym_backend,
)
from .habitat import HabitatEnv
Expand Down
Loading

0 comments on commit dbc8e2e

Please sign in to comment.