Tianshou/tianshou/utils/torch_utils.py

27 lines
569 B
Python

from collections.abc import Iterator
from contextlib import contextmanager
from torch import nn
@contextmanager
def in_eval_mode(module: nn.Module) -> Iterator[None]:
"""Temporarily switch to evaluation mode."""
train = module.training
try:
module.eval()
yield
finally:
module.train(train)
@contextmanager
def in_train_mode(module: nn.Module) -> Iterator[None]:
"""Temporarily switch to training mode."""
train = module.training
try:
module.train()
yield
finally:
module.train(train)