New context manager: policy_within_training_step
Adjusted notebooks, log messages and docs accordingly. Removed now obsolete in_eval_mode and the private context manager in Trainer
This commit is contained in:
parent
78ea013956
commit
e94a5c04cf
@ -44,7 +44,7 @@
|
||||
the actions concatenated), which is essential for the case where we want
|
||||
to reuse the actor's preprocessing network #1128
|
||||
- `torch_utils` (new module)
|
||||
- Added contextmanagers `in`
|
||||
- Added context managers `torch_train_mode` and `policy_within_training_step` #1123
|
||||
|
||||
### Fixes
|
||||
- `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics,
|
||||
|
@ -18,9 +18,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# !pip install tianshou gym"
|
||||
|
@ -74,7 +74,8 @@
|
||||
")\n",
|
||||
"from tianshou.utils import RunningMeanStd\n",
|
||||
"from tianshou.utils.net.common import Net\n",
|
||||
"from tianshou.utils.net.discrete import Actor"
|
||||
"from tianshou.utils.net.discrete import Actor\n",
|
||||
"from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -644,7 +645,10 @@
|
||||
"source": [
|
||||
"obs, info = env.reset()\n",
|
||||
"for i in range(3, 10):\n",
|
||||
" act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n",
|
||||
" # For retrieving actions to be used for training, we set the policy to training mode,\n",
|
||||
" # but the wrapped torch module should be in eval mode.\n",
|
||||
" with policy_within_training_step(policy), torch_train_mode(policy, enabled=False):\n",
|
||||
" act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n",
|
||||
" obs_next, rew, _, truncated, info = env.step(act)\n",
|
||||
" # pretend this episode never end\n",
|
||||
" terminated = False\n",
|
||||
@ -695,7 +699,11 @@
|
||||
},
|
||||
"source": [
|
||||
"#### Updates\n",
|
||||
"Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train."
|
||||
"Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train.\n",
|
||||
"\n",
|
||||
"However, we need to manually set the torch module to training mode prior to that, \n",
|
||||
"and also declare that we are within a training step. Tianshou Trainers will take care of that automatically,\n",
|
||||
"but users need to consider it when calling `.update` outside of the trainer."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -711,16 +719,11 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 0 means sample all data from the buffer\n",
|
||||
"policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "enqlFQLSJrQl"
|
||||
},
|
||||
"source": [
|
||||
"Not that difficult, right?"
|
||||
"\n",
|
||||
"# For updating the policy, the policy should be in training mode\n",
|
||||
"# and the wrapped torch module should also be in training mode (unlike when collecting data).\n",
|
||||
"with policy_within_training_step(policy), torch_train_mode(policy):\n",
|
||||
" policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -54,7 +54,6 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"id": "do-xZ-8B7nVH",
|
||||
@ -64,9 +63,12 @@
|
||||
"tags": [
|
||||
"hide-cell",
|
||||
"remove-output"
|
||||
]
|
||||
],
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-06T15:34:02.969675Z",
|
||||
"start_time": "2024-05-06T15:34:00.747309Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%capture\n",
|
||||
"\n",
|
||||
@ -78,14 +80,20 @@
|
||||
"from tianshou.policy import PGPolicy\n",
|
||||
"from tianshou.trainer import OnpolicyTrainer\n",
|
||||
"from tianshou.utils.net.common import Net\n",
|
||||
"from tianshou.utils.net.discrete import Actor"
|
||||
]
|
||||
"from tianshou.utils.net.discrete import Actor\n",
|
||||
"from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 1
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-06T15:34:07.536452Z",
|
||||
"start_time": "2024-05-06T15:34:03.636670Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"train_env_num = 4\n",
|
||||
"buffer_size = (\n",
|
||||
@ -123,7 +131,9 @@
|
||||
"replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
|
||||
"test_collector = Collector(policy, test_envs)\n",
|
||||
"train_collector = Collector(policy, train_envs, replayBuffer)"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 2
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@ -154,11 +164,19 @@
|
||||
"\n",
|
||||
"n_episode = 10\n",
|
||||
"for _i in range(n_episode):\n",
|
||||
" evaluation_result = test_collector.collect(n_episode=n_episode)\n",
|
||||
" # for test collector, we set the wrapped torch module to evaluation mode\n",
|
||||
" # by default, the policy object itself is not within the training step\n",
|
||||
" with torch_train_mode(policy, enabled=False):\n",
|
||||
" evaluation_result = test_collector.collect(n_episode=n_episode)\n",
|
||||
" print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n",
|
||||
" train_collector.collect(n_step=2000)\n",
|
||||
" # 0 means taking all data stored in train_collector.buffer\n",
|
||||
" policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n",
|
||||
" # for collecting data for training, the policy object should be within the training step\n",
|
||||
" # (affecting e.g. whether the policy is stochastic or deterministic)\n",
|
||||
" with policy_within_training_step(policy):\n",
|
||||
" train_collector.collect(n_step=2000)\n",
|
||||
" # 0 means taking all data stored in train_collector.buffer\n",
|
||||
" # for updating the policy, the wrapped torch module should be in training mode\n",
|
||||
" with torch_train_mode(policy):\n",
|
||||
" policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n",
|
||||
" train_collector.reset_buffer(keep_statistics=True)"
|
||||
]
|
||||
},
|
||||
|
@ -6,7 +6,7 @@ from tianshou.exploration import GaussianNoise, OUNoise
|
||||
from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd
|
||||
from tianshou.utils.net.common import MLP, Net
|
||||
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
||||
from tianshou.utils.torch_utils import in_eval_mode, in_train_mode
|
||||
from tianshou.utils.torch_utils import torch_train_mode
|
||||
|
||||
|
||||
def test_noise() -> None:
|
||||
@ -137,7 +137,7 @@ def test_lr_schedulers() -> None:
|
||||
def test_in_eval_mode() -> None:
|
||||
module = nn.Linear(3, 4)
|
||||
module.train()
|
||||
with in_eval_mode(module):
|
||||
with torch_train_mode(module, False):
|
||||
assert not module.training
|
||||
assert module.training
|
||||
|
||||
@ -145,6 +145,6 @@ def test_in_eval_mode() -> None:
|
||||
def test_in_train_mode() -> None:
|
||||
module = nn.Linear(3, 4)
|
||||
module.eval()
|
||||
with in_train_mode(module):
|
||||
with torch_train_mode(module):
|
||||
assert module.training
|
||||
assert not module.training
|
||||
|
@ -27,7 +27,7 @@ from tianshou.data.types import (
|
||||
from tianshou.env import BaseVectorEnv, DummyVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.utils.print import DataclassPPrintMixin
|
||||
from tianshou.utils.torch_utils import in_eval_mode
|
||||
from tianshou.utils.torch_utils import torch_train_mode
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -300,7 +300,7 @@ class BaseCollector(ABC):
|
||||
if reset_before_collect:
|
||||
self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs)
|
||||
|
||||
with in_eval_mode(self.policy): # safety precaution only
|
||||
with torch_train_mode(self.policy, False):
|
||||
return self._collect(
|
||||
n_step=n_step,
|
||||
n_episode=n_episode,
|
||||
|
@ -25,7 +25,7 @@ from tianshou.data.types import (
|
||||
)
|
||||
from tianshou.utils import MultipleLRSchedulers
|
||||
from tianshou.utils.print import DataclassPPrintMixin
|
||||
from tianshou.utils.torch_utils import in_train_mode
|
||||
from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -532,7 +532,7 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
raise RuntimeError(
|
||||
f"update() was called outside of a training step as signalled by {self.is_within_training_step=} "
|
||||
f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned "
|
||||
f"flag yourself.",
|
||||
f"flag yourself. You can to this e.g., by using the contextmanager {policy_within_training_step.__name__}.",
|
||||
)
|
||||
|
||||
if buffer is None:
|
||||
@ -541,7 +541,7 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
batch, indices = buffer.sample(sample_size)
|
||||
self.updating = True
|
||||
batch = self.process_fn(batch, buffer, indices)
|
||||
with in_train_mode(self):
|
||||
with torch_train_mode(self):
|
||||
training_stat = self.learn(batch, **kwargs)
|
||||
self.post_process_fn(batch, buffer, indices)
|
||||
if self.lr_scheduler is not None:
|
||||
|
@ -2,8 +2,7 @@ import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Callable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
|
||||
import numpy as np
|
||||
@ -29,6 +28,7 @@ from tianshou.utils import (
|
||||
tqdm_config,
|
||||
)
|
||||
from tianshou.utils.logging import set_numerical_fields_to_precision
|
||||
from tianshou.utils.torch_utils import policy_within_training_step
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -404,15 +404,6 @@ class BaseTrainer(ABC):
|
||||
|
||||
return test_stat, stop_fn_flag
|
||||
|
||||
@contextmanager
|
||||
def _is_within_training_step_enabled(self, is_within_training_step: bool) -> Iterator[None]:
|
||||
old_value = self.policy.is_within_training_step
|
||||
try:
|
||||
self.policy.is_within_training_step = is_within_training_step
|
||||
yield
|
||||
finally:
|
||||
self.policy.is_within_training_step = old_value
|
||||
|
||||
def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]:
|
||||
"""Perform one training iteration.
|
||||
|
||||
@ -422,7 +413,7 @@ class BaseTrainer(ABC):
|
||||
:return: the iteration's collect stats, training stats, and a flag indicating whether to stop training.
|
||||
If training is to be stopped, no gradient steps will be performed and the training stats will be `None`.
|
||||
"""
|
||||
with self._is_within_training_step_enabled(True):
|
||||
with policy_within_training_step(self.policy):
|
||||
should_stop_training = False
|
||||
|
||||
collect_stats: CollectStatsBase | CollectStats
|
||||
@ -474,6 +465,7 @@ class BaseTrainer(ABC):
|
||||
|
||||
return collect_stats
|
||||
|
||||
# TODO (maybe): separate out side effect, simplify name?
|
||||
def _update_best_reward_and_return_should_stop_training(
|
||||
self,
|
||||
collect_stats: CollectStats,
|
||||
@ -492,7 +484,7 @@ class BaseTrainer(ABC):
|
||||
should_stop_training = False
|
||||
|
||||
# Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics
|
||||
with self._is_within_training_step_enabled(False):
|
||||
with policy_within_training_step(self.policy, enabled=False):
|
||||
if (
|
||||
collect_stats.n_collected_episodes > 0
|
||||
and self.test_in_train
|
||||
|
@ -1,26 +1,39 @@
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
@contextmanager
|
||||
def in_eval_mode(module: nn.Module) -> Iterator[None]:
|
||||
"""Temporarily switch to evaluation mode."""
|
||||
train = module.training
|
||||
try:
|
||||
module.eval()
|
||||
yield
|
||||
finally:
|
||||
module.train(train)
|
||||
if TYPE_CHECKING:
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
@contextmanager
|
||||
def in_train_mode(module: nn.Module) -> Iterator[None]:
|
||||
"""Temporarily switch to training mode."""
|
||||
train = module.training
|
||||
def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]:
|
||||
"""Temporarily switch to `module.training=enabled`, affecting things like `BatchNormalization`."""
|
||||
original_mode = module.training
|
||||
try:
|
||||
module.train()
|
||||
module.train(enabled)
|
||||
yield
|
||||
finally:
|
||||
module.train(train)
|
||||
module.train(original_mode)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def policy_within_training_step(policy: "BasePolicy", enabled: bool = True) -> Iterator[None]:
|
||||
"""Temporarily switch to `policy.is_within_training_step=enabled`.
|
||||
|
||||
Enabling this ensures that the policy is able to adapt its behavior,
|
||||
allowing it to differentiate between training and inference/evaluation,
|
||||
e.g., to sample actions instead of using the most probable action (where applicable)
|
||||
Note that for rollout, which also happens within a training step, one would usually want
|
||||
the wrapped torch module to be in evaluation mode, which can be achieved using
|
||||
`with torch_train_mode(policy, False)`. For subsequent gradient updates, the policy should be both
|
||||
within training step and in torch train mode.
|
||||
"""
|
||||
original_mode = policy.is_within_training_step
|
||||
try:
|
||||
policy.is_within_training_step = enabled
|
||||
yield
|
||||
finally:
|
||||
policy.is_within_training_step = original_mode
|
||||
|
Loading…
x
Reference in New Issue
Block a user