New context manager: policy_within_training_step

Adjusted notebooks, log messages and docs accordingly. Removed now
obsolete in_eval_mode and the private context manager in Trainer
This commit is contained in:
Michael Panchenko 2024-05-06 16:50:48 +02:00
parent 78ea013956
commit e94a5c04cf
9 changed files with 90 additions and 66 deletions

View File

@ -44,7 +44,7 @@
the actions concatenated), which is essential for the case where we want
to reuse the actor's preprocessing network #1128
- `torch_utils` (new module)
- Added contextmanagers `in`
- Added context managers `torch_train_mode` and `policy_within_training_step` #1123
### Fixes
- `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics,

View File

@ -18,9 +18,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# !pip install tianshou gym"

View File

@ -74,7 +74,8 @@
")\n",
"from tianshou.utils import RunningMeanStd\n",
"from tianshou.utils.net.common import Net\n",
"from tianshou.utils.net.discrete import Actor"
"from tianshou.utils.net.discrete import Actor\n",
"from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode"
]
},
{
@ -644,7 +645,10 @@
"source": [
"obs, info = env.reset()\n",
"for i in range(3, 10):\n",
" act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n",
" # For retrieving actions to be used for training, we set the policy to training mode,\n",
" # but the wrapped torch module should be in eval mode.\n",
" with policy_within_training_step(policy), torch_train_mode(policy, enabled=False):\n",
" act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n",
" obs_next, rew, _, truncated, info = env.step(act)\n",
" # pretend this episode never end\n",
" terminated = False\n",
@ -695,7 +699,11 @@
},
"source": [
"#### Updates\n",
"Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train."
"Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train.\n",
"\n",
"However, we need to manually set the torch module to training mode prior to that, \n",
"and also declare that we are within a training step. Tianshou Trainers will take care of that automatically,\n",
"but users need to consider it when calling `.update` outside of the trainer."
]
},
{
@ -711,16 +719,11 @@
"outputs": [],
"source": [
"# 0 means sample all data from the buffer\n",
"policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "enqlFQLSJrQl"
},
"source": [
"Not that difficult, right?"
"\n",
"# For updating the policy, the policy should be in training mode\n",
"# and the wrapped torch module should also be in training mode (unlike when collecting data).\n",
"with policy_within_training_step(policy), torch_train_mode(policy):\n",
" policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()"
]
},
{

View File

@ -54,7 +54,6 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"editable": true,
"id": "do-xZ-8B7nVH",
@ -64,9 +63,12 @@
"tags": [
"hide-cell",
"remove-output"
]
],
"ExecuteTime": {
"end_time": "2024-05-06T15:34:02.969675Z",
"start_time": "2024-05-06T15:34:00.747309Z"
}
},
"outputs": [],
"source": [
"%%capture\n",
"\n",
@ -78,14 +80,20 @@
"from tianshou.policy import PGPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n",
"from tianshou.utils.net.common import Net\n",
"from tianshou.utils.net.discrete import Actor"
]
"from tianshou.utils.net.discrete import Actor\n",
"from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode"
],
"outputs": [],
"execution_count": 1
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-06T15:34:07.536452Z",
"start_time": "2024-05-06T15:34:03.636670Z"
}
},
"source": [
"train_env_num = 4\n",
"buffer_size = (\n",
@ -123,7 +131,9 @@
"replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
"test_collector = Collector(policy, test_envs)\n",
"train_collector = Collector(policy, train_envs, replayBuffer)"
]
],
"outputs": [],
"execution_count": 2
},
{
"cell_type": "markdown",
@ -154,11 +164,19 @@
"\n",
"n_episode = 10\n",
"for _i in range(n_episode):\n",
" evaluation_result = test_collector.collect(n_episode=n_episode)\n",
" # for test collector, we set the wrapped torch module to evaluation mode\n",
" # by default, the policy object itself is not within the training step\n",
" with torch_train_mode(policy, enabled=False):\n",
" evaluation_result = test_collector.collect(n_episode=n_episode)\n",
" print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n",
" train_collector.collect(n_step=2000)\n",
" # 0 means taking all data stored in train_collector.buffer\n",
" policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n",
" # for collecting data for training, the policy object should be within the training step\n",
" # (affecting e.g. whether the policy is stochastic or deterministic)\n",
" with policy_within_training_step(policy):\n",
" train_collector.collect(n_step=2000)\n",
" # 0 means taking all data stored in train_collector.buffer\n",
" # for updating the policy, the wrapped torch module should be in training mode\n",
" with torch_train_mode(policy):\n",
" policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n",
" train_collector.reset_buffer(keep_statistics=True)"
]
},

View File

@ -6,7 +6,7 @@ from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd
from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
from tianshou.utils.torch_utils import in_eval_mode, in_train_mode
from tianshou.utils.torch_utils import torch_train_mode
def test_noise() -> None:
@ -137,7 +137,7 @@ def test_lr_schedulers() -> None:
def test_in_eval_mode() -> None:
module = nn.Linear(3, 4)
module.train()
with in_eval_mode(module):
with torch_train_mode(module, False):
assert not module.training
assert module.training
@ -145,6 +145,6 @@ def test_in_eval_mode() -> None:
def test_in_train_mode() -> None:
module = nn.Linear(3, 4)
module.eval()
with in_train_mode(module):
with torch_train_mode(module):
assert module.training
assert not module.training

View File

@ -27,7 +27,7 @@ from tianshou.data.types import (
from tianshou.env import BaseVectorEnv, DummyVectorEnv
from tianshou.policy import BasePolicy
from tianshou.utils.print import DataclassPPrintMixin
from tianshou.utils.torch_utils import in_eval_mode
from tianshou.utils.torch_utils import torch_train_mode
log = logging.getLogger(__name__)
@ -300,7 +300,7 @@ class BaseCollector(ABC):
if reset_before_collect:
self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs)
with in_eval_mode(self.policy): # safety precaution only
with torch_train_mode(self.policy, False):
return self._collect(
n_step=n_step,
n_episode=n_episode,

View File

@ -25,7 +25,7 @@ from tianshou.data.types import (
)
from tianshou.utils import MultipleLRSchedulers
from tianshou.utils.print import DataclassPPrintMixin
from tianshou.utils.torch_utils import in_train_mode
from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode
logger = logging.getLogger(__name__)
@ -532,7 +532,7 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
raise RuntimeError(
f"update() was called outside of a training step as signalled by {self.is_within_training_step=} "
f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned "
f"flag yourself.",
f"flag yourself. You can to this e.g., by using the contextmanager {policy_within_training_step.__name__}.",
)
if buffer is None:
@ -541,7 +541,7 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
batch, indices = buffer.sample(sample_size)
self.updating = True
batch = self.process_fn(batch, buffer, indices)
with in_train_mode(self):
with torch_train_mode(self):
training_stat = self.learn(batch, **kwargs)
self.post_process_fn(batch, buffer, indices)
if self.lr_scheduler is not None:

View File

@ -2,8 +2,7 @@ import logging
import time
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from collections.abc import Callable
from dataclasses import asdict
import numpy as np
@ -29,6 +28,7 @@ from tianshou.utils import (
tqdm_config,
)
from tianshou.utils.logging import set_numerical_fields_to_precision
from tianshou.utils.torch_utils import policy_within_training_step
log = logging.getLogger(__name__)
@ -404,15 +404,6 @@ class BaseTrainer(ABC):
return test_stat, stop_fn_flag
@contextmanager
def _is_within_training_step_enabled(self, is_within_training_step: bool) -> Iterator[None]:
old_value = self.policy.is_within_training_step
try:
self.policy.is_within_training_step = is_within_training_step
yield
finally:
self.policy.is_within_training_step = old_value
def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]:
"""Perform one training iteration.
@ -422,7 +413,7 @@ class BaseTrainer(ABC):
:return: the iteration's collect stats, training stats, and a flag indicating whether to stop training.
If training is to be stopped, no gradient steps will be performed and the training stats will be `None`.
"""
with self._is_within_training_step_enabled(True):
with policy_within_training_step(self.policy):
should_stop_training = False
collect_stats: CollectStatsBase | CollectStats
@ -474,6 +465,7 @@ class BaseTrainer(ABC):
return collect_stats
# TODO (maybe): separate out side effect, simplify name?
def _update_best_reward_and_return_should_stop_training(
self,
collect_stats: CollectStats,
@ -492,7 +484,7 @@ class BaseTrainer(ABC):
should_stop_training = False
# Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics
with self._is_within_training_step_enabled(False):
with policy_within_training_step(self.policy, enabled=False):
if (
collect_stats.n_collected_episodes > 0
and self.test_in_train

View File

@ -1,26 +1,39 @@
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING
from torch import nn
@contextmanager
def in_eval_mode(module: nn.Module) -> Iterator[None]:
"""Temporarily switch to evaluation mode."""
train = module.training
try:
module.eval()
yield
finally:
module.train(train)
if TYPE_CHECKING:
from tianshou.policy import BasePolicy
@contextmanager
def in_train_mode(module: nn.Module) -> Iterator[None]:
"""Temporarily switch to training mode."""
train = module.training
def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]:
"""Temporarily switch to `module.training=enabled`, affecting things like `BatchNormalization`."""
original_mode = module.training
try:
module.train()
module.train(enabled)
yield
finally:
module.train(train)
module.train(original_mode)
@contextmanager
def policy_within_training_step(policy: "BasePolicy", enabled: bool = True) -> Iterator[None]:
"""Temporarily switch to `policy.is_within_training_step=enabled`.
Enabling this ensures that the policy is able to adapt its behavior,
allowing it to differentiate between training and inference/evaluation,
e.g., to sample actions instead of using the most probable action (where applicable)
Note that for rollout, which also happens within a training step, one would usually want
the wrapped torch module to be in evaluation mode, which can be achieved using
`with torch_train_mode(policy, False)`. For subsequent gradient updates, the policy should be both
within training step and in torch train mode.
"""
original_mode = policy.is_within_training_step
try:
policy.is_within_training_step = enabled
yield
finally:
policy.is_within_training_step = original_mode