2024-04-26 17:37:12 +02:00
|
|
|
from collections.abc import Iterator
|
2024-04-26 14:45:02 +02:00
|
|
|
from contextlib import contextmanager
|
|
|
|
|
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
2024-04-26 17:37:12 +02:00
|
|
|
def in_eval_mode(module: nn.Module) -> Iterator[None]:
|
2024-04-26 14:45:02 +02:00
|
|
|
"""Temporarily switch to evaluation mode."""
|
|
|
|
train = module.training
|
|
|
|
try:
|
|
|
|
module.eval()
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
module.train(train)
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
2024-04-26 17:37:12 +02:00
|
|
|
def in_train_mode(module: nn.Module) -> Iterator[None]:
|
2024-04-26 14:45:02 +02:00
|
|
|
"""Temporarily switch to training mode."""
|
|
|
|
train = module.training
|
|
|
|
try:
|
|
|
|
module.train()
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
module.train(train)
|