From e94a5c04cf93085c3df18e2d623f7ba465d20489 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 6 May 2024 16:50:48 +0200 Subject: [PATCH] New context manager: policy_within_training_step Adjusted notebooks, log messages and docs accordingly. Removed now obsolete in_eval_mode and the private context manager in Trainer --- CHANGELOG.md | 2 +- docs/02_notebooks/L0_overview.ipynb | 4 +-- docs/02_notebooks/L4_Policy.ipynb | 29 ++++++++++--------- docs/02_notebooks/L6_Trainer.ipynb | 44 ++++++++++++++++++++--------- test/base/test_utils.py | 6 ++-- tianshou/data/collector.py | 4 +-- tianshou/policy/base.py | 6 ++-- tianshou/trainer/base.py | 18 ++++-------- tianshou/utils/torch_utils.py | 43 ++++++++++++++++++---------- 9 files changed, 90 insertions(+), 66 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d74e990..807a9da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,7 +44,7 @@ the actions concatenated), which is essential for the case where we want to reuse the actor's preprocessing network #1128 - `torch_utils` (new module) - - Added contextmanagers `in` + - Added context managers `torch_train_mode` and `policy_within_training_step` #1123 ### Fixes - `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics, diff --git a/docs/02_notebooks/L0_overview.ipynb b/docs/02_notebooks/L0_overview.ipynb index 59d6fd2..0ce6df1 100644 --- a/docs/02_notebooks/L0_overview.ipynb +++ b/docs/02_notebooks/L0_overview.ipynb @@ -18,9 +18,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [], "source": [ "# !pip install tianshou gym" diff --git a/docs/02_notebooks/L4_Policy.ipynb b/docs/02_notebooks/L4_Policy.ipynb index 00f7f27..eed8ea3 100644 --- a/docs/02_notebooks/L4_Policy.ipynb +++ b/docs/02_notebooks/L4_Policy.ipynb @@ -74,7 +74,8 @@ ")\n", "from tianshou.utils import RunningMeanStd\n", "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor" + "from tianshou.utils.net.discrete import Actor\n", + "from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode" ] }, { @@ -644,7 +645,10 @@ "source": [ "obs, info = env.reset()\n", "for i in range(3, 10):\n", - " act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n", + " # For retrieving actions to be used for training, we set the policy to training mode,\n", + " # but the wrapped torch module should be in eval mode.\n", + " with policy_within_training_step(policy), torch_train_mode(policy, enabled=False):\n", + " act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n", " obs_next, rew, _, truncated, info = env.step(act)\n", " # pretend this episode never end\n", " terminated = False\n", @@ -695,7 +699,11 @@ }, "source": [ "#### Updates\n", - "Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train." + "Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train.\n", + "\n", + "However, we need to manually set the torch module to training mode prior to that, \n", + "and also declare that we are within a training step. Tianshou Trainers will take care of that automatically,\n", + "but users need to consider it when calling `.update` outside of the trainer." ] }, { @@ -711,16 +719,11 @@ "outputs": [], "source": [ "# 0 means sample all data from the buffer\n", - "policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "enqlFQLSJrQl" - }, - "source": [ - "Not that difficult, right?" + "\n", + "# For updating the policy, the policy should be in training mode\n", + "# and the wrapped torch module should also be in training mode (unlike when collecting data).\n", + "with policy_within_training_step(policy), torch_train_mode(policy):\n", + " policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()" ] }, { diff --git a/docs/02_notebooks/L6_Trainer.ipynb b/docs/02_notebooks/L6_Trainer.ipynb index 75aea47..d5423bd 100644 --- a/docs/02_notebooks/L6_Trainer.ipynb +++ b/docs/02_notebooks/L6_Trainer.ipynb @@ -54,7 +54,6 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": { "editable": true, "id": "do-xZ-8B7nVH", @@ -64,9 +63,12 @@ "tags": [ "hide-cell", "remove-output" - ] + ], + "ExecuteTime": { + "end_time": "2024-05-06T15:34:02.969675Z", + "start_time": "2024-05-06T15:34:00.747309Z" + } }, - "outputs": [], "source": [ "%%capture\n", "\n", @@ -78,14 +80,20 @@ "from tianshou.policy import PGPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor" - ] + "from tianshou.utils.net.discrete import Actor\n", + "from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode" + ], + "outputs": [], + "execution_count": 1 }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-05-06T15:34:07.536452Z", + "start_time": "2024-05-06T15:34:03.636670Z" + } + }, "source": [ "train_env_num = 4\n", "buffer_size = (\n", @@ -123,7 +131,9 @@ "replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", "test_collector = Collector(policy, test_envs)\n", "train_collector = Collector(policy, train_envs, replayBuffer)" - ] + ], + "outputs": [], + "execution_count": 2 }, { "cell_type": "markdown", @@ -154,11 +164,19 @@ "\n", "n_episode = 10\n", "for _i in range(n_episode):\n", - " evaluation_result = test_collector.collect(n_episode=n_episode)\n", + " # for test collector, we set the wrapped torch module to evaluation mode\n", + " # by default, the policy object itself is not within the training step\n", + " with torch_train_mode(policy, enabled=False):\n", + " evaluation_result = test_collector.collect(n_episode=n_episode)\n", " print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n", - " train_collector.collect(n_step=2000)\n", - " # 0 means taking all data stored in train_collector.buffer\n", - " policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n", + " # for collecting data for training, the policy object should be within the training step\n", + " # (affecting e.g. whether the policy is stochastic or deterministic)\n", + " with policy_within_training_step(policy):\n", + " train_collector.collect(n_step=2000)\n", + " # 0 means taking all data stored in train_collector.buffer\n", + " # for updating the policy, the wrapped torch module should be in training mode\n", + " with torch_train_mode(policy):\n", + " policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n", " train_collector.reset_buffer(keep_statistics=True)" ] }, diff --git a/test/base/test_utils.py b/test/base/test_utils.py index ac3b2fa..f8e5938 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -6,7 +6,7 @@ from tianshou.exploration import GaussianNoise, OUNoise from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic -from tianshou.utils.torch_utils import in_eval_mode, in_train_mode +from tianshou.utils.torch_utils import torch_train_mode def test_noise() -> None: @@ -137,7 +137,7 @@ def test_lr_schedulers() -> None: def test_in_eval_mode() -> None: module = nn.Linear(3, 4) module.train() - with in_eval_mode(module): + with torch_train_mode(module, False): assert not module.training assert module.training @@ -145,6 +145,6 @@ def test_in_eval_mode() -> None: def test_in_train_mode() -> None: module = nn.Linear(3, 4) module.eval() - with in_train_mode(module): + with torch_train_mode(module): assert module.training assert not module.training diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 5bce6c0..6773a63 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -27,7 +27,7 @@ from tianshou.data.types import ( from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.policy import BasePolicy from tianshou.utils.print import DataclassPPrintMixin -from tianshou.utils.torch_utils import in_eval_mode +from tianshou.utils.torch_utils import torch_train_mode log = logging.getLogger(__name__) @@ -300,7 +300,7 @@ class BaseCollector(ABC): if reset_before_collect: self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) - with in_eval_mode(self.policy): # safety precaution only + with torch_train_mode(self.policy, False): return self._collect( n_step=n_step, n_episode=n_episode, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index bee9f9b..b7ae5f2 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -25,7 +25,7 @@ from tianshou.data.types import ( ) from tianshou.utils import MultipleLRSchedulers from tianshou.utils.print import DataclassPPrintMixin -from tianshou.utils.torch_utils import in_train_mode +from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode logger = logging.getLogger(__name__) @@ -532,7 +532,7 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): raise RuntimeError( f"update() was called outside of a training step as signalled by {self.is_within_training_step=} " f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned " - f"flag yourself.", + f"flag yourself. You can to this e.g., by using the contextmanager {policy_within_training_step.__name__}.", ) if buffer is None: @@ -541,7 +541,7 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): batch, indices = buffer.sample(sample_size) self.updating = True batch = self.process_fn(batch, buffer, indices) - with in_train_mode(self): + with torch_train_mode(self): training_stat = self.learn(batch, **kwargs) self.post_process_fn(batch, buffer, indices) if self.lr_scheduler is not None: diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index b738df5..242f2b0 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -2,8 +2,7 @@ import logging import time from abc import ABC, abstractmethod from collections import defaultdict, deque -from collections.abc import Callable, Iterator -from contextlib import contextmanager +from collections.abc import Callable from dataclasses import asdict import numpy as np @@ -29,6 +28,7 @@ from tianshou.utils import ( tqdm_config, ) from tianshou.utils.logging import set_numerical_fields_to_precision +from tianshou.utils.torch_utils import policy_within_training_step log = logging.getLogger(__name__) @@ -404,15 +404,6 @@ class BaseTrainer(ABC): return test_stat, stop_fn_flag - @contextmanager - def _is_within_training_step_enabled(self, is_within_training_step: bool) -> Iterator[None]: - old_value = self.policy.is_within_training_step - try: - self.policy.is_within_training_step = is_within_training_step - yield - finally: - self.policy.is_within_training_step = old_value - def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]: """Perform one training iteration. @@ -422,7 +413,7 @@ class BaseTrainer(ABC): :return: the iteration's collect stats, training stats, and a flag indicating whether to stop training. If training is to be stopped, no gradient steps will be performed and the training stats will be `None`. """ - with self._is_within_training_step_enabled(True): + with policy_within_training_step(self.policy): should_stop_training = False collect_stats: CollectStatsBase | CollectStats @@ -474,6 +465,7 @@ class BaseTrainer(ABC): return collect_stats + # TODO (maybe): separate out side effect, simplify name? def _update_best_reward_and_return_should_stop_training( self, collect_stats: CollectStats, @@ -492,7 +484,7 @@ class BaseTrainer(ABC): should_stop_training = False # Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics - with self._is_within_training_step_enabled(False): + with policy_within_training_step(self.policy, enabled=False): if ( collect_stats.n_collected_episodes > 0 and self.test_in_train diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index 2fb70da..430d174 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -1,26 +1,39 @@ from collections.abc import Iterator from contextlib import contextmanager +from typing import TYPE_CHECKING from torch import nn - -@contextmanager -def in_eval_mode(module: nn.Module) -> Iterator[None]: - """Temporarily switch to evaluation mode.""" - train = module.training - try: - module.eval() - yield - finally: - module.train(train) +if TYPE_CHECKING: + from tianshou.policy import BasePolicy @contextmanager -def in_train_mode(module: nn.Module) -> Iterator[None]: - """Temporarily switch to training mode.""" - train = module.training +def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]: + """Temporarily switch to `module.training=enabled`, affecting things like `BatchNormalization`.""" + original_mode = module.training try: - module.train() + module.train(enabled) yield finally: - module.train(train) + module.train(original_mode) + + +@contextmanager +def policy_within_training_step(policy: "BasePolicy", enabled: bool = True) -> Iterator[None]: + """Temporarily switch to `policy.is_within_training_step=enabled`. + + Enabling this ensures that the policy is able to adapt its behavior, + allowing it to differentiate between training and inference/evaluation, + e.g., to sample actions instead of using the most probable action (where applicable) + Note that for rollout, which also happens within a training step, one would usually want + the wrapped torch module to be in evaluation mode, which can be achieved using + `with torch_train_mode(policy, False)`. For subsequent gradient updates, the policy should be both + within training step and in torch train mode. + """ + original_mode = policy.is_within_training_step + try: + policy.is_within_training_step = enabled + yield + finally: + policy.is_within_training_step = original_mode