Adjusted notebooks, log messages and docs accordingly. Removed now obsolete in_eval_mode and the private context manager in Trainer
40 lines
1.4 KiB
Python
40 lines
1.4 KiB
Python
from collections.abc import Iterator
|
|
from contextlib import contextmanager
|
|
from typing import TYPE_CHECKING
|
|
|
|
from torch import nn
|
|
|
|
if TYPE_CHECKING:
|
|
from tianshou.policy import BasePolicy
|
|
|
|
|
|
@contextmanager
|
|
def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]:
|
|
"""Temporarily switch to `module.training=enabled`, affecting things like `BatchNormalization`."""
|
|
original_mode = module.training
|
|
try:
|
|
module.train(enabled)
|
|
yield
|
|
finally:
|
|
module.train(original_mode)
|
|
|
|
|
|
@contextmanager
|
|
def policy_within_training_step(policy: "BasePolicy", enabled: bool = True) -> Iterator[None]:
|
|
"""Temporarily switch to `policy.is_within_training_step=enabled`.
|
|
|
|
Enabling this ensures that the policy is able to adapt its behavior,
|
|
allowing it to differentiate between training and inference/evaluation,
|
|
e.g., to sample actions instead of using the most probable action (where applicable)
|
|
Note that for rollout, which also happens within a training step, one would usually want
|
|
the wrapped torch module to be in evaluation mode, which can be achieved using
|
|
`with torch_train_mode(policy, False)`. For subsequent gradient updates, the policy should be both
|
|
within training step and in torch train mode.
|
|
"""
|
|
original_mode = policy.is_within_training_step
|
|
try:
|
|
policy.is_within_training_step = enabled
|
|
yield
|
|
finally:
|
|
policy.is_within_training_step = original_mode
|