diff --git a/test/base/test_utils.py b/test/base/test_utils.py index bd14ffe..1d8e37c 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -1,10 +1,12 @@ import numpy as np import torch +from torch import nn from tianshou.exploration import GaussianNoise, OUNoise from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic +from tianshou.utils.torch_utils import in_eval_mode, in_train_mode def test_noise() -> None: @@ -132,9 +134,17 @@ def test_lr_schedulers() -> None: ) -if __name__ == "__main__": - test_noise() - test_moving_average() - test_rms() - test_net() - test_lr_schedulers() +def test_in_eval_mode(): + module = nn.Linear(3, 4) + module.train() + with in_eval_mode(module): + assert not module.training + assert module.training + + +def test_in_train_mode(): + module = nn.Linear(3, 4) + module.eval() + with in_train_mode(module): + assert module.training + assert not module.training \ No newline at end of file diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 4d990db..b790fdd 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -27,7 +27,7 @@ def test_episode( collector.reset(reset_stats=False) if test_fn: test_fn(epoch, global_step) - result = collector.collect(n_episode=n_episode, is_eval=True) + result = collector.collect(n_episode=n_episode, eval_mode=True) if reward_metric: # TODO: move into collector rew = reward_metric(result.returns) result.returns = rew diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py new file mode 100644 index 0000000..1676b52 --- /dev/null +++ b/tianshou/utils/torch_utils.py @@ -0,0 +1,25 @@ +from contextlib import contextmanager + +from torch import nn + + +@contextmanager +def in_eval_mode(module: nn.Module) -> None: + """Temporarily switch to evaluation mode.""" + train = module.training + try: + module.eval() + yield + finally: + module.train(train) + + +@contextmanager +def in_train_mode(module: nn.Module) -> None: + """Temporarily switch to training mode.""" + train = module.training + try: + module.train() + yield + finally: + module.train(train)