Added in_eval/in_train mode contextmanager

This commit is contained in:
Michael Panchenko 2024-04-26 14:45:02 +02:00
parent 829fd9c7a5
commit 7d59302095
3 changed files with 42 additions and 7 deletions

View File

@ -1,10 +1,12 @@
import numpy as np
import torch
from torch import nn
from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd
from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
from tianshou.utils.torch_utils import in_eval_mode, in_train_mode
def test_noise() -> None:
@ -132,9 +134,17 @@ def test_lr_schedulers() -> None:
)
if __name__ == "__main__":
test_noise()
test_moving_average()
test_rms()
test_net()
test_lr_schedulers()
def test_in_eval_mode():
module = nn.Linear(3, 4)
module.train()
with in_eval_mode(module):
assert not module.training
assert module.training
def test_in_train_mode():
module = nn.Linear(3, 4)
module.eval()
with in_train_mode(module):
assert module.training
assert not module.training

View File

@ -27,7 +27,7 @@ def test_episode(
collector.reset(reset_stats=False)
if test_fn:
test_fn(epoch, global_step)
result = collector.collect(n_episode=n_episode, is_eval=True)
result = collector.collect(n_episode=n_episode, eval_mode=True)
if reward_metric: # TODO: move into collector
rew = reward_metric(result.returns)
result.returns = rew

View File

@ -0,0 +1,25 @@
from contextlib import contextmanager
from torch import nn
@contextmanager
def in_eval_mode(module: nn.Module) -> None:
"""Temporarily switch to evaluation mode."""
train = module.training
try:
module.eval()
yield
finally:
module.train(train)
@contextmanager
def in_train_mode(module: nn.Module) -> None:
"""Temporarily switch to training mode."""
train = module.training
try:
module.train()
yield
finally:
module.train(train)