* Add example atari_iqn_hl * Factor out trainer callbacks to new module atari_callbacks * Extract base class for DQN-based agent factories * Improved module factory interface design, achieving higher generality
34 lines
1.2 KiB
Python
34 lines
1.2 KiB
Python
from tianshou.highlevel.trainer import (
|
|
TrainerEpochCallbackTest,
|
|
TrainerEpochCallbackTrain,
|
|
TrainingContext,
|
|
)
|
|
from tianshou.policy import DQNPolicy
|
|
|
|
|
|
class TestEpochCallbackDQNSetEps(TrainerEpochCallbackTest):
|
|
def __init__(self, eps_test: float):
|
|
self.eps_test = eps_test
|
|
|
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
|
policy: DQNPolicy = context.policy
|
|
policy.set_eps(self.eps_test)
|
|
|
|
|
|
class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain):
|
|
def __init__(self, eps_train: float, eps_train_final: float):
|
|
self.eps_train = eps_train
|
|
self.eps_train_final = eps_train_final
|
|
|
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
|
policy: DQNPolicy = context.policy
|
|
logger = context.logger.logger
|
|
# nature DQN setting, linear decay in the first 1M steps
|
|
if env_step <= 1e6:
|
|
eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)
|
|
else:
|
|
eps = self.eps_train_final
|
|
policy.set_eps(eps)
|
|
if env_step % 1000 == 0:
|
|
logger.write("train/env_step", env_step, {"train/eps": eps})
|