Add DQN support in high-level API
* Allow to specify trainer callbacks (train_fn, test_fn, stop_fn) in high-level API, adding the necessary abstractions and pass-on mechanisms * Add example atari_dqn_hl
This commit is contained in:
parent
358978c65d
commit
1cba589bd4
128
examples/atari/atari_dqn_hl.py
Normal file
128
examples/atari/atari_dqn_hl.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
|
||||||
|
from jsonargparse import CLI
|
||||||
|
|
||||||
|
from examples.atari.atari_network import (
|
||||||
|
CriticFactoryAtariDQN,
|
||||||
|
FeatureNetFactoryDQN,
|
||||||
|
)
|
||||||
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||||
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
|
from tianshou.highlevel.experiment import (
|
||||||
|
DQNExperimentBuilder,
|
||||||
|
RLExperimentConfig,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.params.policy_params import DQNParams
|
||||||
|
from tianshou.highlevel.params.policy_wrapper import (
|
||||||
|
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.trainer import TrainerEpochCallback, TrainingContext
|
||||||
|
from tianshou.policy import DQNPolicy
|
||||||
|
from tianshou.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
experiment_config: RLExperimentConfig,
|
||||||
|
task: str = "PongNoFrameskip-v4",
|
||||||
|
scale_obs: int = 0,
|
||||||
|
eps_test: float = 0.005,
|
||||||
|
eps_train: float = 1.0,
|
||||||
|
eps_train_final: float = 0.05,
|
||||||
|
buffer_size: int = 100000,
|
||||||
|
lr: float = 0.0001,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
n_step: int = 3,
|
||||||
|
target_update_freq: int = 500,
|
||||||
|
epoch: int = 100,
|
||||||
|
step_per_epoch: int = 100000,
|
||||||
|
step_per_collect: int = 10,
|
||||||
|
update_per_step: float = 0.1,
|
||||||
|
batch_size: int = 32,
|
||||||
|
training_num: int = 10,
|
||||||
|
test_num: int = 10,
|
||||||
|
frames_stack: int = 4,
|
||||||
|
save_buffer_name: str | None = None, # TODO support?
|
||||||
|
icm_lr_scale: float = 0.0,
|
||||||
|
icm_reward_scale: float = 0.01,
|
||||||
|
icm_forward_loss_weight: float = 0.2,
|
||||||
|
):
|
||||||
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
|
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||||
|
|
||||||
|
sampling_config = RLSamplingConfig(
|
||||||
|
num_epochs=epoch,
|
||||||
|
step_per_epoch=step_per_epoch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_train_envs=training_num,
|
||||||
|
num_test_envs=test_num,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
step_per_collect=step_per_collect,
|
||||||
|
update_per_step=update_per_step,
|
||||||
|
repeat_per_collect=None,
|
||||||
|
replay_buffer_stack_num=frames_stack,
|
||||||
|
replay_buffer_ignore_obs_next=True,
|
||||||
|
replay_buffer_save_only_last_obs=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
env_factory = AtariEnvFactory(
|
||||||
|
task,
|
||||||
|
experiment_config.seed,
|
||||||
|
sampling_config,
|
||||||
|
frames_stack,
|
||||||
|
scale=scale_obs,
|
||||||
|
)
|
||||||
|
|
||||||
|
class TrainEpochCallback(TrainerEpochCallback):
|
||||||
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||||
|
policy: DQNPolicy = context.policy
|
||||||
|
logger = context.logger.logger
|
||||||
|
# nature DQN setting, linear decay in the first 1M steps
|
||||||
|
if env_step <= 1e6:
|
||||||
|
eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final)
|
||||||
|
else:
|
||||||
|
eps = eps_train_final
|
||||||
|
policy.set_eps(eps)
|
||||||
|
if env_step % 1000 == 0:
|
||||||
|
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||||
|
|
||||||
|
class TestEpochCallback(TrainerEpochCallback):
|
||||||
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||||
|
policy: DQNPolicy = context.policy
|
||||||
|
policy.set_eps(eps_test)
|
||||||
|
|
||||||
|
builder = (
|
||||||
|
DQNExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||||
|
.with_dqn_params(
|
||||||
|
DQNParams(
|
||||||
|
discount_factor=gamma,
|
||||||
|
estimation_step=n_step,
|
||||||
|
lr=lr,
|
||||||
|
target_update_freq=target_update_freq,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.with_critic_factory(CriticFactoryAtariDQN())
|
||||||
|
.with_trainer_epoch_callback_train(TrainEpochCallback())
|
||||||
|
.with_trainer_epoch_callback_test(TestEpochCallback())
|
||||||
|
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||||
|
)
|
||||||
|
if icm_lr_scale > 0:
|
||||||
|
builder.with_policy_wrapper_factory(
|
||||||
|
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||||
|
FeatureNetFactoryDQN(),
|
||||||
|
[512],
|
||||||
|
lr,
|
||||||
|
icm_lr_scale,
|
||||||
|
icm_reward_scale,
|
||||||
|
icm_forward_loss_weight,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
experiment = builder.build()
|
||||||
|
experiment.run(log_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.run_main(lambda: CLI(main))
|
@ -8,6 +8,7 @@ from torch import nn
|
|||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.module.actor import ActorFactory
|
from tianshou.highlevel.module.actor import ActorFactory
|
||||||
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice
|
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice
|
||||||
|
from tianshou.highlevel.module.critic import CriticFactory
|
||||||
from tianshou.utils.net.common import BaseActor
|
from tianshou.utils.net.common import BaseActor
|
||||||
from tianshou.utils.net.discrete import Actor, NoisyLinear
|
from tianshou.utils.net.discrete import Actor, NoisyLinear
|
||||||
|
|
||||||
@ -226,6 +227,21 @@ class QRDQN(DQN):
|
|||||||
return obs, state
|
return obs, state
|
||||||
|
|
||||||
|
|
||||||
|
class CriticFactoryAtariDQN(CriticFactory):
|
||||||
|
def create_module(
|
||||||
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
device: TDevice,
|
||||||
|
use_action: bool,
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
assert use_action
|
||||||
|
return DQN(
|
||||||
|
*envs.get_observation_shape(),
|
||||||
|
envs.get_action_shape(),
|
||||||
|
device=device,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
|
||||||
class ActorFactoryAtariDQN(ActorFactory):
|
class ActorFactoryAtariDQN(ActorFactory):
|
||||||
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool):
|
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool):
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
@ -10,7 +10,7 @@ from examples.atari.atari_network import (
|
|||||||
ActorFactoryAtariDQN,
|
ActorFactoryAtariDQN,
|
||||||
FeatureNetFactoryDQN,
|
FeatureNetFactoryDQN,
|
||||||
)
|
)
|
||||||
from examples.atari.atari_wrapper import AtariEnvFactory
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
PPOExperimentBuilder,
|
PPOExperimentBuilder,
|
||||||
@ -98,6 +98,7 @@ def main(
|
|||||||
)
|
)
|
||||||
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs))
|
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs))
|
||||||
.with_critic_factory_use_actor()
|
.with_critic_factory_use_actor()
|
||||||
|
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||||
)
|
)
|
||||||
if icm_lr_scale > 0:
|
if icm_lr_scale > 0:
|
||||||
builder.with_policy_wrapper_factory(
|
builder.with_policy_wrapper_factory(
|
||||||
|
@ -11,6 +11,7 @@ import numpy as np
|
|||||||
from tianshou.env import ShmemVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
|
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
|
||||||
|
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import envpool
|
import envpool
|
||||||
@ -374,11 +375,19 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
class AtariEnvFactory(EnvFactory):
|
class AtariEnvFactory(EnvFactory):
|
||||||
def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig, frame_stack: int):
|
def __init__(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
seed: int,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
frame_stack: int,
|
||||||
|
scale: int = 0,
|
||||||
|
):
|
||||||
self.task = task
|
self.task = task
|
||||||
self.sampling_config = sampling_config
|
self.sampling_config = sampling_config
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.frame_stack = frame_stack
|
self.frame_stack = frame_stack
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
def create_envs(self, config=None) -> DiscreteEnvironments:
|
def create_envs(self, config=None) -> DiscreteEnvironments:
|
||||||
env, train_envs, test_envs = make_atari_env(
|
env, train_envs, test_envs = make_atari_env(
|
||||||
@ -386,7 +395,20 @@ class AtariEnvFactory(EnvFactory):
|
|||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
training_num=self.sampling_config.num_train_envs,
|
training_num=self.sampling_config.num_train_envs,
|
||||||
test_num=self.sampling_config.num_test_envs,
|
test_num=self.sampling_config.num_test_envs,
|
||||||
scale=0,
|
scale=self.scale,
|
||||||
frame_stack=self.frame_stack,
|
frame_stack=self.frame_stack,
|
||||||
)
|
)
|
||||||
return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||||
|
|
||||||
|
|
||||||
|
class AtariStopCallback(TrainerStopCallback):
|
||||||
|
def __init__(self, task: str):
|
||||||
|
self.task = task
|
||||||
|
|
||||||
|
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
||||||
|
env = context.envs.env
|
||||||
|
if env.spec.reward_threshold:
|
||||||
|
return mean_rewards >= env.spec.reward_threshold
|
||||||
|
if "Pong" in self.task:
|
||||||
|
return mean_rewards >= 20
|
||||||
|
return False
|
||||||
|
@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
import gymnasium
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||||
@ -23,6 +24,7 @@ from tianshou.highlevel.optim import OptimizerFactory
|
|||||||
from tianshou.highlevel.params.policy_params import (
|
from tianshou.highlevel.params.policy_params import (
|
||||||
A2CParams,
|
A2CParams,
|
||||||
DDPGParams,
|
DDPGParams,
|
||||||
|
DQNParams,
|
||||||
Params,
|
Params,
|
||||||
ParamTransformerData,
|
ParamTransformerData,
|
||||||
PPOParams,
|
PPOParams,
|
||||||
@ -30,10 +32,12 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
TD3Params,
|
TD3Params,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||||
|
from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext
|
||||||
from tianshou.policy import (
|
from tianshou.policy import (
|
||||||
A2CPolicy,
|
A2CPolicy,
|
||||||
BasePolicy,
|
BasePolicy,
|
||||||
DDPGPolicy,
|
DDPGPolicy,
|
||||||
|
DQNPolicy,
|
||||||
PPOPolicy,
|
PPOPolicy,
|
||||||
SACPolicy,
|
SACPolicy,
|
||||||
TD3Policy,
|
TD3Policy,
|
||||||
@ -54,6 +58,7 @@ class AgentFactory(ABC, ToStringMixin):
|
|||||||
self.sampling_config = sampling_config
|
self.sampling_config = sampling_config
|
||||||
self.optim_factory = optim_factory
|
self.optim_factory = optim_factory
|
||||||
self.policy_wrapper_factory: PolicyWrapperFactory | None = None
|
self.policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||||
|
self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||||
|
|
||||||
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
|
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
|
||||||
buffer_size = self.sampling_config.buffer_size
|
buffer_size = self.sampling_config.buffer_size
|
||||||
@ -85,6 +90,9 @@ class AgentFactory(ABC, ToStringMixin):
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.policy_wrapper_factory = policy_wrapper_factory
|
self.policy_wrapper_factory = policy_wrapper_factory
|
||||||
|
|
||||||
|
def set_trainer_callbacks(self, callbacks: TrainerCallbacks):
|
||||||
|
self.trainer_callbacks = callbacks
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
pass
|
pass
|
||||||
@ -145,6 +153,21 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
logger: Logger,
|
logger: Logger,
|
||||||
) -> OnpolicyTrainer:
|
) -> OnpolicyTrainer:
|
||||||
sampling_config = self.sampling_config
|
sampling_config = self.sampling_config
|
||||||
|
callbacks = self.trainer_callbacks
|
||||||
|
context = TrainingContext(policy, envs, logger)
|
||||||
|
train_fn = (
|
||||||
|
callbacks.epoch_callback_train.get_trainer_fn(context)
|
||||||
|
if callbacks.epoch_callback_train
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
test_fn = (
|
||||||
|
callbacks.epoch_callback_test.get_trainer_fn(context)
|
||||||
|
if callbacks.epoch_callback_test
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
stop_fn = (
|
||||||
|
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
||||||
|
)
|
||||||
return OnpolicyTrainer(
|
return OnpolicyTrainer(
|
||||||
policy=policy,
|
policy=policy,
|
||||||
train_collector=train_collector,
|
train_collector=train_collector,
|
||||||
@ -158,6 +181,9 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
||||||
logger=logger.logger,
|
logger=logger.logger,
|
||||||
test_in_train=False,
|
test_in_train=False,
|
||||||
|
train_fn=train_fn,
|
||||||
|
test_fn=test_fn,
|
||||||
|
stop_fn=stop_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -171,6 +197,21 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
logger: Logger,
|
logger: Logger,
|
||||||
) -> OffpolicyTrainer:
|
) -> OffpolicyTrainer:
|
||||||
sampling_config = self.sampling_config
|
sampling_config = self.sampling_config
|
||||||
|
callbacks = self.trainer_callbacks
|
||||||
|
context = TrainingContext(policy, envs, logger)
|
||||||
|
train_fn = (
|
||||||
|
callbacks.epoch_callback_train.get_trainer_fn(context)
|
||||||
|
if callbacks.epoch_callback_train
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
test_fn = (
|
||||||
|
callbacks.epoch_callback_test.get_trainer_fn(context)
|
||||||
|
if callbacks.epoch_callback_test
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
stop_fn = (
|
||||||
|
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
||||||
|
)
|
||||||
return OffpolicyTrainer(
|
return OffpolicyTrainer(
|
||||||
policy=policy,
|
policy=policy,
|
||||||
train_collector=train_collector,
|
train_collector=train_collector,
|
||||||
@ -184,6 +225,9 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
logger=logger.logger,
|
logger=logger.logger,
|
||||||
update_per_step=sampling_config.update_per_step,
|
update_per_step=sampling_config.update_per_step,
|
||||||
test_in_train=False,
|
test_in_train=False,
|
||||||
|
train_fn=train_fn,
|
||||||
|
test_fn=test_fn,
|
||||||
|
stop_fn=stop_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -195,6 +239,23 @@ class _ActorMixin:
|
|||||||
return self.actor_module_opt_factory.create_module_opt(envs, device, lr)
|
return self.actor_module_opt_factory.create_module_opt(envs, device, lr)
|
||||||
|
|
||||||
|
|
||||||
|
class _CriticMixin:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
critic_factory: CriticFactory,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
critic_use_action: bool,
|
||||||
|
):
|
||||||
|
self.critic_module_opt_factory = CriticModuleOptFactory(
|
||||||
|
critic_factory,
|
||||||
|
optim_factory,
|
||||||
|
critic_use_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||||
|
return self.critic_module_opt_factory.create_module_opt(envs, device, lr)
|
||||||
|
|
||||||
|
|
||||||
class _ActorCriticMixin:
|
class _ActorCriticMixin:
|
||||||
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
|
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
|
||||||
|
|
||||||
@ -241,7 +302,7 @@ class _ActorCriticMixin:
|
|||||||
return ActorCriticModuleOpt(actor_critic, optim)
|
return ActorCriticModuleOpt(actor_critic, optim)
|
||||||
|
|
||||||
|
|
||||||
class _ActorAndCriticMixin(_ActorMixin):
|
class _ActorAndCriticMixin(_ActorMixin, _CriticMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
@ -249,15 +310,8 @@ class _ActorAndCriticMixin(_ActorMixin):
|
|||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
critic_use_action: bool,
|
critic_use_action: bool,
|
||||||
):
|
):
|
||||||
super().__init__(actor_factory, optim_factory)
|
_ActorMixin.__init__(self, actor_factory, optim_factory)
|
||||||
self.critic_module_opt_factory = CriticModuleOptFactory(
|
_CriticMixin.__init__(self, critic_factory, optim_factory, critic_use_action)
|
||||||
critic_factory,
|
|
||||||
optim_factory,
|
|
||||||
critic_use_action,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
||||||
return self.critic_module_opt_factory.create_module_opt(envs, device, lr)
|
|
||||||
|
|
||||||
|
|
||||||
class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
|
class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
|
||||||
@ -385,6 +439,42 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
|||||||
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||||
|
|
||||||
|
|
||||||
|
class DQNAgentFactory(OffpolicyAgentFactory):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: DQNParams,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
critic_factory: CriticFactory,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
):
|
||||||
|
super().__init__(sampling_config, optim_factory)
|
||||||
|
self.params = params
|
||||||
|
self.critic_factory = critic_factory
|
||||||
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
|
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
|
critic = self.critic_factory.create_module(envs, device, use_action=True)
|
||||||
|
optim = self.optim_factory.create_optimizer(critic, self.params.lr)
|
||||||
|
kwargs = self.params.create_kwargs(
|
||||||
|
ParamTransformerData(
|
||||||
|
envs=envs,
|
||||||
|
device=device,
|
||||||
|
optim=optim,
|
||||||
|
optim_factory=self.optim_factory,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
envs.get_type().assert_discrete(self)
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
action_space: gymnasium.spaces.Discrete = envs.get_action_space()
|
||||||
|
return DQNPolicy(
|
||||||
|
model=critic,
|
||||||
|
optim=optim,
|
||||||
|
action_space=action_space,
|
||||||
|
observation_space=envs.get_observation_space(),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
|
class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -14,10 +14,14 @@ class RLSamplingConfig:
|
|||||||
buffer_size: int = 4096
|
buffer_size: int = 4096
|
||||||
step_per_collect: int = 2048
|
step_per_collect: int = 2048
|
||||||
repeat_per_collect: int | None = 10
|
repeat_per_collect: int | None = 10
|
||||||
update_per_step: int = 1
|
update_per_step: float = 1.0
|
||||||
|
"""
|
||||||
|
Only used in off-policy algorithms.
|
||||||
|
How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).
|
||||||
|
"""
|
||||||
start_timesteps: int = 0
|
start_timesteps: int = 0
|
||||||
start_timesteps_random: bool = False
|
start_timesteps_random: bool = False
|
||||||
# TODO can we set the parameters below more intelligently? Perhaps based on env. representation?
|
# TODO can we set the parameters below intelligently? Perhaps based on env. representation?
|
||||||
replay_buffer_ignore_obs_next: bool = False
|
replay_buffer_ignore_obs_next: bool = False
|
||||||
replay_buffer_save_only_last_obs: bool = False
|
replay_buffer_save_only_last_obs: bool = False
|
||||||
replay_buffer_stack_num: int = 1
|
replay_buffer_stack_num: int = 1
|
||||||
|
@ -13,6 +13,7 @@ from tianshou.highlevel.agent import (
|
|||||||
A2CAgentFactory,
|
A2CAgentFactory,
|
||||||
AgentFactory,
|
AgentFactory,
|
||||||
DDPGAgentFactory,
|
DDPGAgentFactory,
|
||||||
|
DQNAgentFactory,
|
||||||
PPOAgentFactory,
|
PPOAgentFactory,
|
||||||
SACAgentFactory,
|
SACAgentFactory,
|
||||||
TD3AgentFactory,
|
TD3AgentFactory,
|
||||||
@ -30,12 +31,18 @@ from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
|||||||
from tianshou.highlevel.params.policy_params import (
|
from tianshou.highlevel.params.policy_params import (
|
||||||
A2CParams,
|
A2CParams,
|
||||||
DDPGParams,
|
DDPGParams,
|
||||||
|
DQNParams,
|
||||||
PPOParams,
|
PPOParams,
|
||||||
SACParams,
|
SACParams,
|
||||||
TD3Params,
|
TD3Params,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||||
|
from tianshou.highlevel.trainer import (
|
||||||
|
TrainerCallbacks,
|
||||||
|
TrainerEpochCallback,
|
||||||
|
TrainerStopCallback,
|
||||||
|
)
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.trainer import BaseTrainer
|
from tianshou.trainer import BaseTrainer
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
@ -160,6 +167,7 @@ class RLExperimentBuilder:
|
|||||||
self._optim_factory: OptimizerFactory | None = None
|
self._optim_factory: OptimizerFactory | None = None
|
||||||
self._env_config: PersistableConfigProtocol | None = None
|
self._env_config: PersistableConfigProtocol | None = None
|
||||||
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||||
|
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||||
|
|
||||||
def with_env_config(self, config: PersistableConfigProtocol) -> Self:
|
def with_env_config(self, config: PersistableConfigProtocol) -> Self:
|
||||||
self._env_config = config
|
self._env_config = config
|
||||||
@ -193,6 +201,18 @@ class RLExperimentBuilder:
|
|||||||
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
|
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallback) -> Self:
|
||||||
|
self._trainer_callbacks.epoch_callback_train = callback
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallback) -> Self:
|
||||||
|
self._trainer_callbacks.epoch_callback_test = callback
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
|
||||||
|
self._trainer_callbacks.stop_callback = callback
|
||||||
|
return self
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _create_agent_factory(self) -> AgentFactory:
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
pass
|
pass
|
||||||
@ -205,6 +225,7 @@ class RLExperimentBuilder:
|
|||||||
|
|
||||||
def build(self) -> RLExperiment:
|
def build(self) -> RLExperiment:
|
||||||
agent_factory = self._create_agent_factory()
|
agent_factory = self._create_agent_factory()
|
||||||
|
agent_factory.set_trainer_callbacks(self._trainer_callbacks)
|
||||||
if self._policy_wrapper_factory:
|
if self._policy_wrapper_factory:
|
||||||
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
||||||
experiment = RLExperiment(
|
experiment = RLExperiment(
|
||||||
@ -302,21 +323,24 @@ class _BuilderMixinCriticsFactory:
|
|||||||
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(1)
|
super().__init__(1)
|
||||||
self._critic_use_actor_module = False
|
|
||||||
|
|
||||||
def with_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
||||||
self: TBuilder | "_BuilderMixinSingleCriticFactory"
|
|
||||||
self._with_critic_factory(0, critic_factory)
|
self._with_critic_factory(0, critic_factory)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_critic_factory_default(
|
def with_critic_factory_default(
|
||||||
self: TBuilder,
|
self,
|
||||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||||
) -> TBuilder:
|
) -> Self:
|
||||||
self: TBuilder | "_BuilderMixinSingleCriticFactory"
|
|
||||||
self._with_critic_factory_default(0, hidden_sizes)
|
self._with_critic_factory_default(0, hidden_sizes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._critic_use_actor_module = False
|
||||||
|
|
||||||
def with_critic_factory_use_actor(self) -> Self:
|
def with_critic_factory_use_actor(self) -> Self:
|
||||||
"""Makes the critic use the same network as the actor."""
|
"""Makes the critic use the same network as the actor."""
|
||||||
self._critic_use_actor_module = True
|
self._critic_use_actor_module = True
|
||||||
@ -372,7 +396,7 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
class A2CExperimentBuilder(
|
class A2CExperimentBuilder(
|
||||||
RLExperimentBuilder,
|
RLExperimentBuilder,
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||||
_BuilderMixinSingleCriticFactory,
|
_BuilderMixinSingleCriticCanUseActorFactory,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -383,7 +407,7 @@ class A2CExperimentBuilder(
|
|||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||||
self._params: A2CParams = A2CParams()
|
self._params: A2CParams = A2CParams()
|
||||||
self._env_config = env_config
|
self._env_config = env_config
|
||||||
|
|
||||||
@ -406,7 +430,7 @@ class A2CExperimentBuilder(
|
|||||||
class PPOExperimentBuilder(
|
class PPOExperimentBuilder(
|
||||||
RLExperimentBuilder,
|
RLExperimentBuilder,
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||||
_BuilderMixinSingleCriticFactory,
|
_BuilderMixinSingleCriticCanUseActorFactory,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -416,7 +440,7 @@ class PPOExperimentBuilder(
|
|||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||||
self._params: PPOParams = PPOParams()
|
self._params: PPOParams = PPOParams()
|
||||||
|
|
||||||
def with_ppo_params(self, params: PPOParams) -> Self:
|
def with_ppo_params(self, params: PPOParams) -> Self:
|
||||||
@ -435,9 +459,8 @@ class PPOExperimentBuilder(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DDPGExperimentBuilder(
|
class DQNExperimentBuilder(
|
||||||
RLExperimentBuilder,
|
RLExperimentBuilder,
|
||||||
_BuilderMixinActorFactory_ContinuousDeterministic,
|
|
||||||
_BuilderMixinSingleCriticFactory,
|
_BuilderMixinSingleCriticFactory,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -445,13 +468,40 @@ class DDPGExperimentBuilder(
|
|||||||
experiment_config: RLExperimentConfig,
|
experiment_config: RLExperimentConfig,
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: RLSamplingConfig,
|
||||||
env_config: PersistableConfigProtocol | None = None,
|
):
|
||||||
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
|
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||||
|
self._params: DQNParams = DQNParams()
|
||||||
|
|
||||||
|
def with_dqn_params(self, params: DQNParams) -> Self:
|
||||||
|
self._params = params
|
||||||
|
return self
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
|
return DQNAgentFactory(
|
||||||
|
self._params,
|
||||||
|
self._sampling_config,
|
||||||
|
self._get_critic_factory(0),
|
||||||
|
self._get_optim_factory(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DDPGExperimentBuilder(
|
||||||
|
RLExperimentBuilder,
|
||||||
|
_BuilderMixinActorFactory_ContinuousDeterministic,
|
||||||
|
_BuilderMixinSingleCriticCanUseActorFactory,
|
||||||
|
):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
experiment_config: RLExperimentConfig,
|
||||||
|
env_factory: EnvFactory,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||||
self._params: DDPGParams = DDPGParams()
|
self._params: DDPGParams = DDPGParams()
|
||||||
self._env_config = env_config
|
|
||||||
|
|
||||||
def with_ddpg_params(self, params: DDPGParams) -> Self:
|
def with_ddpg_params(self, params: DDPGParams) -> Self:
|
||||||
self._params = params
|
self._params = params
|
||||||
|
@ -36,7 +36,9 @@ class ModuleFactory(ToStringMixin, ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ModuleFactoryNet(ModuleFactory):
|
class ModuleFactoryNet(
|
||||||
|
ModuleFactory,
|
||||||
|
): # TODO This is unused and broken; use it in ActorFactory* and so on?
|
||||||
def __init__(self, hidden_sizes: int | Sequence[int]):
|
def __init__(self, hidden_sizes: int | Sequence[int]):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
|
@ -356,6 +356,21 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
|
|||||||
return transformers
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
|
||||||
|
discount_factor: float = 0.99
|
||||||
|
estimation_step: int = 1
|
||||||
|
target_update_freq: int = 0
|
||||||
|
reward_normalization: bool = False
|
||||||
|
is_double: bool = True
|
||||||
|
clip_loss_grad: bool = False
|
||||||
|
|
||||||
|
def _get_param_transformers(self):
|
||||||
|
transformers = super()._get_param_transformers()
|
||||||
|
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
|
||||||
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DDPGParams(Params, ParamsMixinActorAndCritic):
|
class DDPGParams(Params, ParamsMixinActorAndCritic):
|
||||||
tau: float = 0.005
|
tau: float = 0.005
|
||||||
|
55
tianshou/highlevel/trainer.py
Normal file
55
tianshou/highlevel/trainer.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from tianshou.highlevel.env import Environments
|
||||||
|
from tianshou.highlevel.logger import Logger
|
||||||
|
from tianshou.policy import BasePolicy
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingContext:
|
||||||
|
def __init__(self, policy: TPolicy, envs: Environments, logger: Logger):
|
||||||
|
self.policy = policy
|
||||||
|
self.envs = envs
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerEpochCallback(ToStringMixin, ABC):
|
||||||
|
"""Callback which is called at the beginning of each epoch."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]:
|
||||||
|
def fn(epoch, env_step):
|
||||||
|
return self.callback(epoch, env_step, context)
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerStopCallback(ToStringMixin, ABC):
|
||||||
|
"""Callback indicating whether training should stop."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
||||||
|
""":param mean_rewards: the average undiscounted returns of the testing result
|
||||||
|
:return: True if the goal has been reached and training should stop, False otherwise
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]:
|
||||||
|
def fn(mean_rewards: float):
|
||||||
|
return self.should_stop(mean_rewards, context)
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainerCallbacks:
|
||||||
|
epoch_callback_train: TrainerEpochCallback | None = None
|
||||||
|
epoch_callback_test: TrainerEpochCallback | None = None
|
||||||
|
stop_callback: TrainerStopCallback | None = None
|
Loading…
x
Reference in New Issue
Block a user