Add DQN support in high-level API
* Allow to specify trainer callbacks (train_fn, test_fn, stop_fn) in high-level API, adding the necessary abstractions and pass-on mechanisms * Add example atari_dqn_hl
This commit is contained in:
parent
358978c65d
commit
1cba589bd4
128
examples/atari/atari_dqn_hl.py
Normal file
128
examples/atari/atari_dqn_hl.py
Normal file
@ -0,0 +1,128 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import datetime
|
||||
import os
|
||||
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.atari.atari_network import (
|
||||
CriticFactoryAtariDQN,
|
||||
FeatureNetFactoryDQN,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
DQNExperimentBuilder,
|
||||
RLExperimentConfig,
|
||||
)
|
||||
from tianshou.highlevel.params.policy_params import DQNParams
|
||||
from tianshou.highlevel.params.policy_wrapper import (
|
||||
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||
)
|
||||
from tianshou.highlevel.trainer import TrainerEpochCallback, TrainingContext
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.utils import logging
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: RLExperimentConfig,
|
||||
task: str = "PongNoFrameskip-v4",
|
||||
scale_obs: int = 0,
|
||||
eps_test: float = 0.005,
|
||||
eps_train: float = 1.0,
|
||||
eps_train_final: float = 0.05,
|
||||
buffer_size: int = 100000,
|
||||
lr: float = 0.0001,
|
||||
gamma: float = 0.99,
|
||||
n_step: int = 3,
|
||||
target_update_freq: int = 500,
|
||||
epoch: int = 100,
|
||||
step_per_epoch: int = 100000,
|
||||
step_per_collect: int = 10,
|
||||
update_per_step: float = 0.1,
|
||||
batch_size: int = 32,
|
||||
training_num: int = 10,
|
||||
test_num: int = 10,
|
||||
frames_stack: int = 4,
|
||||
save_buffer_name: str | None = None, # TODO support?
|
||||
icm_lr_scale: float = 0.0,
|
||||
icm_reward_scale: float = 0.01,
|
||||
icm_forward_loss_weight: float = 0.2,
|
||||
):
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||
|
||||
sampling_config = RLSamplingConfig(
|
||||
num_epochs=epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
batch_size=batch_size,
|
||||
num_train_envs=training_num,
|
||||
num_test_envs=test_num,
|
||||
buffer_size=buffer_size,
|
||||
step_per_collect=step_per_collect,
|
||||
update_per_step=update_per_step,
|
||||
repeat_per_collect=None,
|
||||
replay_buffer_stack_num=frames_stack,
|
||||
replay_buffer_ignore_obs_next=True,
|
||||
replay_buffer_save_only_last_obs=True,
|
||||
)
|
||||
|
||||
env_factory = AtariEnvFactory(
|
||||
task,
|
||||
experiment_config.seed,
|
||||
sampling_config,
|
||||
frames_stack,
|
||||
scale=scale_obs,
|
||||
)
|
||||
|
||||
class TrainEpochCallback(TrainerEpochCallback):
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy: DQNPolicy = context.policy
|
||||
logger = context.logger.logger
|
||||
# nature DQN setting, linear decay in the first 1M steps
|
||||
if env_step <= 1e6:
|
||||
eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final)
|
||||
else:
|
||||
eps = eps_train_final
|
||||
policy.set_eps(eps)
|
||||
if env_step % 1000 == 0:
|
||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||
|
||||
class TestEpochCallback(TrainerEpochCallback):
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy: DQNPolicy = context.policy
|
||||
policy.set_eps(eps_test)
|
||||
|
||||
builder = (
|
||||
DQNExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||
.with_dqn_params(
|
||||
DQNParams(
|
||||
discount_factor=gamma,
|
||||
estimation_step=n_step,
|
||||
lr=lr,
|
||||
target_update_freq=target_update_freq,
|
||||
),
|
||||
)
|
||||
.with_critic_factory(CriticFactoryAtariDQN())
|
||||
.with_trainer_epoch_callback_train(TrainEpochCallback())
|
||||
.with_trainer_epoch_callback_test(TestEpochCallback())
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
)
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
FeatureNetFactoryDQN(),
|
||||
[512],
|
||||
lr,
|
||||
icm_lr_scale,
|
||||
icm_reward_scale,
|
||||
icm_forward_loss_weight,
|
||||
),
|
||||
)
|
||||
|
||||
experiment = builder.build()
|
||||
experiment.run(log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.run_main(lambda: CLI(main))
|
@ -8,6 +8,7 @@ from torch import nn
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module.actor import ActorFactory
|
||||
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice
|
||||
from tianshou.highlevel.module.critic import CriticFactory
|
||||
from tianshou.utils.net.common import BaseActor
|
||||
from tianshou.utils.net.discrete import Actor, NoisyLinear
|
||||
|
||||
@ -226,6 +227,21 @@ class QRDQN(DQN):
|
||||
return obs, state
|
||||
|
||||
|
||||
class CriticFactoryAtariDQN(CriticFactory):
|
||||
def create_module(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
use_action: bool,
|
||||
) -> torch.nn.Module:
|
||||
assert use_action
|
||||
return DQN(
|
||||
*envs.get_observation_shape(),
|
||||
envs.get_action_shape(),
|
||||
device=device,
|
||||
).to(device)
|
||||
|
||||
|
||||
class ActorFactoryAtariDQN(ActorFactory):
|
||||
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool):
|
||||
self.hidden_size = hidden_size
|
||||
|
@ -10,7 +10,7 @@ from examples.atari.atari_network import (
|
||||
ActorFactoryAtariDQN,
|
||||
FeatureNetFactoryDQN,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
PPOExperimentBuilder,
|
||||
@ -98,6 +98,7 @@ def main(
|
||||
)
|
||||
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs))
|
||||
.with_critic_factory_use_actor()
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
)
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
|
@ -11,6 +11,7 @@ import numpy as np
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
|
||||
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
|
||||
|
||||
try:
|
||||
import envpool
|
||||
@ -374,11 +375,19 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
|
||||
|
||||
|
||||
class AtariEnvFactory(EnvFactory):
|
||||
def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig, frame_stack: int):
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
seed: int,
|
||||
sampling_config: RLSamplingConfig,
|
||||
frame_stack: int,
|
||||
scale: int = 0,
|
||||
):
|
||||
self.task = task
|
||||
self.sampling_config = sampling_config
|
||||
self.seed = seed
|
||||
self.frame_stack = frame_stack
|
||||
self.scale = scale
|
||||
|
||||
def create_envs(self, config=None) -> DiscreteEnvironments:
|
||||
env, train_envs, test_envs = make_atari_env(
|
||||
@ -386,7 +395,20 @@ class AtariEnvFactory(EnvFactory):
|
||||
seed=self.seed,
|
||||
training_num=self.sampling_config.num_train_envs,
|
||||
test_num=self.sampling_config.num_test_envs,
|
||||
scale=0,
|
||||
scale=self.scale,
|
||||
frame_stack=self.frame_stack,
|
||||
)
|
||||
return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||
|
||||
|
||||
class AtariStopCallback(TrainerStopCallback):
|
||||
def __init__(self, task: str):
|
||||
self.task = task
|
||||
|
||||
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
||||
env = context.envs.env
|
||||
if env.spec.reward_threshold:
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
if "Pong" in self.task:
|
||||
return mean_rewards >= 20
|
||||
return False
|
||||
|
@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import gymnasium
|
||||
import torch
|
||||
|
||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
@ -23,6 +24,7 @@ from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.highlevel.params.policy_params import (
|
||||
A2CParams,
|
||||
DDPGParams,
|
||||
DQNParams,
|
||||
Params,
|
||||
ParamTransformerData,
|
||||
PPOParams,
|
||||
@ -30,10 +32,12 @@ from tianshou.highlevel.params.policy_params import (
|
||||
TD3Params,
|
||||
)
|
||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||
from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext
|
||||
from tianshou.policy import (
|
||||
A2CPolicy,
|
||||
BasePolicy,
|
||||
DDPGPolicy,
|
||||
DQNPolicy,
|
||||
PPOPolicy,
|
||||
SACPolicy,
|
||||
TD3Policy,
|
||||
@ -54,6 +58,7 @@ class AgentFactory(ABC, ToStringMixin):
|
||||
self.sampling_config = sampling_config
|
||||
self.optim_factory = optim_factory
|
||||
self.policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||
self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||
|
||||
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
|
||||
buffer_size = self.sampling_config.buffer_size
|
||||
@ -85,6 +90,9 @@ class AgentFactory(ABC, ToStringMixin):
|
||||
) -> None:
|
||||
self.policy_wrapper_factory = policy_wrapper_factory
|
||||
|
||||
def set_trainer_callbacks(self, callbacks: TrainerCallbacks):
|
||||
self.trainer_callbacks = callbacks
|
||||
|
||||
@abstractmethod
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
pass
|
||||
@ -145,6 +153,21 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
|
||||
logger: Logger,
|
||||
) -> OnpolicyTrainer:
|
||||
sampling_config = self.sampling_config
|
||||
callbacks = self.trainer_callbacks
|
||||
context = TrainingContext(policy, envs, logger)
|
||||
train_fn = (
|
||||
callbacks.epoch_callback_train.get_trainer_fn(context)
|
||||
if callbacks.epoch_callback_train
|
||||
else None
|
||||
)
|
||||
test_fn = (
|
||||
callbacks.epoch_callback_test.get_trainer_fn(context)
|
||||
if callbacks.epoch_callback_test
|
||||
else None
|
||||
)
|
||||
stop_fn = (
|
||||
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
||||
)
|
||||
return OnpolicyTrainer(
|
||||
policy=policy,
|
||||
train_collector=train_collector,
|
||||
@ -158,6 +181,9 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
|
||||
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
||||
logger=logger.logger,
|
||||
test_in_train=False,
|
||||
train_fn=train_fn,
|
||||
test_fn=test_fn,
|
||||
stop_fn=stop_fn,
|
||||
)
|
||||
|
||||
|
||||
@ -171,6 +197,21 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
||||
logger: Logger,
|
||||
) -> OffpolicyTrainer:
|
||||
sampling_config = self.sampling_config
|
||||
callbacks = self.trainer_callbacks
|
||||
context = TrainingContext(policy, envs, logger)
|
||||
train_fn = (
|
||||
callbacks.epoch_callback_train.get_trainer_fn(context)
|
||||
if callbacks.epoch_callback_train
|
||||
else None
|
||||
)
|
||||
test_fn = (
|
||||
callbacks.epoch_callback_test.get_trainer_fn(context)
|
||||
if callbacks.epoch_callback_test
|
||||
else None
|
||||
)
|
||||
stop_fn = (
|
||||
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
||||
)
|
||||
return OffpolicyTrainer(
|
||||
policy=policy,
|
||||
train_collector=train_collector,
|
||||
@ -184,6 +225,9 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
||||
logger=logger.logger,
|
||||
update_per_step=sampling_config.update_per_step,
|
||||
test_in_train=False,
|
||||
train_fn=train_fn,
|
||||
test_fn=test_fn,
|
||||
stop_fn=stop_fn,
|
||||
)
|
||||
|
||||
|
||||
@ -195,6 +239,23 @@ class _ActorMixin:
|
||||
return self.actor_module_opt_factory.create_module_opt(envs, device, lr)
|
||||
|
||||
|
||||
class _CriticMixin:
|
||||
def __init__(
|
||||
self,
|
||||
critic_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
critic_use_action: bool,
|
||||
):
|
||||
self.critic_module_opt_factory = CriticModuleOptFactory(
|
||||
critic_factory,
|
||||
optim_factory,
|
||||
critic_use_action,
|
||||
)
|
||||
|
||||
def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||
return self.critic_module_opt_factory.create_module_opt(envs, device, lr)
|
||||
|
||||
|
||||
class _ActorCriticMixin:
|
||||
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
|
||||
|
||||
@ -241,7 +302,7 @@ class _ActorCriticMixin:
|
||||
return ActorCriticModuleOpt(actor_critic, optim)
|
||||
|
||||
|
||||
class _ActorAndCriticMixin(_ActorMixin):
|
||||
class _ActorAndCriticMixin(_ActorMixin, _CriticMixin):
|
||||
def __init__(
|
||||
self,
|
||||
actor_factory: ActorFactory,
|
||||
@ -249,15 +310,8 @@ class _ActorAndCriticMixin(_ActorMixin):
|
||||
optim_factory: OptimizerFactory,
|
||||
critic_use_action: bool,
|
||||
):
|
||||
super().__init__(actor_factory, optim_factory)
|
||||
self.critic_module_opt_factory = CriticModuleOptFactory(
|
||||
critic_factory,
|
||||
optim_factory,
|
||||
critic_use_action,
|
||||
)
|
||||
|
||||
def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||
return self.critic_module_opt_factory.create_module_opt(envs, device, lr)
|
||||
_ActorMixin.__init__(self, actor_factory, optim_factory)
|
||||
_CriticMixin.__init__(self, critic_factory, optim_factory, critic_use_action)
|
||||
|
||||
|
||||
class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
|
||||
@ -385,6 +439,42 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
||||
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||
|
||||
|
||||
class DQNAgentFactory(OffpolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
params: DQNParams,
|
||||
sampling_config: RLSamplingConfig,
|
||||
critic_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(sampling_config, optim_factory)
|
||||
self.params = params
|
||||
self.critic_factory = critic_factory
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
critic = self.critic_factory.create_module(envs, device, use_action=True)
|
||||
optim = self.optim_factory.create_optimizer(critic, self.params.lr)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerData(
|
||||
envs=envs,
|
||||
device=device,
|
||||
optim=optim,
|
||||
optim_factory=self.optim_factory,
|
||||
),
|
||||
)
|
||||
envs.get_type().assert_discrete(self)
|
||||
# noinspection PyTypeChecker
|
||||
action_space: gymnasium.spaces.Discrete = envs.get_action_space()
|
||||
return DQNPolicy(
|
||||
model=critic,
|
||||
optim=optim,
|
||||
action_space=action_space,
|
||||
observation_space=envs.get_observation_space(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -14,10 +14,14 @@ class RLSamplingConfig:
|
||||
buffer_size: int = 4096
|
||||
step_per_collect: int = 2048
|
||||
repeat_per_collect: int | None = 10
|
||||
update_per_step: int = 1
|
||||
update_per_step: float = 1.0
|
||||
"""
|
||||
Only used in off-policy algorithms.
|
||||
How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).
|
||||
"""
|
||||
start_timesteps: int = 0
|
||||
start_timesteps_random: bool = False
|
||||
# TODO can we set the parameters below more intelligently? Perhaps based on env. representation?
|
||||
# TODO can we set the parameters below intelligently? Perhaps based on env. representation?
|
||||
replay_buffer_ignore_obs_next: bool = False
|
||||
replay_buffer_save_only_last_obs: bool = False
|
||||
replay_buffer_stack_num: int = 1
|
||||
|
@ -13,6 +13,7 @@ from tianshou.highlevel.agent import (
|
||||
A2CAgentFactory,
|
||||
AgentFactory,
|
||||
DDPGAgentFactory,
|
||||
DQNAgentFactory,
|
||||
PPOAgentFactory,
|
||||
SACAgentFactory,
|
||||
TD3AgentFactory,
|
||||
@ -30,12 +31,18 @@ from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
||||
from tianshou.highlevel.params.policy_params import (
|
||||
A2CParams,
|
||||
DDPGParams,
|
||||
DQNParams,
|
||||
PPOParams,
|
||||
SACParams,
|
||||
TD3Params,
|
||||
)
|
||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||
from tianshou.highlevel.trainer import (
|
||||
TrainerCallbacks,
|
||||
TrainerEpochCallback,
|
||||
TrainerStopCallback,
|
||||
)
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.trainer import BaseTrainer
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
@ -160,6 +167,7 @@ class RLExperimentBuilder:
|
||||
self._optim_factory: OptimizerFactory | None = None
|
||||
self._env_config: PersistableConfigProtocol | None = None
|
||||
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||
|
||||
def with_env_config(self, config: PersistableConfigProtocol) -> Self:
|
||||
self._env_config = config
|
||||
@ -193,6 +201,18 @@ class RLExperimentBuilder:
|
||||
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
return self
|
||||
|
||||
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallback) -> Self:
|
||||
self._trainer_callbacks.epoch_callback_train = callback
|
||||
return self
|
||||
|
||||
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallback) -> Self:
|
||||
self._trainer_callbacks.epoch_callback_test = callback
|
||||
return self
|
||||
|
||||
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
|
||||
self._trainer_callbacks.stop_callback = callback
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
pass
|
||||
@ -205,6 +225,7 @@ class RLExperimentBuilder:
|
||||
|
||||
def build(self) -> RLExperiment:
|
||||
agent_factory = self._create_agent_factory()
|
||||
agent_factory.set_trainer_callbacks(self._trainer_callbacks)
|
||||
if self._policy_wrapper_factory:
|
||||
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
||||
experiment = RLExperiment(
|
||||
@ -302,21 +323,24 @@ class _BuilderMixinCriticsFactory:
|
||||
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||
def __init__(self):
|
||||
super().__init__(1)
|
||||
self._critic_use_actor_module = False
|
||||
|
||||
def with_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinSingleCriticFactory"
|
||||
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
||||
self._with_critic_factory(0, critic_factory)
|
||||
return self
|
||||
|
||||
def with_critic_factory_default(
|
||||
self: TBuilder,
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinSingleCriticFactory"
|
||||
) -> Self:
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
return self
|
||||
|
||||
|
||||
class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._critic_use_actor_module = False
|
||||
|
||||
def with_critic_factory_use_actor(self) -> Self:
|
||||
"""Makes the critic use the same network as the actor."""
|
||||
self._critic_use_actor_module = True
|
||||
@ -372,7 +396,7 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
class A2CExperimentBuilder(
|
||||
RLExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||
_BuilderMixinSingleCriticFactory,
|
||||
_BuilderMixinSingleCriticCanUseActorFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@ -383,7 +407,7 @@ class A2CExperimentBuilder(
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||
self._params: A2CParams = A2CParams()
|
||||
self._env_config = env_config
|
||||
|
||||
@ -406,7 +430,7 @@ class A2CExperimentBuilder(
|
||||
class PPOExperimentBuilder(
|
||||
RLExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||
_BuilderMixinSingleCriticFactory,
|
||||
_BuilderMixinSingleCriticCanUseActorFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@ -416,7 +440,7 @@ class PPOExperimentBuilder(
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||
self._params: PPOParams = PPOParams()
|
||||
|
||||
def with_ppo_params(self, params: PPOParams) -> Self:
|
||||
@ -435,9 +459,8 @@ class PPOExperimentBuilder(
|
||||
)
|
||||
|
||||
|
||||
class DDPGExperimentBuilder(
|
||||
class DQNExperimentBuilder(
|
||||
RLExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousDeterministic,
|
||||
_BuilderMixinSingleCriticFactory,
|
||||
):
|
||||
def __init__(
|
||||
@ -445,13 +468,40 @@ class DDPGExperimentBuilder(
|
||||
experiment_config: RLExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
sampling_config: RLSamplingConfig,
|
||||
env_config: PersistableConfigProtocol | None = None,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||
self._params: DQNParams = DQNParams()
|
||||
|
||||
def with_dqn_params(self, params: DQNParams) -> Self:
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return DQNAgentFactory(
|
||||
self._params,
|
||||
self._sampling_config,
|
||||
self._get_critic_factory(0),
|
||||
self._get_optim_factory(),
|
||||
)
|
||||
|
||||
|
||||
class DDPGExperimentBuilder(
|
||||
RLExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousDeterministic,
|
||||
_BuilderMixinSingleCriticCanUseActorFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: RLExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
sampling_config: RLSamplingConfig,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||
self._params: DDPGParams = DDPGParams()
|
||||
self._env_config = env_config
|
||||
|
||||
def with_ddpg_params(self, params: DDPGParams) -> Self:
|
||||
self._params = params
|
||||
|
@ -36,7 +36,9 @@ class ModuleFactory(ToStringMixin, ABC):
|
||||
pass
|
||||
|
||||
|
||||
class ModuleFactoryNet(ModuleFactory):
|
||||
class ModuleFactoryNet(
|
||||
ModuleFactory,
|
||||
): # TODO This is unused and broken; use it in ActorFactory* and so on?
|
||||
def __init__(self, hidden_sizes: int | Sequence[int]):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
|
@ -356,6 +356,21 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
|
||||
return transformers
|
||||
|
||||
|
||||
@dataclass
|
||||
class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
|
||||
discount_factor: float = 0.99
|
||||
estimation_step: int = 1
|
||||
target_update_freq: int = 0
|
||||
reward_normalization: bool = False
|
||||
is_double: bool = True
|
||||
clip_loss_grad: bool = False
|
||||
|
||||
def _get_param_transformers(self):
|
||||
transformers = super()._get_param_transformers()
|
||||
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
|
||||
return transformers
|
||||
|
||||
|
||||
@dataclass
|
||||
class DDPGParams(Params, ParamsMixinActorAndCritic):
|
||||
tau: float = 0.005
|
||||
|
55
tianshou/highlevel/trainer.py
Normal file
55
tianshou/highlevel/trainer.py
Normal file
@ -0,0 +1,55 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.logger import Logger
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
|
||||
|
||||
class TrainingContext:
|
||||
def __init__(self, policy: TPolicy, envs: Environments, logger: Logger):
|
||||
self.policy = policy
|
||||
self.envs = envs
|
||||
self.logger = logger
|
||||
|
||||
|
||||
class TrainerEpochCallback(ToStringMixin, ABC):
|
||||
"""Callback which is called at the beginning of each epoch."""
|
||||
|
||||
@abstractmethod
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
pass
|
||||
|
||||
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]:
|
||||
def fn(epoch, env_step):
|
||||
return self.callback(epoch, env_step, context)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
class TrainerStopCallback(ToStringMixin, ABC):
|
||||
"""Callback indicating whether training should stop."""
|
||||
|
||||
@abstractmethod
|
||||
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
||||
""":param mean_rewards: the average undiscounted returns of the testing result
|
||||
:return: True if the goal has been reached and training should stop, False otherwise
|
||||
"""
|
||||
|
||||
def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]:
|
||||
def fn(mean_rewards: float):
|
||||
return self.should_stop(mean_rewards, context)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerCallbacks:
|
||||
epoch_callback_train: TrainerEpochCallback | None = None
|
||||
epoch_callback_test: TrainerEpochCallback | None = None
|
||||
stop_callback: TrainerStopCallback | None = None
|
Loading…
x
Reference in New Issue
Block a user