New context manager: policy_within_training_step

Adjusted notebooks, log messages and docs accordingly. Removed now
obsolete in_eval_mode and the private context manager in Trainer
This commit is contained in:
Michael Panchenko 2024-05-06 16:50:48 +02:00
parent 78ea013956
commit e94a5c04cf
9 changed files with 90 additions and 66 deletions

View File

@ -44,7 +44,7 @@
the actions concatenated), which is essential for the case where we want the actions concatenated), which is essential for the case where we want
to reuse the actor's preprocessing network #1128 to reuse the actor's preprocessing network #1128
- `torch_utils` (new module) - `torch_utils` (new module)
- Added contextmanagers `in` - Added context managers `torch_train_mode` and `policy_within_training_step` #1123
### Fixes ### Fixes
- `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics, - `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics,

View File

@ -18,9 +18,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"# !pip install tianshou gym" "# !pip install tianshou gym"

View File

@ -74,7 +74,8 @@
")\n", ")\n",
"from tianshou.utils import RunningMeanStd\n", "from tianshou.utils import RunningMeanStd\n",
"from tianshou.utils.net.common import Net\n", "from tianshou.utils.net.common import Net\n",
"from tianshou.utils.net.discrete import Actor" "from tianshou.utils.net.discrete import Actor\n",
"from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode"
] ]
}, },
{ {
@ -644,6 +645,9 @@
"source": [ "source": [
"obs, info = env.reset()\n", "obs, info = env.reset()\n",
"for i in range(3, 10):\n", "for i in range(3, 10):\n",
" # For retrieving actions to be used for training, we set the policy to training mode,\n",
" # but the wrapped torch module should be in eval mode.\n",
" with policy_within_training_step(policy), torch_train_mode(policy, enabled=False):\n",
" act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n", " act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n",
" obs_next, rew, _, truncated, info = env.step(act)\n", " obs_next, rew, _, truncated, info = env.step(act)\n",
" # pretend this episode never end\n", " # pretend this episode never end\n",
@ -695,7 +699,11 @@
}, },
"source": [ "source": [
"#### Updates\n", "#### Updates\n",
"Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train." "Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train.\n",
"\n",
"However, we need to manually set the torch module to training mode prior to that, \n",
"and also declare that we are within a training step. Tianshou Trainers will take care of that automatically,\n",
"but users need to consider it when calling `.update` outside of the trainer."
] ]
}, },
{ {
@ -711,16 +719,11 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# 0 means sample all data from the buffer\n", "# 0 means sample all data from the buffer\n",
"policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()" "\n",
] "# For updating the policy, the policy should be in training mode\n",
}, "# and the wrapped torch module should also be in training mode (unlike when collecting data).\n",
{ "with policy_within_training_step(policy), torch_train_mode(policy):\n",
"cell_type": "markdown", " policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()"
"metadata": {
"id": "enqlFQLSJrQl"
},
"source": [
"Not that difficult, right?"
] ]
}, },
{ {

View File

@ -54,7 +54,6 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"editable": true, "editable": true,
"id": "do-xZ-8B7nVH", "id": "do-xZ-8B7nVH",
@ -64,9 +63,12 @@
"tags": [ "tags": [
"hide-cell", "hide-cell",
"remove-output" "remove-output"
] ],
"ExecuteTime": {
"end_time": "2024-05-06T15:34:02.969675Z",
"start_time": "2024-05-06T15:34:00.747309Z"
}
}, },
"outputs": [],
"source": [ "source": [
"%%capture\n", "%%capture\n",
"\n", "\n",
@ -78,14 +80,20 @@
"from tianshou.policy import PGPolicy\n", "from tianshou.policy import PGPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.trainer import OnpolicyTrainer\n",
"from tianshou.utils.net.common import Net\n", "from tianshou.utils.net.common import Net\n",
"from tianshou.utils.net.discrete import Actor" "from tianshou.utils.net.discrete import Actor\n",
] "from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode"
],
"outputs": [],
"execution_count": 1
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "metadata": {
"metadata": {}, "ExecuteTime": {
"outputs": [], "end_time": "2024-05-06T15:34:07.536452Z",
"start_time": "2024-05-06T15:34:03.636670Z"
}
},
"source": [ "source": [
"train_env_num = 4\n", "train_env_num = 4\n",
"buffer_size = (\n", "buffer_size = (\n",
@ -123,7 +131,9 @@
"replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", "replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
"test_collector = Collector(policy, test_envs)\n", "test_collector = Collector(policy, test_envs)\n",
"train_collector = Collector(policy, train_envs, replayBuffer)" "train_collector = Collector(policy, train_envs, replayBuffer)"
] ],
"outputs": [],
"execution_count": 2
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -154,10 +164,18 @@
"\n", "\n",
"n_episode = 10\n", "n_episode = 10\n",
"for _i in range(n_episode):\n", "for _i in range(n_episode):\n",
" # for test collector, we set the wrapped torch module to evaluation mode\n",
" # by default, the policy object itself is not within the training step\n",
" with torch_train_mode(policy, enabled=False):\n",
" evaluation_result = test_collector.collect(n_episode=n_episode)\n", " evaluation_result = test_collector.collect(n_episode=n_episode)\n",
" print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n", " print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n",
" # for collecting data for training, the policy object should be within the training step\n",
" # (affecting e.g. whether the policy is stochastic or deterministic)\n",
" with policy_within_training_step(policy):\n",
" train_collector.collect(n_step=2000)\n", " train_collector.collect(n_step=2000)\n",
" # 0 means taking all data stored in train_collector.buffer\n", " # 0 means taking all data stored in train_collector.buffer\n",
" # for updating the policy, the wrapped torch module should be in training mode\n",
" with torch_train_mode(policy):\n",
" policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n", " policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n",
" train_collector.reset_buffer(keep_statistics=True)" " train_collector.reset_buffer(keep_statistics=True)"
] ]

View File

@ -6,7 +6,7 @@ from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd
from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
from tianshou.utils.torch_utils import in_eval_mode, in_train_mode from tianshou.utils.torch_utils import torch_train_mode
def test_noise() -> None: def test_noise() -> None:
@ -137,7 +137,7 @@ def test_lr_schedulers() -> None:
def test_in_eval_mode() -> None: def test_in_eval_mode() -> None:
module = nn.Linear(3, 4) module = nn.Linear(3, 4)
module.train() module.train()
with in_eval_mode(module): with torch_train_mode(module, False):
assert not module.training assert not module.training
assert module.training assert module.training
@ -145,6 +145,6 @@ def test_in_eval_mode() -> None:
def test_in_train_mode() -> None: def test_in_train_mode() -> None:
module = nn.Linear(3, 4) module = nn.Linear(3, 4)
module.eval() module.eval()
with in_train_mode(module): with torch_train_mode(module):
assert module.training assert module.training
assert not module.training assert not module.training

View File

@ -27,7 +27,7 @@ from tianshou.data.types import (
from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.env import BaseVectorEnv, DummyVectorEnv
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.print import DataclassPPrintMixin
from tianshou.utils.torch_utils import in_eval_mode from tianshou.utils.torch_utils import torch_train_mode
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -300,7 +300,7 @@ class BaseCollector(ABC):
if reset_before_collect: if reset_before_collect:
self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs)
with in_eval_mode(self.policy): # safety precaution only with torch_train_mode(self.policy, False):
return self._collect( return self._collect(
n_step=n_step, n_step=n_step,
n_episode=n_episode, n_episode=n_episode,

View File

@ -25,7 +25,7 @@ from tianshou.data.types import (
) )
from tianshou.utils import MultipleLRSchedulers from tianshou.utils import MultipleLRSchedulers
from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.print import DataclassPPrintMixin
from tianshou.utils.torch_utils import in_train_mode from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -532,7 +532,7 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
raise RuntimeError( raise RuntimeError(
f"update() was called outside of a training step as signalled by {self.is_within_training_step=} " f"update() was called outside of a training step as signalled by {self.is_within_training_step=} "
f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned " f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned "
f"flag yourself.", f"flag yourself. You can to this e.g., by using the contextmanager {policy_within_training_step.__name__}.",
) )
if buffer is None: if buffer is None:
@ -541,7 +541,7 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
batch, indices = buffer.sample(sample_size) batch, indices = buffer.sample(sample_size)
self.updating = True self.updating = True
batch = self.process_fn(batch, buffer, indices) batch = self.process_fn(batch, buffer, indices)
with in_train_mode(self): with torch_train_mode(self):
training_stat = self.learn(batch, **kwargs) training_stat = self.learn(batch, **kwargs)
self.post_process_fn(batch, buffer, indices) self.post_process_fn(batch, buffer, indices)
if self.lr_scheduler is not None: if self.lr_scheduler is not None:

View File

@ -2,8 +2,7 @@ import logging
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Callable, Iterator from collections.abc import Callable
from contextlib import contextmanager
from dataclasses import asdict from dataclasses import asdict
import numpy as np import numpy as np
@ -29,6 +28,7 @@ from tianshou.utils import (
tqdm_config, tqdm_config,
) )
from tianshou.utils.logging import set_numerical_fields_to_precision from tianshou.utils.logging import set_numerical_fields_to_precision
from tianshou.utils.torch_utils import policy_within_training_step
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -404,15 +404,6 @@ class BaseTrainer(ABC):
return test_stat, stop_fn_flag return test_stat, stop_fn_flag
@contextmanager
def _is_within_training_step_enabled(self, is_within_training_step: bool) -> Iterator[None]:
old_value = self.policy.is_within_training_step
try:
self.policy.is_within_training_step = is_within_training_step
yield
finally:
self.policy.is_within_training_step = old_value
def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]: def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]:
"""Perform one training iteration. """Perform one training iteration.
@ -422,7 +413,7 @@ class BaseTrainer(ABC):
:return: the iteration's collect stats, training stats, and a flag indicating whether to stop training. :return: the iteration's collect stats, training stats, and a flag indicating whether to stop training.
If training is to be stopped, no gradient steps will be performed and the training stats will be `None`. If training is to be stopped, no gradient steps will be performed and the training stats will be `None`.
""" """
with self._is_within_training_step_enabled(True): with policy_within_training_step(self.policy):
should_stop_training = False should_stop_training = False
collect_stats: CollectStatsBase | CollectStats collect_stats: CollectStatsBase | CollectStats
@ -474,6 +465,7 @@ class BaseTrainer(ABC):
return collect_stats return collect_stats
# TODO (maybe): separate out side effect, simplify name?
def _update_best_reward_and_return_should_stop_training( def _update_best_reward_and_return_should_stop_training(
self, self,
collect_stats: CollectStats, collect_stats: CollectStats,
@ -492,7 +484,7 @@ class BaseTrainer(ABC):
should_stop_training = False should_stop_training = False
# Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics # Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics
with self._is_within_training_step_enabled(False): with policy_within_training_step(self.policy, enabled=False):
if ( if (
collect_stats.n_collected_episodes > 0 collect_stats.n_collected_episodes > 0
and self.test_in_train and self.test_in_train

View File

@ -1,26 +1,39 @@
from collections.abc import Iterator from collections.abc import Iterator
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING
from torch import nn from torch import nn
if TYPE_CHECKING:
@contextmanager from tianshou.policy import BasePolicy
def in_eval_mode(module: nn.Module) -> Iterator[None]:
"""Temporarily switch to evaluation mode."""
train = module.training
try:
module.eval()
yield
finally:
module.train(train)
@contextmanager @contextmanager
def in_train_mode(module: nn.Module) -> Iterator[None]: def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]:
"""Temporarily switch to training mode.""" """Temporarily switch to `module.training=enabled`, affecting things like `BatchNormalization`."""
train = module.training original_mode = module.training
try: try:
module.train() module.train(enabled)
yield yield
finally: finally:
module.train(train) module.train(original_mode)
@contextmanager
def policy_within_training_step(policy: "BasePolicy", enabled: bool = True) -> Iterator[None]:
"""Temporarily switch to `policy.is_within_training_step=enabled`.
Enabling this ensures that the policy is able to adapt its behavior,
allowing it to differentiate between training and inference/evaluation,
e.g., to sample actions instead of using the most probable action (where applicable)
Note that for rollout, which also happens within a training step, one would usually want
the wrapped torch module to be in evaluation mode, which can be achieved using
`with torch_train_mode(policy, False)`. For subsequent gradient updates, the policy should be both
within training step and in torch train mode.
"""
original_mode = policy.is_within_training_step
try:
policy.is_within_training_step = enabled
yield
finally:
policy.is_within_training_step = original_mode