Added in_eval/in_train mode contextmanager
This commit is contained in:
parent
829fd9c7a5
commit
7d59302095
@ -1,10 +1,12 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tianshou.exploration import GaussianNoise, OUNoise
|
||||
from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd
|
||||
from tianshou.utils.net.common import MLP, Net
|
||||
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
||||
from tianshou.utils.torch_utils import in_eval_mode, in_train_mode
|
||||
|
||||
|
||||
def test_noise() -> None:
|
||||
@ -132,9 +134,17 @@ def test_lr_schedulers() -> None:
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_noise()
|
||||
test_moving_average()
|
||||
test_rms()
|
||||
test_net()
|
||||
test_lr_schedulers()
|
||||
def test_in_eval_mode():
|
||||
module = nn.Linear(3, 4)
|
||||
module.train()
|
||||
with in_eval_mode(module):
|
||||
assert not module.training
|
||||
assert module.training
|
||||
|
||||
|
||||
def test_in_train_mode():
|
||||
module = nn.Linear(3, 4)
|
||||
module.eval()
|
||||
with in_train_mode(module):
|
||||
assert module.training
|
||||
assert not module.training
|
@ -27,7 +27,7 @@ def test_episode(
|
||||
collector.reset(reset_stats=False)
|
||||
if test_fn:
|
||||
test_fn(epoch, global_step)
|
||||
result = collector.collect(n_episode=n_episode, is_eval=True)
|
||||
result = collector.collect(n_episode=n_episode, eval_mode=True)
|
||||
if reward_metric: # TODO: move into collector
|
||||
rew = reward_metric(result.returns)
|
||||
result.returns = rew
|
||||
|
25
tianshou/utils/torch_utils.py
Normal file
25
tianshou/utils/torch_utils.py
Normal file
@ -0,0 +1,25 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
@contextmanager
|
||||
def in_eval_mode(module: nn.Module) -> None:
|
||||
"""Temporarily switch to evaluation mode."""
|
||||
train = module.training
|
||||
try:
|
||||
module.eval()
|
||||
yield
|
||||
finally:
|
||||
module.train(train)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def in_train_mode(module: nn.Module) -> None:
|
||||
"""Temporarily switch to training mode."""
|
||||
train = module.training
|
||||
try:
|
||||
module.train()
|
||||
yield
|
||||
finally:
|
||||
module.train(train)
|
Loading…
x
Reference in New Issue
Block a user