diff --git a/.github/workflows/tests_linters.yml b/.github/workflows/tests_linters.yml index 258378b7a..fe6b148cb 100644 --- a/.github/workflows/tests_linters.yml +++ b/.github/workflows/tests_linters.yml @@ -1,34 +1,51 @@ name: Tests and Linters ๐Ÿงช -on: [ push, pull_request ] +on: [ pull_request ] jobs: tests-and-linters: name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" runs-on: "${{ matrix.os }}" + timeout-minutes: 20 strategy: matrix: - python-version: ["3.8", "3.9"] + python-version: ["3.10", "3.11", "3.12"] os: [ubuntu-latest] steps: - name: Install dependencies for viewer test run: sudo apt-get update && sudo apt-get install -y xvfb + - name: Checkout jumanji ๐Ÿ - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v3 + with: + version: "0.4.26" + enable-cache: true + cache-dependency-glob: "requirements/requirements**.txt" # invalidate cache when requirements file changes + + - uses: actions/setup-python@v5 with: python-version: "${{ matrix.python-version }}" + - name: Install python dependencies ๐Ÿ”ง - run: pip install .[dev,train] + run: uv pip install .[dev,train] + env: + UV_SYSTEM_PYTHON: 1 + - name: Run linters ๐Ÿ–Œ๏ธ run: pre-commit run --all-files --verbose + - name: Run tests ๐Ÿงช run: pytest -n 2 --cov=jumanji --cov-report=term-missing --junit-xml=test-results.xml -vv jumanji + - name: Run coverage run: | coverage html --directory=coverage_html_report coverage report --fail-under=0.97 + - name: Test build docs ๐Ÿ“– run: mkdocs build --verbose --site-dir docs_public diff --git a/.gitignore b/.gitignore index 7a4e033c2..ec09148f1 100644 --- a/.gitignore +++ b/.gitignore @@ -150,7 +150,7 @@ cython_debug/ # MacBook Finder .DS_Store -3.8/ +3.10/ jumanji_env/ **/outputs/ *.xml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 928835b3d..f62452ace 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: name: "Trailing whitespace fixer" - repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 + rev: 7.1.1 hooks: - id: flake8 name: "Linter" diff --git a/README.md b/README.md index 5eb7be203..f18ca16b0 100644 --- a/README.md +++ b/README.md @@ -138,7 +138,7 @@ Alternatively, you can install the latest development version directly from GitH pip install git+https://github.com/instadeepai/jumanji.git ``` -Jumanji has been tested on Python 3.8 and 3.9. +Jumanji has been tested on Python 3.10, 3.11 and 3.12. Note that because the installation of JAX differs depending on your hardware accelerator, we advise users to explicitly install the correct JAX version (see the [official installation guide](https://github.com/google/jax#installation)). diff --git a/jumanji/env.py b/jumanji/env.py index 48035a992..9674960c8 100644 --- a/jumanji/env.py +++ b/jumanji/env.py @@ -110,7 +110,7 @@ def reward_spec(self) -> specs.Array: @cached_property def discount_spec(self) -> specs.BoundedArray: - """Returns the discount spec. By default, this is assumed to be a single float between 0 and 1. + """Returns the discount spec. By default, this is assumed to be a float between 0 and 1. Returns: discount_spec: a `specs.BoundedArray` spec. diff --git a/jumanji/environments/logic/game_2048/viewer.py b/jumanji/environments/logic/game_2048/viewer.py index 819d1c251..b64b48b20 100644 --- a/jumanji/environments/logic/game_2048/viewer.py +++ b/jumanji/environments/logic/game_2048/viewer.py @@ -123,7 +123,7 @@ def make_frame(state_index: int) -> None: return self._animation def get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: - """This function returns a `Matplotlib` figure and axes object for displaying the 2048 game board. + """This function returns a `Matplotlib` figure and axes for displaying the 2048 game board. Returns: A tuple containing the figure and axes objects. diff --git a/jumanji/environments/logic/graph_coloring/env.py b/jumanji/environments/logic/graph_coloring/env.py index b5e65a3e5..32de81019 100644 --- a/jumanji/environments/logic/graph_coloring/env.py +++ b/jumanji/environments/logic/graph_coloring/env.py @@ -296,7 +296,7 @@ def animate( interval: int = 200, save_path: Optional[str] = None, ) -> animation.FuncAnimation: - """Creates an animated gif of the `GraphColoring` environment based on the sequence of game states. + """Creates an animated gif of the `GraphColoring` environment based on a sequence of states. Args: states: is a list of `State` objects representing the sequence of game states. diff --git a/jumanji/environments/logic/sliding_tile_puzzle/viewer.py b/jumanji/environments/logic/sliding_tile_puzzle/viewer.py index 6596a323d..7fa905fcf 100644 --- a/jumanji/environments/logic/sliding_tile_puzzle/viewer.py +++ b/jumanji/environments/logic/sliding_tile_puzzle/viewer.py @@ -71,7 +71,7 @@ def animate( interval: int = 200, save_path: Optional[str] = None, ) -> matplotlib.animation.FuncAnimation: - """Creates an animated gif of the sliding tiles puzzle game based on the sequence of game states. + """Creates an animated gif of the sliding tiles puzzle game based on a sequence of states. Args: states: is a list of `State` objects representing the sequence of game states. @@ -101,7 +101,7 @@ def make_frame(state_index: int) -> None: return self._animation def get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: - """This function returns a `Matplotlib` figure and axes object for displaying the game puzzle. + """This function returns a `Matplotlib` figure and axes for displaying the puzzle. Returns: A tuple containing the figure and axes objects. diff --git a/jumanji/environments/packing/bin_pack/env.py b/jumanji/environments/packing/bin_pack/env.py index 5b2b2c7cf..3506fa0b0 100644 --- a/jumanji/environments/packing/bin_pack/env.py +++ b/jumanji/environments/packing/bin_pack/env.py @@ -397,9 +397,9 @@ def close(self) -> None: def _make_observation_and_extras( self, state: State ) -> Tuple[State, Observation, Dict]: - """Computes the observation and the environment metrics to include in `timestep.extras`. Also - updates the `action_mask` and `sorted_ems_indexes` in the state. The observation is obtained - by selecting a subset of all EMSs, namely the `obs_num_ems` largest ones. + """Computes the observation and the environment metrics to include in `timestep.extras`. + Also updates the `action_mask` and `sorted_ems_indexes` in the state. The observation is + obtained by selecting a subset of all EMSs, namely the `obs_num_ems` largest ones. """ obs_ems, obs_ems_mask, sorted_ems_indexes = self._get_set_of_largest_ems( state.ems, state.ems_mask diff --git a/jumanji/environments/packing/flat_pack/generator.py b/jumanji/environments/packing/flat_pack/generator.py index 7ea8495d5..412c4d9e1 100644 --- a/jumanji/environments/packing/flat_pack/generator.py +++ b/jumanji/environments/packing/flat_pack/generator.py @@ -28,8 +28,8 @@ class InstanceGenerator(abc.ABC): - """Base class for generators for the flat_pack environment. An `InstanceGenerator` is responsible - for generating a problem instance when the environment is reset. + """Base class for generators for the flat_pack environment. An `InstanceGenerator` is + responsible for generating a problem instance when the environment is reset. """ def __init__( diff --git a/jumanji/environments/routing/mmst/generator.py b/jumanji/environments/routing/mmst/generator.py index c71d01054..be0e4685a 100644 --- a/jumanji/environments/routing/mmst/generator.py +++ b/jumanji/environments/routing/mmst/generator.py @@ -88,7 +88,7 @@ def __call__(self, key: chex.PRNGKey) -> State: class SplitRandomGenerator(Generator): - """Generates a random environments that is solvable by spliting the graph into multiple sub graphs. + """Generates a random environments that is solvable by spliting the graph into sub graphs. Returns a graph and with a desired number of edges and nodes to connect per agent. """ diff --git a/jumanji/environments/routing/multi_cvrp/viewer.py b/jumanji/environments/routing/multi_cvrp/viewer.py index 1299cb22d..705d545a0 100644 --- a/jumanji/environments/routing/multi_cvrp/viewer.py +++ b/jumanji/environments/routing/multi_cvrp/viewer.py @@ -210,10 +210,10 @@ def _draw_route(self, ax: plt.Axes, coords: chex.Array, col_id: int) -> None: ax.scatter(x, y, s=self.NODE_SIZE, color=self._cmap(col_id)) def _add_tour(self, ax: plt.Axes, state: State) -> None: - """Add the customers and the depot to the plot, and draw each route in the tour in a different - colour. The tour is the entire trajectory between the visited customers and a route is a - trajectory either starting and ending at the depot or starting at the depot and ending at - the current city.""" + """Add the customers and the depot to the plot, and draw each route in the tour in a + different colour. The tour is the entire trajectory between the visited customers and a + route is a trajectory either starting and ending at the depot or starting at the depot + and ending at the current city.""" x_coords, y_coords = ( state.nodes.coordinates[:, 0] / self._map_max, state.nodes.coordinates[:, 1] / self._map_max, diff --git a/jumanji/environments/routing/robot_warehouse/conftest.py b/jumanji/environments/routing/robot_warehouse/conftest.py index 95ed58271..68d815705 100644 --- a/jumanji/environments/routing/robot_warehouse/conftest.py +++ b/jumanji/environments/routing/robot_warehouse/conftest.py @@ -31,8 +31,8 @@ @pytest.fixture(scope="module") def robot_warehouse_env() -> RobotWarehouse: - """Instantiates a default RobotWarehouse environment with 2 agents, 1 shelf row, 3 shelf columns, - a column height of 2, sensor range of 1 and a request queue size of 4.""" + """Instantiates a default RobotWarehouse environment with 2 agents, 1 shelf row, 3 shelf + columns, a column height of 2, sensor range of 1 and a request queue size of 4.""" generator = RandomGenerator( shelf_rows=1, shelf_columns=3, diff --git a/jumanji/environments/routing/robot_warehouse/env.py b/jumanji/environments/routing/robot_warehouse/env.py index eb9c2c578..8ab107bc4 100644 --- a/jumanji/environments/routing/robot_warehouse/env.py +++ b/jumanji/environments/routing/robot_warehouse/env.py @@ -362,7 +362,9 @@ def observation_spec(self) -> specs.Spec[Observation]: @cached_property def action_spec(self) -> specs.MultiDiscreteArray: - """Returns the action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. + """Returns the action spec. 5 actions: + [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. + Since this is a multi-agent environment, the environment expects an array of actions. This array is of shape (num_agents,). """ diff --git a/jumanji/specs.py b/jumanji/specs.py index 6dc40237b..6cfacd546 100644 --- a/jumanji/specs.py +++ b/jumanji/specs.py @@ -44,9 +44,9 @@ class Spec(abc.ABC, Generic[T]): - """Adapted from `dm_env.spec.Array`. This is an augmentation of the `Array` spec to allow for nested - specs. `self.name`, `self.generate_value` and `self.validate` methods are adapted from the - `dm_env` object.""" + """Adapted from `dm_env.spec.Array`. This is an augmentation of the `Array` spec to allow for + nested specs. `self.name`, `self.generate_value` and `self.validate` methods are adapted from + the `dm_env` object.""" def __init__( self, @@ -139,7 +139,7 @@ def __getitem__(self, item: str) -> "Spec": class Array(Spec[chex.Array]): - """Describes a jax array spec. This is adapted from `dm_env.specs.Array` to suit Jax environments. + """Describes a jax array spec. This is adapted from `dm_env.specs.Array` for Jax environments. An `Array` spec allows an API to describe the arrays that it accepts or returns, before that array exists. diff --git a/jumanji/specs_test.py b/jumanji/specs_test.py index 74e95b512..09b9f48b1 100644 --- a/jumanji/specs_test.py +++ b/jumanji/specs_test.py @@ -589,7 +589,7 @@ def test_array(self) -> None: converted_spec: dm_env.specs.Array = specs.jumanji_specs_to_dm_env_specs( jumanji_spec ) - assert type(converted_spec) == type(dm_env_spec) + assert type(converted_spec) is type(dm_env_spec) assert converted_spec.shape == dm_env_spec.shape assert converted_spec.dtype == dm_env_spec.dtype assert converted_spec.name == dm_env_spec.name @@ -602,7 +602,7 @@ def test_bounded_array(self) -> None: converted_spec: dm_env.specs.BoundedArray = specs.jumanji_specs_to_dm_env_specs( jumanji_spec ) - assert type(converted_spec) == type(dm_env_spec) + assert type(converted_spec) is type(dm_env_spec) assert converted_spec.shape == dm_env_spec.shape assert converted_spec.dtype == dm_env_spec.dtype assert converted_spec.name == dm_env_spec.name @@ -615,7 +615,7 @@ def test_discrete_array(self) -> None: converted_spec: dm_env.specs.DiscreteArray = ( specs.jumanji_specs_to_dm_env_specs(jumanji_spec) ) - assert type(converted_spec) == type(dm_env_spec) + assert type(converted_spec) is type(dm_env_spec) assert converted_spec.shape == dm_env_spec.shape assert converted_spec.dtype == dm_env_spec.dtype assert converted_spec.name == dm_env_spec.name @@ -675,7 +675,7 @@ def test_array(self) -> None: jumanji_spec = specs.Array((1, 2), jnp.int32) gym_space = gym.spaces.Box(-np.inf, np.inf, (1, 2), jnp.int32) converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec) - assert type(converted_spec) == type(gym_space) + assert type(converted_spec) is type(gym_space) assert_trees_all_equal(converted_spec.low, gym_space.low) assert_trees_all_equal(converted_spec.high, gym_space.high) assert converted_spec.shape == gym_space.shape @@ -687,7 +687,7 @@ def test_bounded_array(self) -> None: ) gym_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1, 2), dtype=jnp.float32) converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec) - assert type(converted_spec) == type(gym_space) + assert type(converted_spec) is type(gym_space) assert converted_spec.shape == gym_space.shape assert converted_spec.dtype == gym_space.dtype assert_trees_all_equal(converted_spec.low, gym_space.low) @@ -697,7 +697,7 @@ def test_discrete_array(self) -> None: jumanji_spec = specs.DiscreteArray(num_values=5, dtype=jnp.int32) gym_space = gym.spaces.Discrete(n=5) converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec) - assert type(converted_spec) == type(gym_space) + assert type(converted_spec) is type(gym_space) assert converted_spec.shape == gym_space.shape assert converted_spec.dtype == gym_space.dtype assert converted_spec.n == gym_space.n @@ -708,7 +708,7 @@ def test_multi_discrete_array(self) -> None: ) gym_space = gym.spaces.MultiDiscrete(nvec=[5, 6]) converted_spec = specs.jumanji_specs_to_gym_spaces(jumanji_spec) - assert type(converted_spec) == type(gym_space) + assert type(converted_spec) is type(gym_space) assert converted_spec.shape == gym_space.shape assert converted_spec.dtype == gym_space.dtype assert jnp.array_equal(converted_spec.nvec, gym_space.nvec) diff --git a/jumanji/training/networks/graph_coloring/actor_critic.py b/jumanji/training/networks/graph_coloring/actor_critic.py index 6e2e336f6..62187d709 100644 --- a/jumanji/training/networks/graph_coloring/actor_critic.py +++ b/jumanji/training/networks/graph_coloring/actor_critic.py @@ -219,7 +219,7 @@ def __call__(self, observation: Observation) -> chex.Array: mlp_units=self.transformer_mlp_units, w_init_scale=2 / self.num_transformer_layers, model_size=self.model_size, - name=f"cross_attention_color_node_block_{block_id+1}", + name=f"cross_attention_color_node_block_{block_id + 1}", )(color_embeddings, current_node_embeddings, current_node_embeddings) return new_embedding diff --git a/jumanji/training/networks/mmst/actor_critic.py b/jumanji/training/networks/mmst/actor_critic.py index 45e776b4c..18b21dec3 100644 --- a/jumanji/training/networks/mmst/actor_critic.py +++ b/jumanji/training/networks/mmst/actor_critic.py @@ -207,7 +207,7 @@ def __call__(self, observation: Observation) -> chex.Array: mlp_units=self.transformer_mlp_units, w_init_scale=2 / self.num_transformer_layers, model_size=self.model_size, - name=f"cross_attention_agent_node_block_{block_id+1}", + name=f"cross_attention_agent_node_block_{block_id + 1}", )(agents_embeddings, current_node_embeddings, current_node_embeddings) return new_embedding diff --git a/jumanji/training/networks/tsp/actor_critic.py b/jumanji/training/networks/tsp/actor_critic.py index cff891c5e..1d720743b 100644 --- a/jumanji/training/networks/tsp/actor_critic.py +++ b/jumanji/training/networks/tsp/actor_critic.py @@ -72,8 +72,8 @@ def __init__( transformer_mlp_units: Sequence[int], name: Optional[str] = None, ): - """Linear embedding of all cities' coordinates followed by `transformer_num_blocks` blocks of self - attention. + """Linear embedding of all cities' coordinates followed by `transformer_num_blocks` blocks + of self attention. """ super().__init__(name=name) self.transformer_num_blocks = transformer_num_blocks diff --git a/pyproject.toml b/pyproject.toml index 79ed22fe2..636dcae26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,51 @@ -[tool.isort] -profile = "black" +[build-system] +requires=["setuptools>=62.6"] +build-backend="setuptools.build_meta" + +[project] +name="jumanji" +authors=[{name="InstaDeep Ltd", email="clement.bonnet16@gmail.com"}] +dynamic=["version", "dependencies", "optional-dependencies"] +license={file="LICENSE"} +description="A diverse suite of scalable reinforcement learning environments in JAX" +readme ="README.md" +requires-python=">=3.10" +keywords=["reinforcement-learning", "python", "jax"] +classifiers=[ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", + "License :: OSI Approved :: Apache Software License", +] + +[tool.setuptools.packages.find] +include=["jumanji*"] + +[tool.setuptools.package-data] +"jumanji" = ["py.typed"] + +[tool.setuptools.dynamic] +version={attr="jumanji.version.__version__"} +dependencies={file="requirements/requirements.txt"} +optional-dependencies.dev={file=["requirements/requirements-dev.txt"]} +optional-dependencies.train={file=["requirements/requirements-train.txt"]} + + +[project.urls] +"Homepage"="https://github.com/instadeep/jumanji" +"Bug Tracker"="https://github.com/instadeep/jumanji/issues" +"Documentation"="https://instadeepai.github.io/jumanji" [tool.mypy] -python_version = 3.8 +python_version = "3.10" namespace_packages = true incremental = false cache_dir = "" @@ -47,3 +90,6 @@ module = [ "PIL.*", ] ignore_missing_imports = true + +[tool.isort] +profile = "black" diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index e2a0aadd2..03d70be87 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -1,7 +1,6 @@ black==22.3.0 coverage -flake8==3.9.2 -importlib-metadata<5.0 +flake8 isort==5.11.5 livereload mkdocs==1.2.3 @@ -22,9 +21,5 @@ pytest-cov pytest-mock pytest-parallel pytest-xdist -pytype scipy>=1.7.3 testfixtures -types-Pillow -types-requests<1.27 -types-setuptools diff --git a/setup.cfg b/setup.cfg index d032b15c6..47a19d216 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,11 +20,21 @@ docstring-convention = google per-file-ignores = __init__.py:F401 ignore = - A002 # Argument shadowing a Python builtin. - A003 # Class attribute shadowing a Python builtin. - D107 # Do not require docstrings for __init__. - E266 # Do not require block comments to only have a single leading #. - E731 # Do not assign a lambda expression, use a def. - W503 # Line break before binary operator (not compatible with black). - B017 # assertRaises(Exception): or pytest.raises(Exception) should be considered evil. - E203 # black and flake8 disagree on whitespace before ':'. +# Argument shadowing a Python builtin. + A002 +# Class attribute shadowing a Python builtin. + A003 +# Module shadowing a Python builtin. + A005 +# Do not require docstrings for __init__. + D107 +# Do not require block comments to only have a single leading #. + E266 +# Do not assign a lambda expression, use a def. + E731 +# Line break before binary operator (not compatible with black). + W503 +# assertRaises(Exception): or pytest.raises(Exception) should be considered evil. + B017 +# black and flake8 disagree on whitespace before ':'. + E203 diff --git a/setup.py b/setup.py deleted file mode 100644 index d6f0c038c..000000000 --- a/setup.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from typing import List - -import setuptools -from setuptools import setup - - -def _parse_requirements(path: str) -> List[str]: - """Returns content of given requirements file.""" - with open(os.path.join(path)) as f: - return [ - line.rstrip() for line in f if not (line.isspace() or line.startswith("#")) - ] - - -def _get_version() -> str: - """Grabs the package version from jumanji/version.py.""" - dict_ = {} - with open("jumanji/version.py") as f: - exec(f.read(), dict_) - return dict_["__version__"] - - -setup( - name="jumanji", - version=_get_version(), - author="InstaDeep", - author_email="clement.bonnet16@gmail.com", - description="A diverse suite of scalable reinforcement learning environments in JAX", - license="Apache 2.0", - url="https://github.com/instadeepai/jumanji/", - long_description=open("README.md", encoding="utf-8").read(), - long_description_content_type="text/markdown", - keywords="reinforcement-learning python jax", - packages=setuptools.find_packages(), - python_requires=">=3.8", - install_requires=_parse_requirements("requirements/requirements.txt"), - extras_require={ - "dev": _parse_requirements("requirements/requirements-dev.txt"), - "train": _parse_requirements("requirements/requirements-train.txt"), - }, - package_data={"jumanji": ["py.typed"]}, - classifiers=[ - "Development Status :: 4 - Beta", - "Environment :: Console", - "Intended Audience :: Science/Research", - "Intended Audience :: Developers", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Software Development :: Libraries :: Python Modules", - "License :: OSI Approved :: Apache Software License", - ], - zip_safe=False, - include_package_data=True, -)