2023-10-05 15:39:32 +02:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from collections.abc import Callable
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import TypeVar
|
|
|
|
|
|
|
|
from tianshou.highlevel.env import Environments
|
2023-10-12 17:40:16 +02:00
|
|
|
from tianshou.highlevel.logger import TLogger
|
2023-10-05 15:39:32 +02:00
|
|
|
from tianshou.policy import BasePolicy
|
|
|
|
from tianshou.utils.string import ToStringMixin
|
|
|
|
|
|
|
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
|
|
|
|
|
|
|
|
|
|
|
class TrainingContext:
|
2023-10-12 17:40:16 +02:00
|
|
|
def __init__(self, policy: TPolicy, envs: Environments, logger: TLogger):
|
2023-10-05 15:39:32 +02:00
|
|
|
self.policy = policy
|
|
|
|
self.envs = envs
|
|
|
|
self.logger = logger
|
|
|
|
|
|
|
|
|
2023-10-09 17:22:52 +02:00
|
|
|
class TrainerEpochCallbackTrain(ToStringMixin, ABC):
|
2023-10-05 15:39:32 +02:00
|
|
|
"""Callback which is called at the beginning of each epoch."""
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]:
|
2023-10-09 17:22:52 +02:00
|
|
|
def fn(epoch: int, env_step: int) -> None:
|
|
|
|
return self.callback(epoch, env_step, context)
|
|
|
|
|
|
|
|
return fn
|
|
|
|
|
|
|
|
|
|
|
|
class TrainerEpochCallbackTest(ToStringMixin, ABC):
|
|
|
|
"""Callback which is called at the beginning of each epoch."""
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int | None], None]:
|
|
|
|
def fn(epoch: int, env_step: int | None) -> None:
|
2023-10-05 15:39:32 +02:00
|
|
|
return self.callback(epoch, env_step, context)
|
|
|
|
|
|
|
|
return fn
|
|
|
|
|
|
|
|
|
|
|
|
class TrainerStopCallback(ToStringMixin, ABC):
|
|
|
|
"""Callback indicating whether training should stop."""
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
|
|
|
""":param mean_rewards: the average undiscounted returns of the testing result
|
|
|
|
:return: True if the goal has been reached and training should stop, False otherwise
|
|
|
|
"""
|
|
|
|
|
|
|
|
def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]:
|
2023-10-09 17:22:52 +02:00
|
|
|
def fn(mean_rewards: float) -> bool:
|
2023-10-05 15:39:32 +02:00
|
|
|
return self.should_stop(mean_rewards, context)
|
|
|
|
|
|
|
|
return fn
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class TrainerCallbacks:
|
2023-10-09 17:22:52 +02:00
|
|
|
epoch_callback_train: TrainerEpochCallbackTrain | None = None
|
|
|
|
epoch_callback_test: TrainerEpochCallbackTest | None = None
|
2023-10-05 15:39:32 +02:00
|
|
|
stop_callback: TrainerStopCallback | None = None
|