Add DQN support in high-level API

* Allow to specify trainer callbacks (train_fn, test_fn, stop_fn)
  in high-level API, adding the necessary abstractions and pass-on
  mechanisms
* Add example atari_dqn_hl
This commit is contained in:
Dominik Jain 2023-10-05 15:39:32 +02:00
parent 358978c65d
commit 1cba589bd4
10 changed files with 414 additions and 31 deletions

View File

@ -0,0 +1,128 @@
#!/usr/bin/env python3
import datetime
import os
from jsonargparse import CLI
from examples.atari.atari_network import (
CriticFactoryAtariDQN,
FeatureNetFactoryDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
DQNExperimentBuilder,
RLExperimentConfig,
)
from tianshou.highlevel.params.policy_params import DQNParams
from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity,
)
from tianshou.highlevel.trainer import TrainerEpochCallback, TrainingContext
from tianshou.policy import DQNPolicy
from tianshou.utils import logging
def main(
experiment_config: RLExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: int = 0,
eps_test: float = 0.005,
eps_train: float = 1.0,
eps_train_final: float = 0.05,
buffer_size: int = 100000,
lr: float = 0.0001,
gamma: float = 0.99,
n_step: int = 3,
target_update_freq: int = 500,
epoch: int = 100,
step_per_epoch: int = 100000,
step_per_collect: int = 10,
update_per_step: float = 0.1,
batch_size: int = 32,
training_num: int = 10,
test_num: int = 10,
frames_stack: int = 4,
save_buffer_name: str | None = None, # TODO support?
icm_lr_scale: float = 0.0,
icm_reward_scale: float = 0.01,
icm_forward_loss_weight: float = 0.2,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
update_per_step=update_per_step,
repeat_per_collect=None,
replay_buffer_stack_num=frames_stack,
replay_buffer_ignore_obs_next=True,
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(
task,
experiment_config.seed,
sampling_config,
frames_stack,
scale=scale_obs,
)
class TrainEpochCallback(TrainerEpochCallback):
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
logger = context.logger.logger
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final)
else:
eps = eps_train_final
policy.set_eps(eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
class TestEpochCallback(TrainerEpochCallback):
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
policy.set_eps(eps_test)
builder = (
DQNExperimentBuilder(experiment_config, env_factory, sampling_config)
.with_dqn_params(
DQNParams(
discount_factor=gamma,
estimation_step=n_step,
lr=lr,
target_update_freq=target_update_freq,
),
)
.with_critic_factory(CriticFactoryAtariDQN())
.with_trainer_epoch_callback_train(TrainEpochCallback())
.with_trainer_epoch_callback_test(TestEpochCallback())
.with_trainer_stop_callback(AtariStopCallback(task))
)
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity(
FeatureNetFactoryDQN(),
[512],
lr,
icm_lr_scale,
icm_reward_scale,
icm_forward_loss_weight,
),
)
experiment = builder.build()
experiment.run(log_name)
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))

View File

@ -8,6 +8,7 @@ from torch import nn
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import ActorFactory from tianshou.highlevel.module.actor import ActorFactory
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice
from tianshou.highlevel.module.critic import CriticFactory
from tianshou.utils.net.common import BaseActor from tianshou.utils.net.common import BaseActor
from tianshou.utils.net.discrete import Actor, NoisyLinear from tianshou.utils.net.discrete import Actor, NoisyLinear
@ -226,6 +227,21 @@ class QRDQN(DQN):
return obs, state return obs, state
class CriticFactoryAtariDQN(CriticFactory):
def create_module(
self,
envs: Environments,
device: TDevice,
use_action: bool,
) -> torch.nn.Module:
assert use_action
return DQN(
*envs.get_observation_shape(),
envs.get_action_shape(),
device=device,
).to(device)
class ActorFactoryAtariDQN(ActorFactory): class ActorFactoryAtariDQN(ActorFactory):
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool): def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool):
self.hidden_size = hidden_size self.hidden_size = hidden_size

View File

@ -10,7 +10,7 @@ from examples.atari.atari_network import (
ActorFactoryAtariDQN, ActorFactoryAtariDQN,
FeatureNetFactoryDQN, FeatureNetFactoryDQN,
) )
from examples.atari.atari_wrapper import AtariEnvFactory from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
PPOExperimentBuilder, PPOExperimentBuilder,
@ -98,6 +98,7 @@ def main(
) )
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs)) .with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs))
.with_critic_factory_use_actor() .with_critic_factory_use_actor()
.with_trainer_stop_callback(AtariStopCallback(task))
) )
if icm_lr_scale > 0: if icm_lr_scale > 0:
builder.with_policy_wrapper_factory( builder.with_policy_wrapper_factory(

View File

@ -11,6 +11,7 @@ import numpy as np
from tianshou.env import ShmemVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
try: try:
import envpool import envpool
@ -374,11 +375,19 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
class AtariEnvFactory(EnvFactory): class AtariEnvFactory(EnvFactory):
def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig, frame_stack: int): def __init__(
self,
task: str,
seed: int,
sampling_config: RLSamplingConfig,
frame_stack: int,
scale: int = 0,
):
self.task = task self.task = task
self.sampling_config = sampling_config self.sampling_config = sampling_config
self.seed = seed self.seed = seed
self.frame_stack = frame_stack self.frame_stack = frame_stack
self.scale = scale
def create_envs(self, config=None) -> DiscreteEnvironments: def create_envs(self, config=None) -> DiscreteEnvironments:
env, train_envs, test_envs = make_atari_env( env, train_envs, test_envs = make_atari_env(
@ -386,7 +395,20 @@ class AtariEnvFactory(EnvFactory):
seed=self.seed, seed=self.seed,
training_num=self.sampling_config.num_train_envs, training_num=self.sampling_config.num_train_envs,
test_num=self.sampling_config.num_test_envs, test_num=self.sampling_config.num_test_envs,
scale=0, scale=self.scale,
frame_stack=self.frame_stack, frame_stack=self.frame_stack,
) )
return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
class AtariStopCallback(TrainerStopCallback):
def __init__(self, task: str):
self.task = task
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
env = context.envs.env
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
if "Pong" in self.task:
return mean_rewards >= 20
return False

View File

@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from typing import Generic, TypeVar from typing import Generic, TypeVar
import gymnasium
import torch import torch
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
@ -23,6 +24,7 @@ from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.policy_params import ( from tianshou.highlevel.params.policy_params import (
A2CParams, A2CParams,
DDPGParams, DDPGParams,
DQNParams,
Params, Params,
ParamTransformerData, ParamTransformerData,
PPOParams, PPOParams,
@ -30,10 +32,12 @@ from tianshou.highlevel.params.policy_params import (
TD3Params, TD3Params,
) )
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext
from tianshou.policy import ( from tianshou.policy import (
A2CPolicy, A2CPolicy,
BasePolicy, BasePolicy,
DDPGPolicy, DDPGPolicy,
DQNPolicy,
PPOPolicy, PPOPolicy,
SACPolicy, SACPolicy,
TD3Policy, TD3Policy,
@ -54,6 +58,7 @@ class AgentFactory(ABC, ToStringMixin):
self.sampling_config = sampling_config self.sampling_config = sampling_config
self.optim_factory = optim_factory self.optim_factory = optim_factory
self.policy_wrapper_factory: PolicyWrapperFactory | None = None self.policy_wrapper_factory: PolicyWrapperFactory | None = None
self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
def create_train_test_collector(self, policy: BasePolicy, envs: Environments): def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
buffer_size = self.sampling_config.buffer_size buffer_size = self.sampling_config.buffer_size
@ -85,6 +90,9 @@ class AgentFactory(ABC, ToStringMixin):
) -> None: ) -> None:
self.policy_wrapper_factory = policy_wrapper_factory self.policy_wrapper_factory = policy_wrapper_factory
def set_trainer_callbacks(self, callbacks: TrainerCallbacks):
self.trainer_callbacks = callbacks
@abstractmethod @abstractmethod
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
pass pass
@ -145,6 +153,21 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
logger: Logger, logger: Logger,
) -> OnpolicyTrainer: ) -> OnpolicyTrainer:
sampling_config = self.sampling_config sampling_config = self.sampling_config
callbacks = self.trainer_callbacks
context = TrainingContext(policy, envs, logger)
train_fn = (
callbacks.epoch_callback_train.get_trainer_fn(context)
if callbacks.epoch_callback_train
else None
)
test_fn = (
callbacks.epoch_callback_test.get_trainer_fn(context)
if callbacks.epoch_callback_test
else None
)
stop_fn = (
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
)
return OnpolicyTrainer( return OnpolicyTrainer(
policy=policy, policy=policy,
train_collector=train_collector, train_collector=train_collector,
@ -158,6 +181,9 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
save_best_fn=self._create_save_best_fn(envs, logger.log_path), save_best_fn=self._create_save_best_fn(envs, logger.log_path),
logger=logger.logger, logger=logger.logger,
test_in_train=False, test_in_train=False,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
) )
@ -171,6 +197,21 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
logger: Logger, logger: Logger,
) -> OffpolicyTrainer: ) -> OffpolicyTrainer:
sampling_config = self.sampling_config sampling_config = self.sampling_config
callbacks = self.trainer_callbacks
context = TrainingContext(policy, envs, logger)
train_fn = (
callbacks.epoch_callback_train.get_trainer_fn(context)
if callbacks.epoch_callback_train
else None
)
test_fn = (
callbacks.epoch_callback_test.get_trainer_fn(context)
if callbacks.epoch_callback_test
else None
)
stop_fn = (
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
)
return OffpolicyTrainer( return OffpolicyTrainer(
policy=policy, policy=policy,
train_collector=train_collector, train_collector=train_collector,
@ -184,6 +225,9 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
logger=logger.logger, logger=logger.logger,
update_per_step=sampling_config.update_per_step, update_per_step=sampling_config.update_per_step,
test_in_train=False, test_in_train=False,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
) )
@ -195,6 +239,23 @@ class _ActorMixin:
return self.actor_module_opt_factory.create_module_opt(envs, device, lr) return self.actor_module_opt_factory.create_module_opt(envs, device, lr)
class _CriticMixin:
def __init__(
self,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
critic_use_action: bool,
):
self.critic_module_opt_factory = CriticModuleOptFactory(
critic_factory,
optim_factory,
critic_use_action,
)
def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
return self.critic_module_opt_factory.create_module_opt(envs, device, lr)
class _ActorCriticMixin: class _ActorCriticMixin:
"""Mixin for agents that use an ActorCritic module with a single optimizer.""" """Mixin for agents that use an ActorCritic module with a single optimizer."""
@ -241,7 +302,7 @@ class _ActorCriticMixin:
return ActorCriticModuleOpt(actor_critic, optim) return ActorCriticModuleOpt(actor_critic, optim)
class _ActorAndCriticMixin(_ActorMixin): class _ActorAndCriticMixin(_ActorMixin, _CriticMixin):
def __init__( def __init__(
self, self,
actor_factory: ActorFactory, actor_factory: ActorFactory,
@ -249,15 +310,8 @@ class _ActorAndCriticMixin(_ActorMixin):
optim_factory: OptimizerFactory, optim_factory: OptimizerFactory,
critic_use_action: bool, critic_use_action: bool,
): ):
super().__init__(actor_factory, optim_factory) _ActorMixin.__init__(self, actor_factory, optim_factory)
self.critic_module_opt_factory = CriticModuleOptFactory( _CriticMixin.__init__(self, critic_factory, optim_factory, critic_use_action)
critic_factory,
optim_factory,
critic_use_action,
)
def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
return self.critic_module_opt_factory.create_module_opt(envs, device, lr)
class _ActorAndDualCriticsMixin(_ActorAndCriticMixin): class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
@ -385,6 +439,42 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
return self.create_actor_critic_module_opt(envs, device, self.params.lr) return self.create_actor_critic_module_opt(envs, device, self.params.lr)
class DQNAgentFactory(OffpolicyAgentFactory):
def __init__(
self,
params: DQNParams,
sampling_config: RLSamplingConfig,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
self.params = params
self.critic_factory = critic_factory
self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
critic = self.critic_factory.create_module(envs, device, use_action=True)
optim = self.optim_factory.create_optimizer(critic, self.params.lr)
kwargs = self.params.create_kwargs(
ParamTransformerData(
envs=envs,
device=device,
optim=optim,
optim_factory=self.optim_factory,
),
)
envs.get_type().assert_discrete(self)
# noinspection PyTypeChecker
action_space: gymnasium.spaces.Discrete = envs.get_action_space()
return DQNPolicy(
model=critic,
optim=optim,
action_space=action_space,
observation_space=envs.get_observation_space(),
**kwargs,
)
class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin): class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
def __init__( def __init__(
self, self,

View File

@ -14,10 +14,14 @@ class RLSamplingConfig:
buffer_size: int = 4096 buffer_size: int = 4096
step_per_collect: int = 2048 step_per_collect: int = 2048
repeat_per_collect: int | None = 10 repeat_per_collect: int | None = 10
update_per_step: int = 1 update_per_step: float = 1.0
"""
Only used in off-policy algorithms.
How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).
"""
start_timesteps: int = 0 start_timesteps: int = 0
start_timesteps_random: bool = False start_timesteps_random: bool = False
# TODO can we set the parameters below more intelligently? Perhaps based on env. representation? # TODO can we set the parameters below intelligently? Perhaps based on env. representation?
replay_buffer_ignore_obs_next: bool = False replay_buffer_ignore_obs_next: bool = False
replay_buffer_save_only_last_obs: bool = False replay_buffer_save_only_last_obs: bool = False
replay_buffer_stack_num: int = 1 replay_buffer_stack_num: int = 1

View File

@ -13,6 +13,7 @@ from tianshou.highlevel.agent import (
A2CAgentFactory, A2CAgentFactory,
AgentFactory, AgentFactory,
DDPGAgentFactory, DDPGAgentFactory,
DQNAgentFactory,
PPOAgentFactory, PPOAgentFactory,
SACAgentFactory, SACAgentFactory,
TD3AgentFactory, TD3AgentFactory,
@ -30,12 +31,18 @@ from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
from tianshou.highlevel.params.policy_params import ( from tianshou.highlevel.params.policy_params import (
A2CParams, A2CParams,
DDPGParams, DDPGParams,
DQNParams,
PPOParams, PPOParams,
SACParams, SACParams,
TD3Params, TD3Params,
) )
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.highlevel.trainer import (
TrainerCallbacks,
TrainerEpochCallback,
TrainerStopCallback,
)
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer from tianshou.trainer import BaseTrainer
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
@ -160,6 +167,7 @@ class RLExperimentBuilder:
self._optim_factory: OptimizerFactory | None = None self._optim_factory: OptimizerFactory | None = None
self._env_config: PersistableConfigProtocol | None = None self._env_config: PersistableConfigProtocol | None = None
self._policy_wrapper_factory: PolicyWrapperFactory | None = None self._policy_wrapper_factory: PolicyWrapperFactory | None = None
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
def with_env_config(self, config: PersistableConfigProtocol) -> Self: def with_env_config(self, config: PersistableConfigProtocol) -> Self:
self._env_config = config self._env_config = config
@ -193,6 +201,18 @@ class RLExperimentBuilder:
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay) self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
return self return self
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallback) -> Self:
self._trainer_callbacks.epoch_callback_train = callback
return self
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallback) -> Self:
self._trainer_callbacks.epoch_callback_test = callback
return self
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
self._trainer_callbacks.stop_callback = callback
return self
@abstractmethod @abstractmethod
def _create_agent_factory(self) -> AgentFactory: def _create_agent_factory(self) -> AgentFactory:
pass pass
@ -205,6 +225,7 @@ class RLExperimentBuilder:
def build(self) -> RLExperiment: def build(self) -> RLExperiment:
agent_factory = self._create_agent_factory() agent_factory = self._create_agent_factory()
agent_factory.set_trainer_callbacks(self._trainer_callbacks)
if self._policy_wrapper_factory: if self._policy_wrapper_factory:
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
experiment = RLExperiment( experiment = RLExperiment(
@ -302,21 +323,24 @@ class _BuilderMixinCriticsFactory:
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self): def __init__(self):
super().__init__(1) super().__init__(1)
self._critic_use_actor_module = False
def with_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder: def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
self: TBuilder | "_BuilderMixinSingleCriticFactory"
self._with_critic_factory(0, critic_factory) self._with_critic_factory(0, critic_factory)
return self return self
def with_critic_factory_default( def with_critic_factory_default(
self: TBuilder, self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> TBuilder: ) -> Self:
self: TBuilder | "_BuilderMixinSingleCriticFactory"
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes)
return self return self
class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory):
def __init__(self):
super().__init__()
self._critic_use_actor_module = False
def with_critic_factory_use_actor(self) -> Self: def with_critic_factory_use_actor(self) -> Self:
"""Makes the critic use the same network as the actor.""" """Makes the critic use the same network as the actor."""
self._critic_use_actor_module = True self._critic_use_actor_module = True
@ -372,7 +396,7 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
class A2CExperimentBuilder( class A2CExperimentBuilder(
RLExperimentBuilder, RLExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinActorFactory_ContinuousGaussian,
_BuilderMixinSingleCriticFactory, _BuilderMixinSingleCriticCanUseActorFactory,
): ):
def __init__( def __init__(
self, self,
@ -383,7 +407,7 @@ class A2CExperimentBuilder(
): ):
super().__init__(experiment_config, env_factory, sampling_config) super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticFactory.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
self._params: A2CParams = A2CParams() self._params: A2CParams = A2CParams()
self._env_config = env_config self._env_config = env_config
@ -406,7 +430,7 @@ class A2CExperimentBuilder(
class PPOExperimentBuilder( class PPOExperimentBuilder(
RLExperimentBuilder, RLExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinActorFactory_ContinuousGaussian,
_BuilderMixinSingleCriticFactory, _BuilderMixinSingleCriticCanUseActorFactory,
): ):
def __init__( def __init__(
self, self,
@ -416,7 +440,7 @@ class PPOExperimentBuilder(
): ):
super().__init__(experiment_config, env_factory, sampling_config) super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticFactory.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
self._params: PPOParams = PPOParams() self._params: PPOParams = PPOParams()
def with_ppo_params(self, params: PPOParams) -> Self: def with_ppo_params(self, params: PPOParams) -> Self:
@ -435,9 +459,8 @@ class PPOExperimentBuilder(
) )
class DDPGExperimentBuilder( class DQNExperimentBuilder(
RLExperimentBuilder, RLExperimentBuilder,
_BuilderMixinActorFactory_ContinuousDeterministic,
_BuilderMixinSingleCriticFactory, _BuilderMixinSingleCriticFactory,
): ):
def __init__( def __init__(
@ -445,13 +468,40 @@ class DDPGExperimentBuilder(
experiment_config: RLExperimentConfig, experiment_config: RLExperimentConfig,
env_factory: EnvFactory, env_factory: EnvFactory,
sampling_config: RLSamplingConfig, sampling_config: RLSamplingConfig,
env_config: PersistableConfigProtocol | None = None, ):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinSingleCriticFactory.__init__(self)
self._params: DQNParams = DQNParams()
def with_dqn_params(self, params: DQNParams) -> Self:
self._params = params
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
return DQNAgentFactory(
self._params,
self._sampling_config,
self._get_critic_factory(0),
self._get_optim_factory(),
)
class DDPGExperimentBuilder(
RLExperimentBuilder,
_BuilderMixinActorFactory_ContinuousDeterministic,
_BuilderMixinSingleCriticCanUseActorFactory,
):
def __init__(
self,
experiment_config: RLExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
): ):
super().__init__(experiment_config, env_factory, sampling_config) super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
_BuilderMixinSingleCriticFactory.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
self._params: DDPGParams = DDPGParams() self._params: DDPGParams = DDPGParams()
self._env_config = env_config
def with_ddpg_params(self, params: DDPGParams) -> Self: def with_ddpg_params(self, params: DDPGParams) -> Self:
self._params = params self._params = params

View File

@ -36,7 +36,9 @@ class ModuleFactory(ToStringMixin, ABC):
pass pass
class ModuleFactoryNet(ModuleFactory): class ModuleFactoryNet(
ModuleFactory,
): # TODO This is unused and broken; use it in ActorFactory* and so on?
def __init__(self, hidden_sizes: int | Sequence[int]): def __init__(self, hidden_sizes: int | Sequence[int]):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes

View File

@ -356,6 +356,21 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
return transformers return transformers
@dataclass
class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
discount_factor: float = 0.99
estimation_step: int = 1
target_update_freq: int = 0
reward_normalization: bool = False
is_double: bool = True
clip_loss_grad: bool = False
def _get_param_transformers(self):
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
return transformers
@dataclass @dataclass
class DDPGParams(Params, ParamsMixinActorAndCritic): class DDPGParams(Params, ParamsMixinActorAndCritic):
tau: float = 0.005 tau: float = 0.005

View File

@ -0,0 +1,55 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger
from tianshou.policy import BasePolicy
from tianshou.utils.string import ToStringMixin
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
class TrainingContext:
def __init__(self, policy: TPolicy, envs: Environments, logger: Logger):
self.policy = policy
self.envs = envs
self.logger = logger
class TrainerEpochCallback(ToStringMixin, ABC):
"""Callback which is called at the beginning of each epoch."""
@abstractmethod
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
pass
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]:
def fn(epoch, env_step):
return self.callback(epoch, env_step, context)
return fn
class TrainerStopCallback(ToStringMixin, ABC):
"""Callback indicating whether training should stop."""
@abstractmethod
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
""":param mean_rewards: the average undiscounted returns of the testing result
:return: True if the goal has been reached and training should stop, False otherwise
"""
def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]:
def fn(mean_rewards: float):
return self.should_stop(mean_rewards, context)
return fn
@dataclass
class TrainerCallbacks:
epoch_callback_train: TrainerEpochCallback | None = None
epoch_callback_test: TrainerEpochCallback | None = None
stop_callback: TrainerStopCallback | None = None