diff --git a/docs/02_notebooks/L3_Vectorized__Environment.ipynb b/docs/02_notebooks/L3_Vectorized__Environment.ipynb index d2f6b915c..a83b0bd9c 100644 --- a/docs/02_notebooks/L3_Vectorized__Environment.ipynb +++ b/docs/02_notebooks/L3_Vectorized__Environment.ipynb @@ -197,10 +197,10 @@ "* ShmemVectorEnv: use share memory instead of pipe based on SubprocVectorEnv;\n", "* RayVectorEnv: use Ray for concurrent activities and is currently the only choice for parallel simulation in a cluster with multiple machines.\n", "\n", - "Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.env.html) for details.\n", + "Check the [documentation](https://tianshou.org/en/master/03_api/env/venvs.html) for details.\n", "\n", "### Difference between synchronous and asynchronous mode (How to choose?)\n", - "Explanation can be found at the [Parallel Sampling](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling) tutorial." + "Explanation can be found at the [Parallel Sampling](https://tianshou.org/en/master/01_tutorials/07_cheatsheet.html#parallel-sampling) tutorial." ] } ], @@ -223,7 +223,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/docs/02_notebooks/L5_Collector.ipynb b/docs/02_notebooks/L5_Collector.ipynb index 7da98a5cf..aa86d685c 100644 --- a/docs/02_notebooks/L5_Collector.ipynb +++ b/docs/02_notebooks/L5_Collector.ipynb @@ -247,7 +247,7 @@ }, "source": [ "## Further Reading\n", - "The above collector actually collects 52 data at a time because 52 % 4 = 0. There is one asynchronous collector which allows you collect exactly 50 steps. Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.data.html#asynccollector) for details." + "The above collector actually collects 52 data at a time because 52 % 4 = 0. There is one asynchronous collector which allows you collect exactly 50 steps. Check the [documentation](https://tianshou.org/en/master/03_api/data/collector.html#tianshou.data.collector.AsyncCollector) for details." ] } ], diff --git a/docs/02_notebooks/L6_Trainer.ipynb b/docs/02_notebooks/L6_Trainer.ipynb index d5423bd01..edc161981 100644 --- a/docs/02_notebooks/L6_Trainer.ipynb +++ b/docs/02_notebooks/L6_Trainer.ipynb @@ -54,7 +54,12 @@ }, { "cell_type": "code", + "execution_count": 1, "metadata": { + "ExecuteTime": { + "end_time": "2024-05-06T15:34:02.969675Z", + "start_time": "2024-05-06T15:34:00.747309Z" + }, "editable": true, "id": "do-xZ-8B7nVH", "slideshow": { @@ -63,12 +68,9 @@ "tags": [ "hide-cell", "remove-output" - ], - "ExecuteTime": { - "end_time": "2024-05-06T15:34:02.969675Z", - "start_time": "2024-05-06T15:34:00.747309Z" - } + ] }, + "outputs": [], "source": [ "%%capture\n", "\n", @@ -82,18 +84,18 @@ "from tianshou.utils.net.common import Net\n", "from tianshou.utils.net.discrete import Actor\n", "from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode" - ], - "outputs": [], - "execution_count": 1 + ] }, { "cell_type": "code", + "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2024-05-06T15:34:07.536452Z", "start_time": "2024-05-06T15:34:03.636670Z" } }, + "outputs": [], "source": [ "train_env_num = 4\n", "buffer_size = (\n", @@ -131,9 +133,7 @@ "replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", "test_collector = Collector(policy, test_envs)\n", "train_collector = Collector(policy, train_envs, replayBuffer)" - ], - "outputs": [], - "execution_count": 2 + ] }, { "cell_type": "markdown", @@ -252,10 +252,10 @@ "source": [ "## Further Reading\n", "### Logger usages\n", - "Tianshou provides experiment loggers that are both tensorboard- and wandb-compatible. It also has a BaseLogger Class which allows you to self-define your own logger. Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.utils.html#tianshou.utils.BaseLogger) for details.\n", + "Tianshou provides experiment loggers that are both tensorboard- and wandb-compatible. It also has a BaseLogger Class which allows you to self-define your own logger. Check the [documentation](https://tianshou.org/en/master/03_api/utils/logger/base.html#tianshou.utils.logger.base.BaseLogger) for details.\n", "\n", "### Learn more about the APIs of Trainers\n", - "[documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.trainer.html)" + "[documentation](https://tianshou.org/en/master/03_api/trainer/index.html)" ] } ], diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 7531616a5..4df6b584a 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -36,7 +36,7 @@ from sensai.util.logging import datetime_tag from sensai.util.string import ToStringMixin -from tianshou.data import Collector, InfoStats +from tianshou.data import BaseCollector, Collector, InfoStats from tianshou.env import BaseVectorEnv from tianshou.highlevel.agent import ( A2CAgentFactory, @@ -111,13 +111,14 @@ from tianshou.policy import BasePolicy from tianshou.utils import LazyLogger from tianshou.utils.net.common import ModuleType +from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.warning import deprecation log = logging.getLogger(__name__) @dataclass -class ExperimentConfig: +class ExperimentConfig(ToStringMixin, DataclassPPrintMixin): """Generic config for setting up the experiment, not RL or training specific.""" seed: int = 42 @@ -161,7 +162,7 @@ class ExperimentResult: """dataclass of results as returned by the trainer (if any)""" -class Experiment(ToStringMixin): +class Experiment(ToStringMixin, DataclassPPrintMixin): """Represents a reinforcement learning experiment. An experiment is composed only of configuration and factory objects, which themselves @@ -333,12 +334,16 @@ def create_experiment_world( # create policy and collectors log.info("Creating policy") policy = self.agent_factory.create_policy(envs, self.config.device) + log.info("Creating collectors") - train_collector, test_collector = self.agent_factory.create_train_test_collector( - policy, - envs, - reset_collectors=reset_collectors, - ) + train_collector: BaseCollector | None = None + test_collector: BaseCollector | None = None + if self.config.train: + train_collector, test_collector = self.agent_factory.create_train_test_collector( + policy, + envs, + reset_collectors=reset_collectors, + ) # create context object with all relevant instances (except trainer; added later) world = World( @@ -414,6 +419,10 @@ def run( ): trainer_result: InfoStats | None = None if self.config.train: + assert world.trainer is not None + assert world.train_collector is not None + assert world.test_collector is not None + # prefilling buffers with either random or current agent's actions if self.sampling_config.start_timesteps > 0: log.info( @@ -426,7 +435,6 @@ def run( ) log.info("Starting training") - assert world.trainer is not None world.trainer.run() if use_persistence: world.logger.finalize() diff --git a/tianshou/highlevel/world.py b/tianshou/highlevel/world.py index c32ef9cbc..6db216b15 100644 --- a/tianshou/highlevel/world.py +++ b/tianshou/highlevel/world.py @@ -10,14 +10,14 @@ from tianshou.trainer import BaseTrainer -@dataclass +@dataclass(kw_only=True) class World: """Container for instances and configuration items that are relevant to an experiment.""" envs: "Environments" policy: "BasePolicy" - train_collector: "BaseCollector" - test_collector: "BaseCollector" + train_collector: Optional["BaseCollector"] = None + test_collector: Optional["BaseCollector"] = None logger: "TLogger" persist_directory: str restore_directory: str | None diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index a6679fa20..c2f65cb7a 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -406,19 +406,14 @@ def test_step(self) -> tuple[CollectStats, bool]: self.best_reward_std = rew_std if self.save_best_fn: self.save_best_fn(self.policy) + cur_info, best_info = "", "" if score != rew: - # use custom score calculater - log_msg = ( - f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, score: {score:.6f}," - f" best_reward: {self.best_reward:.6f} ± " - f"{self.best_reward_std:.6f}, score: {self.best_score:.6f} in #{self.best_epoch}" - ) - else: - log_msg = ( - f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," - f" best_reward: {self.best_reward:.6f} ± " - f"{self.best_reward_std:.6f} in #{self.best_epoch}" - ) + cur_info, best_info = f", score: {score: .6f}", f", best_score: {self.best_score:.6f}" + log_msg = ( + f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},{cur_info}" + f" best_reward: {self.best_reward:.6f} ± " + f"{self.best_reward_std:.6f}{best_info} in #{self.best_epoch}" + ) log.info(log_msg) if self.verbose: print(log_msg, flush=True)