diff --git a/rllib/agents/a3c/a3c_torch_policy.py b/rllib/agents/a3c/a3c_torch_policy.py index a4d980bac7304..5eb83eb5f46d0 100644 --- a/rllib/agents/a3c/a3c_torch_policy.py +++ b/rllib/agents/a3c/a3c_torch_policy.py @@ -1,8 +1,8 @@ import ray from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() @@ -84,8 +84,9 @@ def _value(self, obs): return self.model.value_function()[0] -A3CTorchPolicy = build_torch_policy( +A3CTorchPolicy = build_policy_class( name="A3CTorchPolicy", + framework="torch", get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, loss_fn=actor_critic_loss, stats_fn=loss_and_entropy_stats, diff --git a/rllib/agents/ars/ars_torch_policy.py b/rllib/agents/ars/ars_torch_policy.py index 809b435b6765a..7b7140a489325 100644 --- a/rllib/agents/ars/ars_torch_policy.py +++ b/rllib/agents/ars/ars_torch_policy.py @@ -4,10 +4,11 @@ import ray from ray.rllib.agents.es.es_torch_policy import after_init, before_init, \ make_model_and_action_dist -from ray.rllib.policy.torch_policy_template import build_torch_policy +from ray.rllib.policy.policy_template import build_policy_class -ARSTorchPolicy = build_torch_policy( +ARSTorchPolicy = build_policy_class( name="ARSTorchPolicy", + framework="torch", loss_fn=None, get_default_config=lambda: ray.rllib.agents.ars.ars.DEFAULT_CONFIG, before_init=before_init, diff --git a/rllib/agents/ddpg/ddpg_torch_model.py b/rllib/agents/ddpg/ddpg_torch_model.py index a24b949207a91..66d910ebf07f3 100644 --- a/rllib/agents/ddpg/ddpg_torch_model.py +++ b/rllib/agents/ddpg/ddpg_torch_model.py @@ -2,7 +2,7 @@ from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.utils.framework import try_import_torch, get_activation_fn +from ray.rllib.utils.framework import get_activation_fn, try_import_torch torch, nn = try_import_torch() diff --git a/rllib/agents/ddpg/ddpg_torch_policy.py b/rllib/agents/ddpg/ddpg_torch_policy.py index 39123783a421d..79be4cce823c8 100644 --- a/rllib/agents/ddpg/ddpg_torch_policy.py +++ b/rllib/agents/ddpg/ddpg_torch_policy.py @@ -7,8 +7,8 @@ from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \ PRIO_WEIGHTS from ray.rllib.models.torch.torch_action_dist import TorchDeterministic +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import huber_loss, l2_loss @@ -264,8 +264,9 @@ def setup_late_mixins(policy, obs_space, action_space, config): TargetNetworkMixin.__init__(policy) -DDPGTorchPolicy = build_torch_policy( +DDPGTorchPolicy = build_policy_class( name="DDPGTorchPolicy", + framework="torch", loss_fn=ddpg_actor_critic_loss, get_default_config=lambda: ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG, stats_fn=build_ddpg_stats, diff --git a/rllib/agents/dqn/dqn_torch_policy.py b/rllib/agents/dqn/dqn_torch_policy.py index 9b9d535d95b46..1ed468e1d883b 100644 --- a/rllib/agents/dqn/dqn_torch_policy.py +++ b/rllib/agents/dqn/dqn_torch_policy.py @@ -14,9 +14,9 @@ from ray.rllib.models.torch.torch_action_dist import (TorchCategorical, TorchDistributionWrapper) from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule -from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.exploration.parameter_noise import ParameterNoise from ray.rllib.utils.framework import try_import_torch @@ -384,8 +384,9 @@ def extra_action_out_fn(policy: Policy, input_dict, state_batches, model, return {"q_values": policy.q_values} -DQNTorchPolicy = build_torch_policy( +DQNTorchPolicy = build_policy_class( name="DQNTorchPolicy", + framework="torch", loss_fn=build_q_losses, get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, make_model_and_action_dist=build_q_model_and_distribution, diff --git a/rllib/agents/dqn/simple_q_torch_policy.py b/rllib/agents/dqn/simple_q_torch_policy.py index b9ec0f0c41f28..9862f82b79749 100644 --- a/rllib/agents/dqn/simple_q_torch_policy.py +++ b/rllib/agents/dqn/simple_q_torch_policy.py @@ -11,8 +11,8 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \ TorchDistributionWrapper from ray.rllib.policy import Policy +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import huber_loss from ray.rllib.utils.typing import TensorType, TrainerConfigDict @@ -127,8 +127,9 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space, TargetNetworkMixin.__init__(policy, obs_space, action_space, config) -SimpleQTorchPolicy = build_torch_policy( +SimpleQTorchPolicy = build_policy_class( name="SimpleQPolicy", + framework="torch", loss_fn=build_q_losses, get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, extra_action_out_fn=extra_action_out_fn, diff --git a/rllib/agents/dreamer/dreamer_torch_policy.py b/rllib/agents/dreamer/dreamer_torch_policy.py index f9abd10c871ad..d23ad9c3088db 100644 --- a/rllib/agents/dreamer/dreamer_torch_policy.py +++ b/rllib/agents/dreamer/dreamer_torch_policy.py @@ -1,11 +1,11 @@ import logging import ray -from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.models.catalog import ModelCatalog from ray.rllib.agents.dreamer.utils import FreezeParameters +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.policy.policy_template import build_policy_class +from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() if torch: @@ -236,8 +236,9 @@ def dreamer_optimizer_fn(policy, config): return (model_opt, actor_opt, critic_opt) -DreamerTorchPolicy = build_torch_policy( +DreamerTorchPolicy = build_policy_class( name="DreamerTorchPolicy", + framework="torch", get_default_config=lambda: ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG, action_sampler_fn=action_sampler_fn, loss_fn=dreamer_loss, diff --git a/rllib/agents/es/es_torch_policy.py b/rllib/agents/es/es_torch_policy.py index 6f7e374c98732..444735e0b090b 100644 --- a/rllib/agents/es/es_torch_policy.py +++ b/rllib/agents/es/es_torch_policy.py @@ -7,8 +7,8 @@ import ray from ray.rllib.models import ModelCatalog +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.filter import get_filter from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \ @@ -126,8 +126,9 @@ def make_model_and_action_dist(policy, observation_space, action_space, return model, dist_class -ESTorchPolicy = build_torch_policy( +ESTorchPolicy = build_policy_class( name="ESTorchPolicy", + framework="torch", loss_fn=None, get_default_config=lambda: ray.rllib.agents.es.es.DEFAULT_CONFIG, before_init=before_init, diff --git a/rllib/agents/impala/vtrace_torch_policy.py b/rllib/agents/impala/vtrace_torch_policy.py index 22077202ee868..c6b8c26342f6c 100644 --- a/rllib/agents/impala/vtrace_torch_policy.py +++ b/rllib/agents/impala/vtrace_torch_policy.py @@ -6,10 +6,10 @@ from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping import ray.rllib.agents.impala.vtrace_torch as vtrace from ray.rllib.models.torch.torch_action_dist import TorchCategorical +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule, \ EntropyCoeffSchedule -from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import explained_variance, global_norm, \ sequence_mask @@ -260,8 +260,9 @@ def setup_mixins(policy, obs_space, action_space, config): LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) -VTraceTorchPolicy = build_torch_policy( +VTraceTorchPolicy = build_policy_class( name="VTraceTorchPolicy", + framework="torch", loss_fn=build_vtrace_loss, get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG, stats_fn=stats, diff --git a/rllib/agents/maml/maml_tf_policy.py b/rllib/agents/maml/maml_tf_policy.py index 7aff4c426575b..d07de1495f3c8 100644 --- a/rllib/agents/maml/maml_tf_policy.py +++ b/rllib/agents/maml/maml_tf_policy.py @@ -1,13 +1,13 @@ import logging import ray +from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \ + vf_preds_fetches, compute_and_clip_gradients, setup_config, \ + ValueNetworkMixin from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.utils import try_import_tf -from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \ - vf_preds_fetches, compute_and_clip_gradients, setup_config, \ - ValueNetworkMixin from ray.rllib.utils.framework import get_activation_fn tf1, tf, tfv = try_import_tf() diff --git a/rllib/agents/maml/maml_torch_policy.py b/rllib/agents/maml/maml_torch_policy.py index 182ac8c25af33..478d95ba65ab6 100644 --- a/rllib/agents/maml/maml_torch_policy.py +++ b/rllib/agents/maml/maml_torch_policy.py @@ -2,8 +2,8 @@ import ray from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \ setup_config from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \ @@ -347,8 +347,9 @@ def setup_mixins(policy, obs_space, action_space, config): KLCoeffMixin.__init__(policy, config) -MAMLTorchPolicy = build_torch_policy( +MAMLTorchPolicy = build_policy_class( name="MAMLTorchPolicy", + framework="torch", get_default_config=lambda: ray.rllib.agents.maml.maml.DEFAULT_CONFIG, loss_fn=maml_loss, stats_fn=maml_stats, diff --git a/rllib/agents/marwil/marwil_torch_policy.py b/rllib/agents/marwil/marwil_torch_policy.py index e88e5e312f403..a64194abf9534 100644 --- a/rllib/agents/marwil/marwil_torch_policy.py +++ b/rllib/agents/marwil/marwil_torch_policy.py @@ -1,8 +1,8 @@ import ray from ray.rllib.agents.marwil.marwil_tf_policy import postprocess_advantages from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import explained_variance @@ -75,8 +75,9 @@ def setup_mixins(policy, obs_space, action_space, config): ValueNetworkMixin.__init__(policy) -MARWILTorchPolicy = build_torch_policy( +MARWILTorchPolicy = build_policy_class( name="MARWILTorchPolicy", + framework="torch", loss_fn=marwil_loss, get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG, stats_fn=stats, diff --git a/rllib/agents/mbmpo/mbmpo_torch_policy.py b/rllib/agents/mbmpo/mbmpo_torch_policy.py index a4682ba81fe7a..f43d06ebec5a2 100644 --- a/rllib/agents/mbmpo/mbmpo_torch_policy.py +++ b/rllib/agents/mbmpo/mbmpo_torch_policy.py @@ -13,7 +13,7 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.policy.policy import Policy -from ray.rllib.policy.torch_policy_template import build_torch_policy +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import TrainerConfigDict @@ -76,8 +76,9 @@ def make_model_and_action_dist( # Build a child class of `TorchPolicy`, given the custom functions defined # above. -MBMPOTorchPolicy = build_torch_policy( +MBMPOTorchPolicy = build_policy_class( name="MBMPOTorchPolicy", + framework="torch", get_default_config=lambda: ray.rllib.agents.mbmpo.mbmpo.DEFAULT_CONFIG, make_model_and_action_dist=make_model_and_action_dist, loss_fn=maml_loss, diff --git a/rllib/agents/pg/pg_torch_policy.py b/rllib/agents/pg/pg_torch_policy.py index be65f9e91c847..d707f01f2364e 100644 --- a/rllib/agents/pg/pg_torch_policy.py +++ b/rllib/agents/pg/pg_torch_policy.py @@ -10,8 +10,8 @@ from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.policy import Policy +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import TensorType @@ -72,8 +72,9 @@ def pg_loss_stats(policy: Policy, # Build a child class of `TFPolicy`, given the extra options: # - trajectory post-processing function (to calculate advantages) # - PG loss function -PGTorchPolicy = build_torch_policy( +PGTorchPolicy = build_policy_class( name="PGTorchPolicy", + framework="torch", get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG, loss_fn=pg_torch_loss, stats_fn=pg_loss_stats, diff --git a/rllib/agents/ppo/appo_torch_policy.py b/rllib/agents/ppo/appo_torch_policy.py index f24dfc6d1b54f..461886dbec2be 100644 --- a/rllib/agents/ppo/appo_torch_policy.py +++ b/rllib/agents/ppo/appo_torch_policy.py @@ -23,9 +23,9 @@ from ray.rllib.models.torch.torch_action_dist import \ TorchDistributionWrapper, TorchCategorical from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule -from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import explained_variance, global_norm, \ sequence_mask @@ -322,8 +322,9 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space, # Build a child class of `TorchPolicy`, given the custom functions defined # above. -AsyncPPOTorchPolicy = build_torch_policy( +AsyncPPOTorchPolicy = build_policy_class( name="AsyncPPOTorchPolicy", + framework="torch", loss_fn=appo_surrogate_loss, stats_fn=stats, postprocess_fn=postprocess_trajectory, diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index d99251298c0c4..d73f53666f7cf 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -14,10 +14,10 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \ LearningRateSchedule -from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import convert_to_torch_tensor, \ explained_variance, sequence_mask @@ -111,6 +111,9 @@ def reduce_mean_valid(t): policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_vf_loss = mean_vf_loss + policy._vf_explained_var = explained_variance( + train_batch[Postprocessing.VALUE_TARGETS], + policy.model.value_function()) policy._mean_entropy = mean_entropy policy._mean_kl = mean_kl @@ -134,9 +137,7 @@ def kl_and_loss_stats(policy: Policy, "total_loss": policy._total_loss, "policy_loss": policy._mean_policy_loss, "vf_loss": policy._mean_vf_loss, - "vf_explained_var": explained_variance( - train_batch[Postprocessing.VALUE_TARGETS], - policy.model.value_function()), + "vf_explained_var": policy._vf_explained_var, "kl": policy._mean_kl, "entropy": policy._mean_entropy, "entropy_coeff": policy.entropy_coeff, @@ -271,8 +272,9 @@ def setup_mixins(policy: Policy, obs_space: gym.spaces.Space, # Build a child class of `TorchPolicy`, given the custom functions defined # above. -PPOTorchPolicy = build_torch_policy( +PPOTorchPolicy = build_policy_class( name="PPOTorchPolicy", + framework="torch", get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, loss_fn=ppo_surrogate_loss, stats_fn=kl_and_loss_stats, diff --git a/rllib/agents/qmix/qmix_policy.py b/rllib/agents/qmix/qmix_policy.py index 6b533378ac8aa..539efa2c5e359 100644 --- a/rllib/agents/qmix/qmix_policy.py +++ b/rllib/agents/qmix/qmix_policy.py @@ -143,7 +143,7 @@ def forward(self, return loss, mask, masked_td_error, chosen_action_qvals, targets -# TODO(sven): Make this a TorchPolicy child via `build_torch_policy`. +# TODO(sven): Make this a TorchPolicy child via `build_policy_class`. class QMixTorchPolicy(Policy): """QMix impl. Assumes homogeneous agents for now. diff --git a/rllib/agents/sac/sac_torch_policy.py b/rllib/agents/sac/sac_torch_policy.py index a1a8f996bc23c..d1d53697ba2f2 100644 --- a/rllib/agents/sac/sac_torch_policy.py +++ b/rllib/agents/sac/sac_torch_policy.py @@ -17,8 +17,8 @@ from ray.rllib.models.torch.torch_action_dist import \ TorchDistributionWrapper, TorchDirichlet from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.models.torch.torch_action_dist import ( TorchCategorical, TorchSquashedGaussian, TorchDiagGaussian, TorchBeta) from ray.rllib.utils.framework import try_import_torch @@ -480,8 +480,9 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space, # Build a child class of `TorchPolicy`, given the custom functions defined # above. -SACTorchPolicy = build_torch_policy( +SACTorchPolicy = build_policy_class( name="SACTorchPolicy", + framework="torch", loss_fn=actor_critic_loss, get_default_config=lambda: ray.rllib.agents.sac.sac.DEFAULT_CONFIG, stats_fn=stats, diff --git a/rllib/agents/slateq/slateq_torch_policy.py b/rllib/agents/slateq/slateq_torch_policy.py index 0afb7cb12031a..19638d65767a4 100644 --- a/rllib/agents/slateq/slateq_torch_policy.py +++ b/rllib/agents/slateq/slateq_torch_policy.py @@ -11,8 +11,8 @@ TorchDistributionWrapper) from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import (ModelConfigDict, TensorType, TrainerConfigDict) @@ -403,8 +403,9 @@ def postprocess_fn_add_next_actions_for_sarsa(policy: Policy, return batch -SlateQTorchPolicy = build_torch_policy( +SlateQTorchPolicy = build_policy_class( name="SlateQTorchPolicy", + framework="torch", get_default_config=lambda: ray.rllib.agents.slateq.slateq.DEFAULT_CONFIG, # build model, loss functions, and optimizers diff --git a/rllib/contrib/bandits/agents/policy.py b/rllib/contrib/bandits/agents/policy.py index 2a9b50137381b..e47c91005232c 100644 --- a/rllib/contrib/bandits/agents/policy.py +++ b/rllib/contrib/bandits/agents/policy.py @@ -10,9 +10,9 @@ from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import restore_original_dimensions from ray.rllib.policy.policy import LEARNER_STATS_KEY +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import TorchPolicy -from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.annotations import override from ray.util.debug import log_once @@ -109,8 +109,9 @@ def init_cum_regret(policy, *args): policy.regrets = [] -BanditPolicy = build_torch_policy( +BanditPolicy = build_policy_class( name="BanditPolicy", + framework="torch", get_default_config=lambda: DEFAULT_CONFIG, loss_fn=None, after_init=init_cum_regret, diff --git a/rllib/examples/custom_torch_policy.py b/rllib/examples/custom_torch_policy.py index 1cea6aa1cf515..3c88e89f16462 100644 --- a/rllib/examples/custom_torch_policy.py +++ b/rllib/examples/custom_torch_policy.py @@ -4,8 +4,8 @@ import ray from ray import tune from ray.rllib.agents.trainer_template import build_trainer +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_policy_template import build_torch_policy parser = argparse.ArgumentParser() parser.add_argument("--stop-iters", type=int, default=200) @@ -20,8 +20,8 @@ def policy_gradient_loss(policy, model, dist_class, train_batch): # -MyTorchPolicy = build_torch_policy( - name="MyTorchPolicy", loss_fn=policy_gradient_loss) +MyTorchPolicy = build_policy_class( + name="MyTorchPolicy", framework="torch", loss_fn=policy_gradient_loss) # MyTrainer = build_trainer( diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index fc45149a564e2..85d99129482de 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -9,6 +9,7 @@ from ray.rllib.models.repeated_values import RepeatedValues from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement +from ray.rllib.utils import NullContextManager from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI from ray.rllib.utils.framework import try_import_tf, try_import_torch, \ TensorType @@ -280,7 +281,7 @@ def variables(self, as_dict: bool = False Args: as_dict(bool): Whether variables should be returned as dict-values - (using descriptive keys). + (using descriptive str keys). Returns: Union[List[any],Dict[str,any]]: The list (or dict if `as_dict` is @@ -375,19 +376,6 @@ def get_input_dict(self, sample_batch, return input_dict -class NullContextManager: - """No-op context manager""" - - def __init__(self): - pass - - def __enter__(self): - pass - - def __exit__(self, *args): - pass - - @DeveloperAPI def flatten(obs: TensorType, framework: str) -> TensorType: """Flatten the given tensor.""" diff --git a/rllib/policy/__init__.py b/rllib/policy/__init__.py index 348fe187da4c6..67868182a07af 100644 --- a/rllib/policy/__init__.py +++ b/rllib/policy/__init__.py @@ -1,6 +1,7 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.policy.tf_policy_template import build_tf_policy @@ -8,6 +9,7 @@ "Policy", "TFPolicy", "TorchPolicy", + "build_policy_class", "build_tf_policy", "build_torch_policy", ] diff --git a/rllib/policy/policy_template.py b/rllib/policy/policy_template.py new file mode 100644 index 0000000000000..9c955741379d5 --- /dev/null +++ b/rllib/policy/policy_template.py @@ -0,0 +1,400 @@ +import gym +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.jax.jax_modelv2 import JAXModelV2 +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_policy import TorchPolicy +from ray.rllib.policy.view_requirement import ViewRequirement +from ray.rllib.utils import add_mixins, force_list, NullContextManager +from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.framework import try_import_torch, try_import_jax +from ray.rllib.utils.torch_ops import convert_to_non_torch_type +from ray.rllib.utils.typing import TensorType, TrainerConfigDict + +jax, _ = try_import_jax() +torch, _ = try_import_torch() + + +# TODO: (sven) Unify this with `build_tf_policy` as well. +@DeveloperAPI +def build_policy_class( + name: str, + framework: str, + *, + loss_fn: Optional[Callable[[ + Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch + ], Union[TensorType, List[TensorType]]]], + get_default_config: Optional[Callable[[], TrainerConfigDict]] = None, + stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[ + str, TensorType]]] = None, + postprocess_fn: Optional[Callable[[ + Policy, SampleBatch, Optional[Dict[Any, SampleBatch]], Optional[ + "MultiAgentEpisode"] + ], SampleBatch]] = None, + extra_action_out_fn: Optional[Callable[[ + Policy, Dict[str, TensorType], List[TensorType], ModelV2, + TorchDistributionWrapper + ], Dict[str, TensorType]]] = None, + extra_grad_process_fn: Optional[Callable[[ + Policy, "torch.optim.Optimizer", TensorType + ], Dict[str, TensorType]]] = None, + # TODO: (sven) Replace "fetches" with "process". + extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[ + str, TensorType]]] = None, + optimizer_fn: Optional[Callable[[Policy, TrainerConfigDict], + "torch.optim.Optimizer"]] = None, + validate_spaces: Optional[Callable[ + [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, + before_init: Optional[Callable[ + [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, + before_loss_init: Optional[Callable[[ + Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict + ], None]] = None, + after_init: Optional[Callable[ + [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, + _after_loss_init: Optional[Callable[[ + Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict + ], None]] = None, + action_sampler_fn: Optional[Callable[[TensorType, List[ + TensorType]], Tuple[TensorType, TensorType]]] = None, + action_distribution_fn: Optional[Callable[[ + Policy, ModelV2, TensorType, TensorType, TensorType + ], Tuple[TensorType, type, List[TensorType]]]] = None, + make_model: Optional[Callable[[ + Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict + ], ModelV2]] = None, + make_model_and_action_dist: Optional[Callable[[ + Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict + ], Tuple[ModelV2, Type[TorchDistributionWrapper]]]] = None, + apply_gradients_fn: Optional[Callable[ + [Policy, "torch.optim.Optimizer"], None]] = None, + mixins: Optional[List[type]] = None, + view_requirements_fn: Optional[Callable[[Policy], Dict[ + str, ViewRequirement]]] = None, + get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None +) -> Type[TorchPolicy]: + """Helper function for creating a new Policy class at runtime. + + Supports frameworks JAX and PyTorch. + + Args: + name (str): name of the policy (e.g., "PPOTorchPolicy") + framework (str): Either "jax" or "torch". + loss_fn (Optional[Callable[[Policy, ModelV2, + Type[TorchDistributionWrapper], SampleBatch], Union[TensorType, + List[TensorType]]]]): Callable that returns a loss tensor. + get_default_config (Optional[Callable[[None], TrainerConfigDict]]): + Optional callable that returns the default config to merge with any + overrides. If None, uses only(!) the user-provided + PartialTrainerConfigDict as dict for this Policy. + postprocess_fn (Optional[Callable[[Policy, SampleBatch, + Optional[Dict[Any, SampleBatch]], Optional["MultiAgentEpisode"]], + SampleBatch]]): Optional callable for post-processing experience + batches (called after the super's `postprocess_trajectory` method). + stats_fn (Optional[Callable[[Policy, SampleBatch], + Dict[str, TensorType]]]): Optional callable that returns a dict of + values given the policy and training batch. If None, + will use `TorchPolicy.extra_grad_info()` instead. The stats dict is + used for logging (e.g. in TensorBoard). + extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType], + List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str, + TensorType]]]): Optional callable that returns a dict of extra + values to include in experiences. If None, no extra computations + will be performed. + extra_grad_process_fn (Optional[Callable[[Policy, + "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]]): + Optional callable that is called after gradients are computed and + returns a processing info dict. If None, will call the + `TorchPolicy.extra_grad_process()` method instead. + # TODO: (sven) dissolve naming mismatch between "learn" and "compute.." + extra_learn_fetches_fn (Optional[Callable[[Policy], + Dict[str, TensorType]]]): Optional callable that returns a dict of + extra tensors from the policy after loss evaluation. If None, + will call the `TorchPolicy.extra_compute_grad_fetches()` method + instead. + optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict], + "torch.optim.Optimizer"]]): Optional callable that returns a + torch optimizer given the policy and config. If None, will call + the `TorchPolicy.optimizer()` method instead (which returns a + torch Adam optimizer). + validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space, + TrainerConfigDict], None]]): Optional callable that takes the + Policy, observation_space, action_space, and config to check for + correctness. If None, no spaces checking will be done. + before_init (Optional[Callable[[Policy, gym.Space, gym.Space, + TrainerConfigDict], None]]): Optional callable to run at the + beginning of `Policy.__init__` that takes the same arguments as + the Policy constructor. If None, this step will be skipped. + before_loss_init (Optional[Callable[[Policy, gym.spaces.Space, + gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to + run prior to loss init. If None, this step will be skipped. + after_init (Optional[Callable[[Policy, gym.Space, gym.Space, + TrainerConfigDict], None]]): DEPRECATED: Use `before_loss_init` + instead. + _after_loss_init (Optional[Callable[[Policy, gym.spaces.Space, + gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to + run after the loss init. If None, this step will be skipped. + This will be deprecated at some point and renamed into `after_init` + to match `build_tf_policy()` behavior. + action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]], + Tuple[TensorType, TensorType]]]): Optional callable returning a + sampled action and its log-likelihood given some (obs and state) + inputs. If None, will either use `action_distribution_fn` or + compute actions by calling self.model, then sampling from the + so parameterized action distribution. + action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType, + TensorType, TensorType], Tuple[TensorType, + Type[TorchDistributionWrapper], List[TensorType]]]]): A callable + that takes the Policy, Model, the observation batch, an + explore-flag, a timestep, and an is_training flag and returns a + tuple of a) distribution inputs (parameters), b) a dist-class to + generate an action distribution object from, and c) internal-state + outputs (empty list if not applicable). If None, will either use + `action_sampler_fn` or compute actions by calling self.model, + then sampling from the parameterized action distribution. + make_model (Optional[Callable[[Policy, gym.spaces.Space, + gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable + that takes the same arguments as Policy.__init__ and returns a + model instance. The distribution class will be determined + automatically. Note: Only one of `make_model` or + `make_model_and_action_dist` should be provided. If both are None, + a default Model will be created. + make_model_and_action_dist (Optional[Callable[[Policy, + gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], + Tuple[ModelV2, Type[TorchDistributionWrapper]]]]): Optional + callable that takes the same arguments as Policy.__init__ and + returns a tuple of model instance and torch action distribution + class. + Note: Only one of `make_model` or `make_model_and_action_dist` + should be provided. If both are None, a default Model will be + created. + apply_gradients_fn (Optional[Callable[[Policy, + "torch.optim.Optimizer"], None]]): Optional callable that + takes a grads list and applies these to the Model's parameters. + If None, will call the `TorchPolicy.apply_gradients()` method + instead. + mixins (Optional[List[type]]): Optional list of any class mixins for + the returned policy class. These mixins will be applied in order + and will have higher precedence than the TorchPolicy class. + view_requirements_fn (Optional[Callable[[Policy], + Dict[str, ViewRequirement]]]): An optional callable to retrieve + additional train view requirements for this policy. + get_batch_divisibility_req (Optional[Callable[[Policy], int]]): + Optional callable that returns the divisibility requirement for + sample batches. If None, will assume a value of 1. + + Returns: + Type[TorchPolicy]: TorchPolicy child class constructed from the + specified args. + """ + + original_kwargs = locals().copy() + parent_cls = TorchPolicy + base = add_mixins(parent_cls, mixins) + + class policy_cls(base): + def __init__(self, obs_space, action_space, config): + # Set up the config from possible default-config fn and given + # config arg. + if get_default_config: + config = dict(get_default_config(), **config) + self.config = config + + # Set the DL framework for this Policy. + self.framework = self.config["framework"] = framework + + # Validate observation- and action-spaces. + if validate_spaces: + validate_spaces(self, obs_space, action_space, self.config) + + # Do some pre-initialization steps. + if before_init: + before_init(self, obs_space, action_space, self.config) + + # Model is customized (use default action dist class). + if make_model: + assert make_model_and_action_dist is None, \ + "Either `make_model` or `make_model_and_action_dist`" \ + " must be None!" + self.model = make_model(self, obs_space, action_space, config) + dist_class, _ = ModelCatalog.get_action_dist( + action_space, self.config["model"], framework=framework) + # Model and action dist class are customized. + elif make_model_and_action_dist: + self.model, dist_class = make_model_and_action_dist( + self, obs_space, action_space, config) + # Use default model and default action dist. + else: + dist_class, logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"], framework=framework) + self.model = ModelCatalog.get_model_v2( + obs_space=obs_space, + action_space=action_space, + num_outputs=logit_dim, + model_config=self.config["model"], + framework=framework) + + # Make sure, we passed in a correct Model factory. + model_cls = TorchModelV2 if framework == "torch" else JAXModelV2 + assert isinstance(self.model, model_cls), \ + "ERROR: Generated Model must be a TorchModelV2 object!" + + # Call the framework-specific Policy constructor. + self.parent_cls = parent_cls + self.parent_cls.__init__( + self, + observation_space=obs_space, + action_space=action_space, + config=config, + model=self.model, + loss=loss_fn, + action_distribution_class=dist_class, + action_sampler_fn=action_sampler_fn, + action_distribution_fn=action_distribution_fn, + max_seq_len=config["model"]["max_seq_len"], + get_batch_divisibility_req=get_batch_divisibility_req, + ) + + # Update this Policy's ViewRequirements (if function given). + if callable(view_requirements_fn): + self.view_requirements.update(view_requirements_fn(self)) + # Merge Model's view requirements into Policy's. + self.view_requirements.update( + self.model.inference_view_requirements) + + _before_loss_init = before_loss_init or after_init + if _before_loss_init: + _before_loss_init(self, self.observation_space, + self.action_space, config) + + # Perform test runs through postprocessing- and loss functions. + self._initialize_loss_from_dummy_batch( + auto_remove_unneeded_view_reqs=True, + stats_fn=stats_fn, + ) + + if _after_loss_init: + _after_loss_init(self, obs_space, action_space, config) + + # Got to reset global_timestep again after this fake run-through. + self.global_timestep = 0 + + @override(Policy) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): + # Do all post-processing always with no_grad(). + # Not using this here will introduce a memory leak + # in torch (issue #6962). + with self._no_grad_context(): + # Call super's postprocess_trajectory first. + sample_batch = super().postprocess_trajectory( + sample_batch, other_agent_batches, episode) + if postprocess_fn: + return postprocess_fn(self, sample_batch, + other_agent_batches, episode) + + return sample_batch + + @override(parent_cls) + def extra_grad_process(self, optimizer, loss): + """Called after optimizer.zero_grad() and loss.backward() calls. + + Allows for gradient processing before optimizer.step() is called. + E.g. for gradient clipping. + """ + if extra_grad_process_fn: + return extra_grad_process_fn(self, optimizer, loss) + else: + return parent_cls.extra_grad_process(self, optimizer, loss) + + @override(parent_cls) + def extra_compute_grad_fetches(self): + if extra_learn_fetches_fn: + fetches = convert_to_non_torch_type( + extra_learn_fetches_fn(self)) + # Auto-add empty learner stats dict if needed. + return dict({LEARNER_STATS_KEY: {}}, **fetches) + else: + return parent_cls.extra_compute_grad_fetches(self) + + @override(parent_cls) + def apply_gradients(self, gradients): + if apply_gradients_fn: + apply_gradients_fn(self, gradients) + else: + parent_cls.apply_gradients(self, gradients) + + @override(parent_cls) + def extra_action_out(self, input_dict, state_batches, model, + action_dist): + with self._no_grad_context(): + if extra_action_out_fn: + stats_dict = extra_action_out_fn( + self, input_dict, state_batches, model, action_dist) + else: + stats_dict = parent_cls.extra_action_out( + self, input_dict, state_batches, model, action_dist) + return self._convert_to_non_torch_type(stats_dict) + + @override(parent_cls) + def optimizer(self): + if optimizer_fn: + optimizers = optimizer_fn(self, self.config) + else: + optimizers = parent_cls.optimizer(self) + optimizers = force_list(optimizers) + if getattr(self, "exploration", None): + optimizers = self.exploration.get_exploration_optimizer( + optimizers) + return optimizers + + @override(parent_cls) + def extra_grad_info(self, train_batch): + with self._no_grad_context(): + if stats_fn: + stats_dict = stats_fn(self, train_batch) + else: + stats_dict = self.parent_cls.extra_grad_info( + self, train_batch) + return self._convert_to_non_torch_type(stats_dict) + + def _no_grad_context(self): + if self.framework == "torch": + return torch.no_grad() + return NullContextManager() + + def _convert_to_non_torch_type(self, data): + if self.framework == "torch": + return convert_to_non_torch_type(data) + return data + + def with_updates(**overrides): + """Creates a Torch|JAXPolicy cls based on settings of another one. + + Keyword Args: + **overrides: The settings (passed into `build_torch_policy`) that + should be different from the class that this method is called + on. + + Returns: + type: A new Torch|JAXPolicy sub-class. + + Examples: + >> MySpecialDQNPolicyClass = DQNTorchPolicy.with_updates( + .. name="MySpecialDQNPolicyClass", + .. loss_function=[some_new_loss_function], + .. ) + """ + return build_policy_class(**dict(original_kwargs, **overrides)) + + policy_cls.with_updates = staticmethod(with_updates) + policy_cls.__name__ = name + policy_cls.__qualname__ = name + return policy_cls diff --git a/rllib/policy/torch_policy_template.py b/rllib/policy/torch_policy_template.py index 2c698fefa1de2..78777f6ba1412 100644 --- a/rllib/policy/torch_policy_template.py +++ b/rllib/policy/torch_policy_template.py @@ -1,17 +1,16 @@ import gym from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union -from ray.rllib.models.catalog import ModelCatalog +from ray.util import log_once from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import TorchPolicy -from ray.rllib.utils import add_mixins, force_list -from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_ops import convert_to_non_torch_type from ray.rllib.utils.typing import TensorType, TrainerConfigDict torch, _ = try_import_torch() @@ -38,7 +37,6 @@ def build_torch_policy( extra_grad_process_fn: Optional[Callable[[ Policy, "torch.optim.Optimizer", TensorType ], Dict[str, TensorType]]] = None, - # TODO: (sven) Replace "fetches" with "process". extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[ str, TensorType]]] = None, optimizer_fn: Optional[Callable[[Policy, TrainerConfigDict], @@ -71,291 +69,13 @@ def build_torch_policy( mixins: Optional[List[type]] = None, get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None ) -> Type[TorchPolicy]: - """Helper function for creating a torch policy class at runtime. - Args: - name (str): name of the policy (e.g., "PPOTorchPolicy") - loss_fn (Optional[Callable[[Policy, ModelV2, - Type[TorchDistributionWrapper], SampleBatch], Union[TensorType, - List[TensorType]]]]): Callable that returns a loss tensor. - get_default_config (Optional[Callable[[None], TrainerConfigDict]]): - Optional callable that returns the default config to merge with any - overrides. If None, uses only(!) the user-provided - PartialTrainerConfigDict as dict for this Policy. - postprocess_fn (Optional[Callable[[Policy, SampleBatch, - Optional[Dict[Any, SampleBatch]], Optional["MultiAgentEpisode"]], - SampleBatch]]): Optional callable for post-processing experience - batches (called after the super's `postprocess_trajectory` method). - stats_fn (Optional[Callable[[Policy, SampleBatch], - Dict[str, TensorType]]]): Optional callable that returns a dict of - values given the policy and training batch. If None, - will use `TorchPolicy.extra_grad_info()` instead. The stats dict is - used for logging (e.g. in TensorBoard). - extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType], - List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str, - TensorType]]]): Optional callable that returns a dict of extra - values to include in experiences. If None, no extra computations - will be performed. - extra_grad_process_fn (Optional[Callable[[Policy, - "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]]): - Optional callable that is called after gradients are computed and - returns a processing info dict. If None, will call the - `TorchPolicy.extra_grad_process()` method instead. - # TODO: (sven) dissolve naming mismatch between "learn" and "compute.." - extra_learn_fetches_fn (Optional[Callable[[Policy], - Dict[str, TensorType]]]): Optional callable that returns a dict of - extra tensors from the policy after loss evaluation. If None, - will call the `TorchPolicy.extra_compute_grad_fetches()` method - instead. - optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict], - "torch.optim.Optimizer"]]): Optional callable that returns a - torch optimizer given the policy and config. If None, will call - the `TorchPolicy.optimizer()` method instead (which returns a - torch Adam optimizer). - validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space, - TrainerConfigDict], None]]): Optional callable that takes the - Policy, observation_space, action_space, and config to check for - correctness. If None, no spaces checking will be done. - before_init (Optional[Callable[[Policy, gym.Space, gym.Space, - TrainerConfigDict], None]]): Optional callable to run at the - beginning of `Policy.__init__` that takes the same arguments as - the Policy constructor. If None, this step will be skipped. - before_loss_init (Optional[Callable[[Policy, gym.spaces.Space, - gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to - run prior to loss init. If None, this step will be skipped. - after_init (Optional[Callable[[Policy, gym.Space, gym.Space, - TrainerConfigDict], None]]): DEPRECATED: Use `before_loss_init` - instead. - _after_loss_init (Optional[Callable[[Policy, gym.spaces.Space, - gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to - run after the loss init. If None, this step will be skipped. - This will be deprecated at some point and renamed into `after_init` - to match `build_tf_policy()` behavior. - action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]], - Tuple[TensorType, TensorType]]]): Optional callable returning a - sampled action and its log-likelihood given some (obs and state) - inputs. If None, will either use `action_distribution_fn` or - compute actions by calling self.model, then sampling from the - so parameterized action distribution. - action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType, - TensorType, TensorType], Tuple[TensorType, - Type[TorchDistributionWrapper], List[TensorType]]]]): A callable - that takes the Policy, Model, the observation batch, an - explore-flag, a timestep, and an is_training flag and returns a - tuple of a) distribution inputs (parameters), b) a dist-class to - generate an action distribution object from, and c) internal-state - outputs (empty list if not applicable). If None, will either use - `action_sampler_fn` or compute actions by calling self.model, - then sampling from the parameterized action distribution. - make_model (Optional[Callable[[Policy, gym.spaces.Space, - gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable - that takes the same arguments as Policy.__init__ and returns a - model instance. The distribution class will be determined - automatically. Note: Only one of `make_model` or - `make_model_and_action_dist` should be provided. If both are None, - a default Model will be created. - make_model_and_action_dist (Optional[Callable[[Policy, - gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], - Tuple[ModelV2, Type[TorchDistributionWrapper]]]]): Optional - callable that takes the same arguments as Policy.__init__ and - returns a tuple of model instance and torch action distribution - class. - Note: Only one of `make_model` or `make_model_and_action_dist` - should be provided. If both are None, a default Model will be - created. - apply_gradients_fn (Optional[Callable[[Policy, - "torch.optim.Optimizer"], None]]): Optional callable that - takes a grads list and applies these to the Model's parameters. - If None, will call the `TorchPolicy.apply_gradients()` method - instead. - mixins (Optional[List[type]]): Optional list of any class mixins for - the returned policy class. These mixins will be applied in order - and will have higher precedence than the TorchPolicy class. - get_batch_divisibility_req (Optional[Callable[[Policy], int]]): - Optional callable that returns the divisibility requirement for - sample batches. If None, will assume a value of 1. - - Returns: - Type[TorchPolicy]: TorchPolicy child class constructed from the - specified args. - """ - - original_kwargs = locals().copy() - base = add_mixins(TorchPolicy, mixins) - - class policy_cls(base): - def __init__(self, obs_space, action_space, config): - if get_default_config: - config = dict(get_default_config(), **config) - self.config = config - - if validate_spaces: - validate_spaces(self, obs_space, action_space, self.config) - - if before_init: - before_init(self, obs_space, action_space, self.config) - - # Model is customized (use default action dist class). - if make_model: - assert make_model_and_action_dist is None, \ - "Either `make_model` or `make_model_and_action_dist`" \ - " must be None!" - self.model = make_model(self, obs_space, action_space, config) - dist_class, _ = ModelCatalog.get_action_dist( - action_space, self.config["model"], framework="torch") - # Model and action dist class are customized. - elif make_model_and_action_dist: - self.model, dist_class = make_model_and_action_dist( - self, obs_space, action_space, config) - # Use default model and default action dist. - else: - dist_class, logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"], framework="torch") - self.model = ModelCatalog.get_model_v2( - obs_space=obs_space, - action_space=action_space, - num_outputs=logit_dim, - model_config=self.config["model"], - framework="torch") - - # Make sure, we passed in a correct Model factory. - assert isinstance(self.model, TorchModelV2), \ - "ERROR: Generated Model must be a TorchModelV2 object!" - - TorchPolicy.__init__( - self, - observation_space=obs_space, - action_space=action_space, - config=config, - model=self.model, - loss=loss_fn, - action_distribution_class=dist_class, - action_sampler_fn=action_sampler_fn, - action_distribution_fn=action_distribution_fn, - max_seq_len=config["model"]["max_seq_len"], - get_batch_divisibility_req=get_batch_divisibility_req, - ) - - # Merge Model's view requirements into Policy's. - self.view_requirements.update( - self.model.inference_view_requirements) - - _before_loss_init = before_loss_init or after_init - if _before_loss_init: - _before_loss_init(self, self.observation_space, - self.action_space, config) - - # Perform test runs through postprocessing- and loss functions. - self._initialize_loss_from_dummy_batch( - auto_remove_unneeded_view_reqs=True, - stats_fn=stats_fn, - ) - - if _after_loss_init: - _after_loss_init(self, obs_space, action_space, config) - - # Got to reset global_timestep again after this fake run-through. - self.global_timestep = 0 - - @override(Policy) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - # Do all post-processing always with no_grad(). - # Not using this here will introduce a memory leak (issue #6962). - with torch.no_grad(): - # Call super's postprocess_trajectory first. - sample_batch = super().postprocess_trajectory( - sample_batch, other_agent_batches, episode) - if postprocess_fn: - return postprocess_fn(self, sample_batch, - other_agent_batches, episode) - - return sample_batch - - @override(TorchPolicy) - def extra_grad_process(self, optimizer, loss): - """Called after optimizer.zero_grad() and loss.backward() calls. - - Allows for gradient processing before optimizer.step() is called. - E.g. for gradient clipping. - """ - if extra_grad_process_fn: - return extra_grad_process_fn(self, optimizer, loss) - else: - return TorchPolicy.extra_grad_process(self, optimizer, loss) - - @override(TorchPolicy) - def extra_compute_grad_fetches(self): - if extra_learn_fetches_fn: - fetches = convert_to_non_torch_type( - extra_learn_fetches_fn(self)) - # Auto-add empty learner stats dict if needed. - return dict({LEARNER_STATS_KEY: {}}, **fetches) - else: - return TorchPolicy.extra_compute_grad_fetches(self) - - @override(TorchPolicy) - def apply_gradients(self, gradients): - if apply_gradients_fn: - apply_gradients_fn(self, gradients) - else: - TorchPolicy.apply_gradients(self, gradients) - - @override(TorchPolicy) - def extra_action_out(self, input_dict, state_batches, model, - action_dist): - with torch.no_grad(): - if extra_action_out_fn: - stats_dict = extra_action_out_fn( - self, input_dict, state_batches, model, action_dist) - else: - stats_dict = TorchPolicy.extra_action_out( - self, input_dict, state_batches, model, action_dist) - return convert_to_non_torch_type(stats_dict) - - @override(TorchPolicy) - def optimizer(self): - if optimizer_fn: - optimizers = optimizer_fn(self, self.config) - else: - optimizers = TorchPolicy.optimizer(self) - optimizers = force_list(optimizers) - if getattr(self, "exploration", None): - optimizers = self.exploration.get_exploration_optimizer( - optimizers) - return optimizers - - @override(TorchPolicy) - def extra_grad_info(self, train_batch): - with torch.no_grad(): - if stats_fn: - stats_dict = stats_fn(self, train_batch) - else: - stats_dict = TorchPolicy.extra_grad_info(self, train_batch) - return convert_to_non_torch_type(stats_dict) - - def with_updates(**overrides): - """Allows creating a TorchPolicy cls based on settings of another one. - - Keyword Args: - **overrides: The settings (passed into `build_torch_policy`) that - should be different from the class that this method is called - on. - - Returns: - type: A new TorchPolicy sub-class. - - Examples: - >> MySpecialDQNPolicyClass = DQNTorchPolicy.with_updates( - .. name="MySpecialDQNPolicyClass", - .. loss_function=[some_new_loss_function], - .. ) - """ - return build_torch_policy(**dict(original_kwargs, **overrides)) - - policy_cls.with_updates = staticmethod(with_updates) - policy_cls.__name__ = name - policy_cls.__qualname__ = name - return policy_cls + if log_once("deprecation_warning_build_torch_policy"): + deprecation_warning( + old="build_torch_policy", + new="build_policy_class(framework='torch')", + error=False) + kwargs = locals().copy() + # Set to torch and call new function. + kwargs["framework"] = "torch" + return build_policy_class(**kwargs) diff --git a/rllib/utils/__init__.py b/rllib/utils/__init__.py index e0c2dda3b9794..276cdeaf7061e 100644 --- a/rllib/utils/__init__.py +++ b/rllib/utils/__init__.py @@ -53,6 +53,19 @@ def force_list(elements=None, to_tuple=False): if type(elements) in [list, tuple] else ctor([elements]) +class NullContextManager: + """No-op context manager""" + + def __init__(self): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + force_tuple = partial(force_list, to_tuple=True) __all__ = [ diff --git a/rllib/utils/exploration/curiosity.py b/rllib/utils/exploration/curiosity.py index 175ce80193a0b..a9434e1a11748 100644 --- a/rllib/utils/exploration/curiosity.py +++ b/rllib/utils/exploration/curiosity.py @@ -4,12 +4,13 @@ from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.models.modelv2 import ModelV2, NullContextManager +from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import Categorical, MultiCategorical from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \ TorchMultiCategorical from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils import NullContextManager from ray.rllib.utils.annotations import override from ray.rllib.utils.exploration.exploration import Exploration from ray.rllib.utils.framework import get_activation_fn, try_import_tf, \ diff --git a/rllib/utils/framework.py b/rllib/utils/framework.py index 1d0a8afcd8ff2..8c7d223074b6a 100644 --- a/rllib/utils/framework.py +++ b/rllib/utils/framework.py @@ -16,13 +16,13 @@ def try_import_jax(error=False): - """Tries importing JAX and returns the module (or None). + """Tries importing JAX and FLAX and returns both modules (or Nones). Args: - error (bool): Whether to raise an error if JAX cannot be imported. + error (bool): Whether to raise an error if JAX/FLAX cannot be imported. Returns: - The jax module. + Tuple: The jax- and the flax modules. Raises: ImportError: If error=True and JAX is not installed.