27 lines
569 B
Python
27 lines
569 B
Python
from collections.abc import Iterator
|
|
from contextlib import contextmanager
|
|
|
|
from torch import nn
|
|
|
|
|
|
@contextmanager
|
|
def in_eval_mode(module: nn.Module) -> Iterator[None]:
|
|
"""Temporarily switch to evaluation mode."""
|
|
train = module.training
|
|
try:
|
|
module.eval()
|
|
yield
|
|
finally:
|
|
module.train(train)
|
|
|
|
|
|
@contextmanager
|
|
def in_train_mode(module: nn.Module) -> Iterator[None]:
|
|
"""Temporarily switch to training mode."""
|
|
train = module.training
|
|
try:
|
|
module.train()
|
|
yield
|
|
finally:
|
|
module.train(train)
|