Dominik Jain dae4000cd2 Revert "Depend on sensAI instead of copying its utils (logging, string)"
This reverts commit fdb0eba93d81fa5e698770b4f7088c87fc1238da.
2023-11-08 19:11:39 +01:00

73 lines
2.4 KiB
Python

from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import TLogger
from tianshou.policy import BasePolicy
from tianshou.utils.string import ToStringMixin
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
class TrainingContext:
def __init__(self, policy: TPolicy, envs: Environments, logger: TLogger):
self.policy = policy
self.envs = envs
self.logger = logger
class TrainerEpochCallbackTrain(ToStringMixin, ABC):
"""Callback which is called at the beginning of each epoch."""
@abstractmethod
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
pass
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]:
def fn(epoch: int, env_step: int) -> None:
return self.callback(epoch, env_step, context)
return fn
class TrainerEpochCallbackTest(ToStringMixin, ABC):
"""Callback which is called at the beginning of each epoch."""
@abstractmethod
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
pass
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int | None], None]:
def fn(epoch: int, env_step: int | None) -> None:
return self.callback(epoch, env_step, context)
return fn
class TrainerStopCallback(ToStringMixin, ABC):
"""Callback indicating whether training should stop."""
@abstractmethod
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
""":param mean_rewards: the average undiscounted returns of the testing result
:param context: the training context
:return: True if the goal has been reached and training should stop, False otherwise
"""
def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]:
def fn(mean_rewards: float) -> bool:
return self.should_stop(mean_rewards, context)
return fn
@dataclass
class TrainerCallbacks:
"""Container for callbacks used during training."""
epoch_callback_train: TrainerEpochCallbackTrain | None = None
epoch_callback_test: TrainerEpochCallbackTest | None = None
stop_callback: TrainerStopCallback | None = None