diff --git a/examples/atari/atari_callbacks.py b/examples/atari/atari_callbacks.py new file mode 100644 index 0000000..41b4f83 --- /dev/null +++ b/examples/atari/atari_callbacks.py @@ -0,0 +1,33 @@ +from tianshou.highlevel.trainer import ( + TrainerEpochCallbackTest, + TrainerEpochCallbackTrain, + TrainingContext, +) +from tianshou.policy import DQNPolicy + + +class TestEpochCallbackDQNSetEps(TrainerEpochCallbackTest): + def __init__(self, eps_test: float): + self.eps_test = eps_test + + def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: + policy: DQNPolicy = context.policy + policy.set_eps(self.eps_test) + + +class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain): + def __init__(self, eps_train: float, eps_train_final: float): + self.eps_train = eps_train + self.eps_train_final = eps_train_final + + def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: + policy: DQNPolicy = context.policy + logger = context.logger.logger + # nature DQN setting, linear decay in the first 1M steps + if env_step <= 1e6: + eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final) + else: + eps = self.eps_train_final + policy.set_eps(eps) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/eps": eps}) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 63d5a51..b2722cd 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -1,13 +1,12 @@ #!/usr/bin/env python3 -import datetime import os from jsonargparse import CLI from examples.atari.atari_network import ( ActorFactoryAtariPlainDQN, - FeatureNetFactoryDQN, + IntermediateModuleFactoryAtariDQN, ) from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from tianshou.highlevel.config import SamplingConfig @@ -26,6 +25,7 @@ from tianshou.highlevel.trainer import ( ) from tianshou.policy import DQNPolicy from tianshou.utils import logging +from tianshou.utils.logging import datetime_tag def main( @@ -53,8 +53,7 @@ def main( icm_reward_scale: float = 0.01, icm_forward_loss_weight: float = 0.2, ): - now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") - log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) + log_name = os.path.join(task, "dqn", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( num_epochs=epoch, @@ -115,7 +114,7 @@ def main( if icm_lr_scale > 0: builder.with_policy_wrapper_factory( PolicyWrapperFactoryIntrinsicCuriosity( - FeatureNetFactoryDQN(), + IntermediateModuleFactoryAtariDQN(net_only=True), [512], lr, icm_lr_scale, diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py new file mode 100644 index 0000000..4d1d6c1 --- /dev/null +++ b/examples/atari/atari_iqn_hl.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 + +import os +from collections.abc import Sequence + +from jsonargparse import CLI + +from examples.atari.atari_callbacks import ( + TestEpochCallbackDQNSetEps, + TrainEpochCallbackNatureDQNEpsLinearDecay, +) +from examples.atari.atari_network import ( + IntermediateModuleFactoryAtariDQN, +) +from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback +from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.experiment import ( + ExperimentConfig, + IQNExperimentBuilder, +) +from tianshou.highlevel.params.policy_params import IQNParams +from tianshou.utils import logging +from tianshou.utils.logging import datetime_tag + + +def main( + experiment_config: ExperimentConfig, + task: str = "PongNoFrameskip-v4", + scale_obs: int = 0, + eps_test: float = 0.005, + eps_train: float = 1.0, + eps_train_final: float = 0.05, + buffer_size: int = 100000, + lr: float = 0.0001, + gamma: float = 0.99, + sample_size: int = 32, + online_sample_size: int = 8, + target_sample_size: int = 8, + num_cosines: int = 64, + hidden_sizes: Sequence[int] = (512,), + n_step: int = 3, + target_update_freq: int = 500, + epoch: int = 100, + step_per_epoch: int = 100000, + step_per_collect: int = 10, + update_per_step: float = 0.1, + batch_size: int = 32, + training_num: int = 10, + test_num: int = 10, + frames_stack: int = 4, + save_buffer_name: str | None = None, # TODO support? +): + log_name = os.path.join(task, "iqn", str(experiment_config.seed), datetime_tag()) + + sampling_config = SamplingConfig( + num_epochs=epoch, + step_per_epoch=step_per_epoch, + batch_size=batch_size, + num_train_envs=training_num, + num_test_envs=test_num, + buffer_size=buffer_size, + step_per_collect=step_per_collect, + update_per_step=update_per_step, + repeat_per_collect=None, + replay_buffer_stack_num=frames_stack, + replay_buffer_ignore_obs_next=True, + replay_buffer_save_only_last_obs=True, + ) + + env_factory = AtariEnvFactory( + task, + experiment_config.seed, + sampling_config, + frames_stack, + scale=scale_obs, + ) + + experiment = ( + IQNExperimentBuilder(env_factory, experiment_config, sampling_config) + .with_iqn_params( + IQNParams( + discount_factor=gamma, + estimation_step=n_step, + lr=lr, + sample_size=sample_size, + online_sample_size=online_sample_size, + target_update_freq=target_update_freq, + target_sample_size=target_sample_size, + hidden_sizes=hidden_sizes, + num_cosines=num_cosines, + ), + ) + .with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(net_only=False)) + .with_trainer_epoch_callback_train( + TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final), + ) + .with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test)) + .with_trainer_stop_callback(AtariStopCallback(task)) + .build() + ) + experiment.run(log_name) + + +if __name__ == "__main__": + logging.run_main(lambda: CLI(main)) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 3757d70..0767eb7 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -7,7 +7,11 @@ from torch import nn from tianshou.highlevel.env import Environments from tianshou.highlevel.module.actor import ActorFactory -from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice +from tianshou.highlevel.module.core import ( + IntermediateModule, + IntermediateModuleFactory, + TDevice, +) from tianshou.utils.net.discrete import Actor, NoisyLinear @@ -253,12 +257,15 @@ class ActorFactoryAtariDQN(ActorFactory): return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device) -class FeatureNetFactoryDQN(ModuleFactory): - def create_module(self, envs: Environments, device: TDevice) -> Module: +class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory): + def __init__(self, net_only: bool): + self.net_only = net_only + + def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: dqn = DQN( *envs.get_observation_shape(), envs.get_action_shape(), - device, + device=device, features_only=True, ) - return Module(dqn.net, dqn.output_dim) + return IntermediateModule(dqn.net if self.net_only else dqn, dqn.output_dim) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index c6cc321..9a5777b 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -8,7 +8,7 @@ from jsonargparse import CLI from examples.atari.atari_network import ( ActorFactoryAtariDQN, - FeatureNetFactoryDQN, + IntermediateModuleFactoryAtariDQN, ) from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from tianshou.highlevel.config import SamplingConfig @@ -103,7 +103,7 @@ def main( if icm_lr_scale > 0: builder.with_policy_wrapper_factory( PolicyWrapperFactoryIntrinsicCuriosity( - FeatureNetFactoryDQN(), + IntermediateModuleFactoryAtariDQN(net_only=True), [hidden_sizes], lr, icm_lr_scale, diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 735cd1d..9909687 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -6,7 +6,7 @@ from jsonargparse import CLI from examples.atari.atari_network import ( ActorFactoryAtariDQN, - FeatureNetFactoryDQN, + IntermediateModuleFactoryAtariDQN, ) from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from tianshou.highlevel.config import SamplingConfig @@ -26,7 +26,7 @@ from tianshou.utils.logging import datetime_tag def main( experiment_config: ExperimentConfig, task: str = "PongNoFrameskip-v4", - scale_obs: bool = False, + scale_obs: int = 0, buffer_size: int = 100000, actor_lr: float = 1e-5, critic_lr: float = 1e-5, @@ -67,7 +67,9 @@ def main( replay_buffer_save_only_last_obs=True, ) - env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack) + env_factory = AtariEnvFactory( + task, experiment_config.seed, sampling_config, frames_stack, scale=scale_obs, + ) builder = ( DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config) @@ -82,14 +84,14 @@ def main( estimation_step=n_step, ), ) - .with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs, features_only=True)) + .with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True)) .with_common_critic_factory_use_actor() .with_trainer_stop_callback(AtariStopCallback(task)) ) if icm_lr_scale > 0: builder.with_policy_wrapper_factory( PolicyWrapperFactoryIntrinsicCuriosity( - FeatureNetFactoryDQN(), + IntermediateModuleFactoryAtariDQN(net_only=True), [hidden_size], actor_lr, icm_lr_scale, diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index a9c9d21..66f6492 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -13,7 +13,10 @@ from tianshou.highlevel.logger import Logger from tianshou.highlevel.module.actor import ( ActorFactory, ) -from tianshou.highlevel.module.core import TDevice +from tianshou.highlevel.module.core import ( + ModuleFactory, + TDevice, +) from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory from tianshou.highlevel.module.module_opt import ( ActorCriticModuleOpt, @@ -24,6 +27,7 @@ from tianshou.highlevel.params.policy_params import ( DDPGParams, DiscreteSACParams, DQNParams, + IQNParams, NPGParams, Params, ParamsMixinActorAndDualCritics, @@ -44,6 +48,7 @@ from tianshou.policy import ( DDPGPolicy, DiscreteSACPolicy, DQNPolicy, + IQNPolicy, NPGPolicy, PGPolicy, PPOPolicy, @@ -61,6 +66,9 @@ CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" TParams = TypeVar("TParams", bound=Params) TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler) TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics) +TDiscreteCriticOnlyParams = TypeVar( + "TDiscreteCriticOnlyParams", bound=ParamsMixinLearningRateWithScheduler, +) TPolicy = TypeVar("TPolicy", bound=BasePolicy) @@ -394,21 +402,27 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]): return TRPOPolicy -class DQNAgentFactory(OffpolicyAgentFactory): +class DiscreteCriticOnlyAgentFactory( + OffpolicyAgentFactory, Generic[TDiscreteCriticOnlyParams, TPolicy], +): def __init__( self, - params: DQNParams, + params: TDiscreteCriticOnlyParams, sampling_config: SamplingConfig, - actor_factory: ActorFactory, + model_factory: ModuleFactory, optim_factory: OptimizerFactory, ): super().__init__(sampling_config, optim_factory) self.params = params - self.actor_factory = actor_factory + self.model_factory = model_factory self.optim_factory = optim_factory - def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: - model = self.actor_factory.create_module(envs, device) + @abstractmethod + def _get_policy_class(self) -> type[TPolicy]: + pass + + def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: + model = self.model_factory.create_module(envs, device) optim = self.optim_factory.create_optimizer(model, self.params.lr) kwargs = self.params.create_kwargs( ParamTransformerData( @@ -420,7 +434,8 @@ class DQNAgentFactory(OffpolicyAgentFactory): ) envs.get_type().assert_discrete(self) action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space()) - return DQNPolicy( + policy_class = self._get_policy_class() + return policy_class( model=model, optim=optim, action_space=action_space, @@ -429,6 +444,16 @@ class DQNAgentFactory(OffpolicyAgentFactory): ) +class DQNAgentFactory(DiscreteCriticOnlyAgentFactory[DQNParams, DQNPolicy]): + def _get_policy_class(self) -> type[DQNPolicy]: + return DQNPolicy + + +class IQNAgentFactory(DiscreteCriticOnlyAgentFactory[IQNParams, IQNPolicy]): + def _get_policy_class(self) -> type[IQNPolicy]: + return IQNPolicy + + class DDPGAgentFactory(OffpolicyAgentFactory): def __init__( self, diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 520cff5..a8a8889 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -15,6 +15,7 @@ from tianshou.highlevel.agent import ( DDPGAgentFactory, DiscreteSACAgentFactory, DQNAgentFactory, + IQNAgentFactory, NPGAgentFactory, PGAgentFactory, PPOAgentFactory, @@ -33,6 +34,11 @@ from tianshou.highlevel.module.actor import ( ActorFuture, ActorFutureProviderProtocol, ContinuousActorType, + IntermediateModuleFactoryFromActorFactory, +) +from tianshou.highlevel.module.core import ( + ImplicitQuantileNetworkFactory, + IntermediateModuleFactory, ) from tianshou.highlevel.module.critic import ( CriticEnsembleFactory, @@ -47,6 +53,7 @@ from tianshou.highlevel.params.policy_params import ( DDPGParams, DiscreteSACParams, DQNParams, + IQNParams, NPGParams, PGParams, PPOParams, @@ -641,6 +648,41 @@ class DQNExperimentBuilder( ) +class IQNExperimentBuilder(ExperimentBuilder): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + self._params: IQNParams = IQNParams() + self._preprocess_network_factory = IntermediateModuleFactoryFromActorFactory( + ActorFactoryDefault(ContinuousActorType.UNSUPPORTED), + ) + + def with_iqn_params(self, params: IQNParams) -> Self: + self._params = params + return self + + def with_preprocess_network_factory(self, module_factory: IntermediateModuleFactory) -> Self: + self._preprocess_network_factory = module_factory + return self + + def _create_agent_factory(self) -> AgentFactory: + model_factory = ImplicitQuantileNetworkFactory( + self._preprocess_network_factory, + hidden_sizes=self._params.hidden_sizes, + num_cosines=self._params.num_cosines, + ) + return IQNAgentFactory( + self._params, + self._sampling_config, + model_factory, + self._get_optim_factory(), + ) + + class DDPGExperimentBuilder( ExperimentBuilder, _BuilderMixinActorFactory_ContinuousDeterministic, diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 1b161da..0d3d86f 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -8,7 +8,13 @@ import torch from torch import nn from tianshou.highlevel.env import Environments, EnvType -from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal +from tianshou.highlevel.module.core import ( + IntermediateModule, + IntermediateModuleFactory, + ModuleFactory, + TDevice, + init_linear_orthogonal, +) from tianshou.highlevel.module.module_opt import ModuleOpt from tianshou.highlevel.optim import OptimizerFactory from tianshou.utils.net import continuous, discrete @@ -34,7 +40,7 @@ class ActorFutureProviderProtocol(Protocol): pass -class ActorFactory(ToStringMixin, ABC): +class ActorFactory(ModuleFactory, ToStringMixin, ABC): @abstractmethod def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: pass @@ -212,3 +218,13 @@ class ActorFactoryTransientStorageDecorator(ActorFactory): module = self.actor_factory.create_module(envs, device) self._actor_future.actor = module return module + + +class IntermediateModuleFactoryFromActorFactory(IntermediateModuleFactory): + def __init__(self, actor_factory: ActorFactory): + self.actor_factory = actor_factory + + def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: + actor = self.actor_factory.create_module(envs, device) + assert isinstance(actor, BaseActor) + return IntermediateModule(actor, actor.get_output_dim()) diff --git a/tianshou/highlevel/module/core.py b/tianshou/highlevel/module/core.py index 80ca012..08fc88e 100644 --- a/tianshou/highlevel/module/core.py +++ b/tianshou/highlevel/module/core.py @@ -7,7 +7,7 @@ import numpy as np import torch from tianshou.highlevel.env import Environments -from tianshou.utils.net.common import Net +from tianshou.utils.net.discrete import ImplicitQuantileNetwork from tianshou.utils.string import ToStringMixin TDevice: TypeAlias = str | torch.device @@ -24,24 +24,45 @@ def init_linear_orthogonal(module: torch.nn.Module) -> None: torch.nn.init.zeros_(m.bias) +class ModuleFactory(ABC): + @abstractmethod + def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module: + pass + + @dataclass -class Module: +class IntermediateModule: module: torch.nn.Module output_dim: int -class ModuleFactory(ToStringMixin, ABC): +class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC): @abstractmethod - def create_module(self, envs: Environments, device: TDevice) -> Module: + def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: pass + def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module: + return self.create_intermediate_module(envs, device).module -class ModuleFactoryNet( - ModuleFactory, -): # TODO This is unused and broken; use it in ActorFactory* and so on? - def __init__(self, hidden_sizes: int | Sequence[int]): + +class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin): + def __init__( + self, + preprocess_net_factory: IntermediateModuleFactory, + hidden_sizes: Sequence[int] = (), + num_cosines: int = 64, + ): + self.preprocess_net_factory = preprocess_net_factory self.hidden_sizes = hidden_sizes + self.num_cosines = num_cosines - def create_module(self, envs: Environments, device: TDevice) -> Module: - module = Net(envs.get_observation_shape()) - return Module(module, module.output_dim) + def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork: + preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device) + return ImplicitQuantileNetwork( + preprocess_net=preprocess_net.module, + action_shape=envs.get_action_shape(), + hidden_sizes=self.hidden_sizes, + num_cosines=self.num_cosines, + preprocess_net_output_dim=preprocess_net.output_dim, + device=device, + ).to(device) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 44cc20f..33c9783 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import asdict, dataclass from typing import Any, Literal, Protocol @@ -388,6 +389,23 @@ class DQNParams(Params, ParamsMixinLearningRateWithScheduler): return transformers +@dataclass +class IQNParams(DQNParams): + sample_size: int = 32 + online_sample_size: int = 8 + target_sample_size: int = 8 + num_quantiles: int = 200 + hidden_sizes: Sequence[int] = () + """hidden dimensions to use in the IQN network""" + num_cosines: int = 64 + """number of cosines to use in the IQN network""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.append(ParamTransformerDrop("hidden_sizes", "num_cosines")) + return transformers + + @dataclass class DDPGParams(Params, ParamsMixinActorAndCritic): tau: float = 0.005 diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py index 8bf43cc..843a0f7 100644 --- a/tianshou/highlevel/params/policy_wrapper.py +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from typing import Generic, TypeVar from tianshou.highlevel.env import Environments -from tianshou.highlevel.module.core import ModuleFactory, TDevice +from tianshou.highlevel.module.core import IntermediateModuleFactory, TDevice from tianshou.highlevel.optim import OptimizerFactory from tianshou.policy import BasePolicy, ICMPolicy from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -31,7 +31,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity( ): def __init__( self, - feature_net_factory: ModuleFactory, + feature_net_factory: IntermediateModuleFactory, hidden_sizes: Sequence[int], lr: float, lr_scale: float, @@ -52,7 +52,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity( optim_factory: OptimizerFactory, device: TDevice, ) -> ICMPolicy: - feature_net = self.feature_net_factory.create_module(envs, device) + feature_net = self.feature_net_factory.create_intermediate_module(envs, device) action_dim = envs.get_action_shape() if not isinstance(action_dim, int): raise ValueError(f"Environment action shape must be an integer, got {action_dim}")