from tianshou.highlevel.trainer import (
    TrainerEpochCallbackTest,
    TrainerEpochCallbackTrain,
    TrainingContext,
)
from tianshou.policy import DQNPolicy


class TestEpochCallbackDQNSetEps(TrainerEpochCallbackTest):
    def __init__(self, eps_test: float):
        self.eps_test = eps_test

    def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
        policy: DQNPolicy = context.policy
        policy.set_eps(self.eps_test)


class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain):
    def __init__(self, eps_train: float, eps_train_final: float):
        self.eps_train = eps_train
        self.eps_train_final = eps_train_final

    def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
        policy: DQNPolicy = context.policy
        logger = context.logger.logger
        # nature DQN setting, linear decay in the first 1M steps
        if env_step <= 1e6:
            eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)
        else:
            eps = self.eps_train_final
        policy.set_eps(eps)
        if env_step % 1000 == 0:
            logger.write("train/env_step", env_step, {"train/eps": eps})