Tianshou/tianshou/utils/torch_utils.py

26 lines
512 B
Python
Raw Normal View History

from contextlib import contextmanager
from torch import nn
@contextmanager
def in_eval_mode(module: nn.Module) -> None:
"""Temporarily switch to evaluation mode."""
train = module.training
try:
module.eval()
yield
finally:
module.train(train)
@contextmanager
def in_train_mode(module: nn.Module) -> None:
"""Temporarily switch to training mode."""
train = module.training
try:
module.train()
yield
finally:
module.train(train)