Added in_eval/in_train mode contextmanager
This commit is contained in:
parent
829fd9c7a5
commit
7d59302095
@ -1,10 +1,12 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.exploration import GaussianNoise, OUNoise
|
from tianshou.exploration import GaussianNoise, OUNoise
|
||||||
from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd
|
from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd
|
||||||
from tianshou.utils.net.common import MLP, Net
|
from tianshou.utils.net.common import MLP, Net
|
||||||
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
||||||
|
from tianshou.utils.torch_utils import in_eval_mode, in_train_mode
|
||||||
|
|
||||||
|
|
||||||
def test_noise() -> None:
|
def test_noise() -> None:
|
||||||
@ -132,9 +134,17 @@ def test_lr_schedulers() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def test_in_eval_mode():
|
||||||
test_noise()
|
module = nn.Linear(3, 4)
|
||||||
test_moving_average()
|
module.train()
|
||||||
test_rms()
|
with in_eval_mode(module):
|
||||||
test_net()
|
assert not module.training
|
||||||
test_lr_schedulers()
|
assert module.training
|
||||||
|
|
||||||
|
|
||||||
|
def test_in_train_mode():
|
||||||
|
module = nn.Linear(3, 4)
|
||||||
|
module.eval()
|
||||||
|
with in_train_mode(module):
|
||||||
|
assert module.training
|
||||||
|
assert not module.training
|
@ -27,7 +27,7 @@ def test_episode(
|
|||||||
collector.reset(reset_stats=False)
|
collector.reset(reset_stats=False)
|
||||||
if test_fn:
|
if test_fn:
|
||||||
test_fn(epoch, global_step)
|
test_fn(epoch, global_step)
|
||||||
result = collector.collect(n_episode=n_episode, is_eval=True)
|
result = collector.collect(n_episode=n_episode, eval_mode=True)
|
||||||
if reward_metric: # TODO: move into collector
|
if reward_metric: # TODO: move into collector
|
||||||
rew = reward_metric(result.returns)
|
rew = reward_metric(result.returns)
|
||||||
result.returns = rew
|
result.returns = rew
|
||||||
|
25
tianshou/utils/torch_utils.py
Normal file
25
tianshou/utils/torch_utils.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def in_eval_mode(module: nn.Module) -> None:
|
||||||
|
"""Temporarily switch to evaluation mode."""
|
||||||
|
train = module.training
|
||||||
|
try:
|
||||||
|
module.eval()
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
module.train(train)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def in_train_mode(module: nn.Module) -> None:
|
||||||
|
"""Temporarily switch to training mode."""
|
||||||
|
train = module.training
|
||||||
|
try:
|
||||||
|
module.train()
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
module.train(train)
|
Loading…
x
Reference in New Issue
Block a user