Skip to content

Commit

Permalink
Test last pyRDDLGym-rl release + add pyRDDLGym- gurobi in deps
Browse files Browse the repository at this point in the history
  • Loading branch information
nhuet committed Nov 8, 2024
1 parent d63b1db commit 16d71eb
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 95 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -564,9 +564,9 @@ jobs:
python_version=${{ matrix.python-version }}
wheelfile=$(ls ./wheels/scikit_decide*-cp${python_version/\./}-*win*.whl)
if [ "$python_version" = "3.12" ]; then
pip install ${wheelfile}[all] pytest "pygame>=2.5" optuna "cffi>=1.17" "git+https://github.com/pyrddlgym-project/pyRDDLGym-gurobi"
pip install ${wheelfile}[all] pytest "pygame>=2.5" optuna "cffi>=1.17"
else
pip install ${wheelfile}[all] pytest gymnasium[classic-control] optuna "git+https://github.com/pyrddlgym-project/pyRDDLGym-gurobi"
pip install ${wheelfile}[all] pytest gymnasium[classic-control] optuna
fi
- name: Test with pytest
Expand Down
5 changes: 3 additions & 2 deletions notebooks/16_rddl_tuto.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@
"\n",
"\n",
"<div class=\"alert alert-block alert-warning\"><b>Note: </b>\n",
"The solver needs a real license for Gurobi, as the free license available when installing gurobipy from PyPi is not sufficient to solve this domain.\n",
"To solve reasonable size problems, the solver needs a real license for Gurobi, as the free license available when installing gurobipy from PyPi is not sufficient to solve this domain. Here we limit the `rollout_horizon` to be able to run it with the free license, because optimization variables are created for each timestep.\n",
"</div>"
]
},
Expand All @@ -557,7 +557,8 @@
"assert RDDLGurobiSolver.check_domain(domain_factory_gurobi_agent())\n",
"\n",
"with RDDLGurobiSolver(\n",
" domain_factory=domain_factory_gurobi_agent, rollout_horizon=10\n",
" domain_factory=domain_factory_gurobi_agent,\n",
" rollout_horizon=2, # increase the rollout_horizon with real license\n",
") as solver:\n",
" solver.solve()\n",
" rollout(\n",
Expand Down
99 changes: 90 additions & 9 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 13 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,16 @@ pygrib = [
{ version = "<=2.1.5", platform = "linux", optional = true },
{ version = ">=2.1.5", platform = "darwin", optional = true },
]
pyRDDLGym = { version = ">=2.0, <2.1", optional = true }
pyRDDLGym-rl = { version = ">=0.1", optional = true }
pyRDDLGym = [
{ version = ">=2.1", python = "<3.12", optional = true },
{ version = "==2.0, <2.1", python = ">=3.12", optional = true },
]
pyRDDLGym-rl = [
{ version = ">=0.2", python = "<3.12", optional = true },
{ version = ">=0.1, <0.2", python = ">=3.12", optional = true },
]
pyRDDLGym-jax = { version = ">=0.3", optional = true }
pyRDDLGym-gurobi = { version = ">=0.2", optional = true }
rddlrepository = {version = ">=2.0", optional = true }

[tool.poetry.extras]
Expand Down Expand Up @@ -103,7 +110,8 @@ solvers = [
"up-enhsp",
"up-pyperplan",
"scipy",
"pyRDDLGym-jax"
"pyRDDLGym-jax",
"pyRDDLGym-gurobi"
]
all = [
"gymnasium",
Expand All @@ -125,7 +133,8 @@ all = [
"pyRDDLGym",
"pyRDDLGym-rl",
"rddlrepository",
"pyRDDLGym-jax"
"pyRDDLGym-jax",
"pyRDDLGym-gurobi"
]

[tool.poetry.plugins."skdecide.domains"]
Expand Down
112 changes: 44 additions & 68 deletions skdecide/hub/solver/rddl/rddl.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,21 @@
from collections.abc import Callable
from typing import Any, Optional

from pyRDDLGym_gurobi.core.planner import (
GurobiOnlineController,
GurobiPlan,
GurobiStraightLinePlan,
)
from pyRDDLGym_jax.core.planner import (
JaxBackpropPlanner,
JaxOfflineController,
JaxOnlineController,
load_config,
)
from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator

from skdecide import Solver
from skdecide.builders.solver import FromInitialState, Policies
from skdecide.hub.domain.rddl import RDDLDomain

try:
from pyRDDLGym_gurobi.core.planner import (
GurobiOnlineController,
GurobiPlan,
GurobiStraightLinePlan,
)
except ImportError:
pyrddlgym_gurobi_available = False
else:
pyrddlgym_gurobi_available = True


class D(RDDLDomain):
pass
Expand Down Expand Up @@ -60,58 +52,42 @@ def _is_policy_defined_for(self, observation: D.T_observation) -> bool:
return True


if pyrddlgym_gurobi_available:

class D(RDDLDomain):
pass

class RDDLGurobiSolver(Solver, Policies, FromInitialState):
T_domain = D

def __init__(
self,
domain_factory: Callable[[], RDDLDomain],
plan: Optional[GurobiPlan] = None,
rollout_horizon=5,
model_params: Optional[dict[str, Any]] = None,
):
Solver.__init__(self, domain_factory=domain_factory)
self._domain = domain_factory()
self.rollout_horizon = rollout_horizon
if plan is None:
self.plan = GurobiStraightLinePlan()
else:
self.plan = plan
if model_params is None:
self.model_params = {"NonConvex": 2, "OutputFlag": 0}
else:
self.model_params = model_params

@classmethod
def _check_domain_additional(cls, domain: D) -> bool:
return hasattr(domain, "rddl_gym_env")

def _solve(self, from_memory: Optional[D.T_state] = None) -> None:
self.controller = GurobiOnlineController(
rddl=self._domain.rddl_gym_env.model,
plan=self.plan,
rollout_horizon=self.rollout_horizon,
model_params=self.model_params,
)

def _sample_action(self, observation: D.T_observation) -> D.T_event:
return self.controller.sample_action(observation)

def _is_policy_defined_for(self, observation: D.T_observation) -> bool:
return True

else:

class RDDLGurobiSolver(Solver, Policies, FromInitialState):
T_domain = D

def __init__(self, domain_factory: Callable[[], RDDLDomain], rollout_horizon=5):
raise RuntimeError(
"You need pyRDDLGym-gurobi installed for this solver. "
"See https://github.com/pyrddlgym-project/pyRDDLGym-gurobi for more information."
)
class RDDLGurobiSolver(Solver, Policies, FromInitialState):
T_domain = D

def __init__(
self,
domain_factory: Callable[[], RDDLDomain],
plan: Optional[GurobiPlan] = None,
rollout_horizon=5,
model_params: Optional[dict[str, Any]] = None,
):
Solver.__init__(self, domain_factory=domain_factory)
self._domain = domain_factory()
self.rollout_horizon = rollout_horizon
if plan is None:
self.plan = GurobiStraightLinePlan()
else:
self.plan = plan
if model_params is None:
self.model_params = {"NonConvex": 2, "OutputFlag": 0}
else:
self.model_params = model_params

@classmethod
def _check_domain_additional(cls, domain: D) -> bool:
return hasattr(domain, "rddl_gym_env")

def _solve(self, from_memory: Optional[D.T_state] = None) -> None:
self.controller = GurobiOnlineController(
rddl=self._domain.rddl_gym_env.model,
plan=self.plan,
rollout_horizon=self.rollout_horizon,
model_params=self.model_params,
)

def _sample_action(self, observation: D.T_observation) -> D.T_event:
return self.controller.sample_action(observation)

def _is_policy_defined_for(self, observation: D.T_observation) -> bool:
return True
11 changes: 1 addition & 10 deletions tests/solvers/python/test_pyrddlgym_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,10 @@
import shutil
from urllib.request import urlcleanup, urlretrieve

import pytest
from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator

from skdecide.hub.domain.rddl import RDDLDomain
from skdecide.hub.solver.rddl.rddl import (
RDDLGurobiSolver,
RDDLJaxSolver,
pyrddlgym_gurobi_available,
)
from skdecide.hub.solver.rddl.rddl import RDDLGurobiSolver, RDDLJaxSolver
from skdecide.utils import load_registered_solver, rollout


Expand Down Expand Up @@ -44,10 +39,6 @@ def test_pyrddlgymdomain_jax():
rollout(domain_factory(), solver, max_steps=100, render=False, verbose=False)


@pytest.mark.skipif(
not pyrddlgym_gurobi_available,
reason="You need to install pyRDDL_gurobi for this solver",
)
def test_pyrddlgymdomain_gurobi():
# domain factory (with proper backend and vectorized flag)
domain_factory = lambda: RDDLDomain(
Expand Down

0 comments on commit 16d71eb

Please sign in to comment.