diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py new file mode 100644 index 0000000..735cd1d --- /dev/null +++ b/examples/atari/atari_sac_hl.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 + +import os + +from jsonargparse import CLI + +from examples.atari.atari_network import ( + ActorFactoryAtariDQN, + FeatureNetFactoryDQN, +) +from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback +from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.experiment import ( + DiscreteSACExperimentBuilder, + ExperimentConfig, +) +from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault +from tianshou.highlevel.params.policy_params import DiscreteSACParams +from tianshou.highlevel.params.policy_wrapper import ( + PolicyWrapperFactoryIntrinsicCuriosity, +) +from tianshou.utils import logging +from tianshou.utils.logging import datetime_tag + + +def main( + experiment_config: ExperimentConfig, + task: str = "PongNoFrameskip-v4", + scale_obs: bool = False, + buffer_size: int = 100000, + actor_lr: float = 1e-5, + critic_lr: float = 1e-5, + gamma: float = 0.99, + n_step: int = 3, + tau: float = 0.005, + alpha: float = 0.05, + auto_alpha: bool = False, + alpha_lr: float = 3e-4, + epoch: int = 100, + step_per_epoch: int = 100000, + step_per_collect: int = 10, + update_per_step: float = 0.1, + batch_size: int = 64, + hidden_size: int = 512, + training_num: int = 10, + test_num: int = 10, + frames_stack: int = 4, + save_buffer_name: str | None = None, # TODO add support in high-level API? + icm_lr_scale: float = 0.0, + icm_reward_scale: float = 0.01, + icm_forward_loss_weight: float = 0.2, +): + log_name = os.path.join(task, "sac", str(experiment_config.seed), datetime_tag()) + + sampling_config = SamplingConfig( + num_epochs=epoch, + step_per_epoch=step_per_epoch, + update_per_step=update_per_step, + batch_size=batch_size, + num_train_envs=training_num, + num_test_envs=test_num, + buffer_size=buffer_size, + step_per_collect=step_per_collect, + repeat_per_collect=None, + replay_buffer_stack_num=frames_stack, + replay_buffer_ignore_obs_next=True, + replay_buffer_save_only_last_obs=True, + ) + + env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack) + + builder = ( + DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config) + .with_sac_params( + DiscreteSACParams( + actor_lr=actor_lr, + critic1_lr=critic_lr, + critic2_lr=critic_lr, + gamma=gamma, + tau=tau, + alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha, + estimation_step=n_step, + ), + ) + .with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs, features_only=True)) + .with_common_critic_factory_use_actor() + .with_trainer_stop_callback(AtariStopCallback(task)) + ) + if icm_lr_scale > 0: + builder.with_policy_wrapper_factory( + PolicyWrapperFactoryIntrinsicCuriosity( + FeatureNetFactoryDQN(), + [hidden_size], + actor_lr, + icm_lr_scale, + icm_reward_scale, + icm_forward_loss_weight, + ), + ) + experiment = builder.build() + experiment.run(log_name) + + +if __name__ == "__main__": + logging.run_main(lambda: CLI(main)) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index f92dc66..2000c09 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -22,9 +22,11 @@ from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.params.policy_params import ( A2CParams, DDPGParams, + DiscreteSACParams, DQNParams, NPGParams, Params, + ParamsMixinActorAndDualCritics, ParamTransformerData, PGParams, PPOParams, @@ -39,6 +41,7 @@ from tianshou.policy import ( A2CPolicy, BasePolicy, DDPGPolicy, + DiscreteSACPolicy, DQNPolicy, NPGPolicy, PGPolicy, @@ -49,13 +52,13 @@ from tianshou.policy import ( TRPOPolicy, ) from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer -from tianshou.utils.net import continuous, discrete -from tianshou.utils.net.common import ActorCritic, BaseActor +from tianshou.utils.net.common import ActorCritic from tianshou.utils.string import ToStringMixin CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" TParams = TypeVar("TParams", bound=Params) +TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics) TPolicy = TypeVar("TPolicy", bound=BasePolicy) @@ -247,7 +250,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC): ) -class _ActorCriticMixin: +class _ActorCriticMixin: # TODO merge """Mixin for agents that use an ActorCritic module with a single optimizer.""" def __init__( @@ -256,13 +259,11 @@ class _ActorCriticMixin: critic_factory: CriticFactory, optim_factory: OptimizerFactory, critic_use_action: bool, - critic_use_actor_module: bool, ): self.actor_factory = actor_factory self.critic_factory = critic_factory self.optim_factory = optim_factory self.critic_use_action = critic_use_action - self.critic_use_actor_module = critic_use_actor_module def create_actor_critic_module_opt( self, @@ -271,28 +272,7 @@ class _ActorCriticMixin: lr: float, ) -> ActorCriticModuleOpt: actor = self.actor_factory.create_module(envs, device) - critic: torch.nn.Module - if self.critic_use_actor_module: - if self.critic_use_action: - raise ValueError( - "The options critic_use_actor_module and critic_use_action are mutually exclusive", - ) - if not isinstance(actor, BaseActor): - raise ValueError( - f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}", - ) - if envs.get_type().is_discrete(): - critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device) - elif envs.get_type().is_continuous(): - critic = continuous.Critic(actor.get_preprocess_net(), device=device).to(device) - else: - raise ValueError - else: - critic = self.critic_factory.create_module( - envs, - device, - use_action=self.critic_use_action, - ) + critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action) actor_critic = ActorCritic(actor, critic) optim = self.optim_factory.create_optimizer(actor_critic, lr) return ActorCriticModuleOpt(actor_critic, optim) @@ -349,7 +329,6 @@ class ActorCriticAgentFactory( critic_factory: CriticFactory, optimizer_factory: OptimizerFactory, policy_class: type[TPolicy], - critic_use_actor_module: bool, ): super().__init__(sampling_config, optim_factory=optimizer_factory) _ActorCriticMixin.__init__( @@ -358,7 +337,6 @@ class ActorCriticAgentFactory( critic_factory, optimizer_factory, critic_use_action=False, - critic_use_actor_module=critic_use_actor_module, ) self.params = params self.policy_class = policy_class @@ -395,7 +373,6 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]): actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory, - critic_use_actor_module: bool, ): super().__init__( params, @@ -404,7 +381,6 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]): critic_factory, optimizer_factory, A2CPolicy, - critic_use_actor_module, ) def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: @@ -419,7 +395,6 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]): actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory, - critic_use_actor_module: bool, ): super().__init__( params, @@ -428,7 +403,6 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]): critic_factory, optimizer_factory, PPOPolicy, - critic_use_actor_module, ) def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: @@ -443,7 +417,6 @@ class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]): actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory, - critic_use_actor_module: bool, ): super().__init__( params, @@ -452,7 +425,6 @@ class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]): critic_factory, optimizer_factory, NPGPolicy, - critic_use_actor_module, ) def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: @@ -467,7 +439,6 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]): actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory, - critic_use_actor_module: bool, ): super().__init__( params, @@ -476,7 +447,6 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]): critic_factory, optimizer_factory, TRPOPolicy, - critic_use_actor_module, ) def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: @@ -619,6 +589,81 @@ class REDQAgentFactory(OffpolicyAgentFactory): ) +class ActorDualCriticsAgentFactory( + OffpolicyAgentFactory, Generic[TActorDualCriticsParams, TPolicy], ABC, +): + def __init__( + self, + params: TActorDualCriticsParams, + sampling_config: SamplingConfig, + actor_factory: ActorFactory, + critic1_factory: CriticFactory, + critic2_factory: CriticFactory, + optim_factory: OptimizerFactory, + ): + super().__init__(sampling_config, optim_factory) + self.params = params + self.actor_factory = actor_factory + self.critic1_factory = critic1_factory + self.critic2_factory = critic2_factory + self.optim_factory = optim_factory + + @abstractmethod + def _get_policy_class(self) -> type[TPolicy]: + pass + + @abstractmethod + def _get_discrete_last_size_use_action_shape(self) -> bool: + pass + + def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: + actor = self.actor_factory.create_module_opt( + envs, + device, + self.optim_factory, + self.params.actor_lr, + ) + use_action_shape = self._get_discrete_last_size_use_action_shape() + critic1 = self.critic1_factory.create_module_opt( + envs, + device, + True, + self.optim_factory, + self.params.critic1_lr, + discrete_last_size_use_action_shape=use_action_shape, + ) + critic2 = self.critic2_factory.create_module_opt( + envs, + device, + True, + self.optim_factory, + self.params.critic2_lr, + discrete_last_size_use_action_shape=use_action_shape, + ) + kwargs = self.params.create_kwargs( + ParamTransformerData( + envs=envs, + device=device, + optim_factory=self.optim_factory, + actor=actor, + critic1=critic1, + critic2=critic2, + ), + ) + policy_class = self._get_policy_class() + return policy_class( + actor=actor.module, + actor_optim=actor.optim, + critic=critic1.module, + critic_optim=critic1.optim, + critic2=critic2.module, + critic2_optim=critic2.optim, + action_space=envs.get_action_space(), + observation_space=envs.get_observation_space(), + **kwargs, + ) + + class SACAgentFactory(OffpolicyAgentFactory): def __init__( self, @@ -680,6 +725,14 @@ class SACAgentFactory(OffpolicyAgentFactory): ) +class DiscreteSACAgentFactory(ActorDualCriticsAgentFactory[DiscreteSACParams, DiscreteSACPolicy]): + def _get_discrete_last_size_use_action_shape(self) -> bool: + return True + + def _get_policy_class(self) -> type[TPolicy]: + return DiscreteSACPolicy + + class TD3AgentFactory(OffpolicyAgentFactory): def __init__( self, diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index eb8eeae..520cff5 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -13,6 +13,7 @@ from tianshou.highlevel.agent import ( A2CAgentFactory, AgentFactory, DDPGAgentFactory, + DiscreteSACAgentFactory, DQNAgentFactory, NPGAgentFactory, PGAgentFactory, @@ -28,6 +29,9 @@ from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory from tianshou.highlevel.module.actor import ( ActorFactory, ActorFactoryDefault, + ActorFactoryTransientStorageDecorator, + ActorFuture, + ActorFutureProviderProtocol, ContinuousActorType, ) from tianshou.highlevel.module.critic import ( @@ -35,11 +39,13 @@ from tianshou.highlevel.module.critic import ( CriticEnsembleFactoryDefault, CriticFactory, CriticFactoryDefault, + CriticFactoryReuseActor, ) from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam from tianshou.highlevel.params.policy_params import ( A2CParams, DDPGParams, + DiscreteSACParams, DQNParams, NPGParams, PGParams, @@ -263,9 +269,10 @@ class ExperimentBuilder: return experiment -class _BuilderMixinActorFactory: +class _BuilderMixinActorFactory(ActorFutureProviderProtocol): def __init__(self, continuous_actor_type: ContinuousActorType): self._continuous_actor_type = continuous_actor_type + self._actor_future = ActorFuture() self._actor_factory: ActorFactory | None = None def with_actor_factory(self, actor_factory: ActorFactory) -> Self: @@ -286,11 +293,16 @@ class _BuilderMixinActorFactory: ) return self + def get_actor_future(self) -> ActorFuture: + return self._actor_future + def _get_actor_factory(self) -> ActorFactory: + actor_factory: ActorFactory if self._actor_factory is None: - return ActorFactoryDefault(self._continuous_actor_type) + actor_factory = ActorFactoryDefault(self._continuous_actor_type) else: - return self._actor_factory + actor_factory = self._actor_factory + return ActorFactoryTransientStorageDecorator(actor_factory, self._actor_future) class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory): @@ -325,7 +337,8 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor class _BuilderMixinCriticsFactory: - def __init__(self, num_critics: int): + def __init__(self, num_critics: int, actor_future_provider: ActorFutureProviderProtocol): + self._actor_future_provider = actor_future_provider self._critic_factories: list[CriticFactory | None] = [None] * num_critics def _with_critic_factory(self, idx: int, critic_factory: CriticFactory) -> Self: @@ -336,6 +349,12 @@ class _BuilderMixinCriticsFactory: self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes) return self + def _with_critic_factory_use_actor(self, idx: int) -> Self: + self._critic_factories[idx] = CriticFactoryReuseActor( + self._actor_future_provider.get_actor_future(), + ) + return self + def _get_critic_factory(self, idx: int) -> CriticFactory: factory = self._critic_factories[idx] if factory is None: @@ -345,8 +364,8 @@ class _BuilderMixinCriticsFactory: class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): - def __init__(self) -> None: - super().__init__(1) + def __init__(self, actor_future_provider: ActorFutureProviderProtocol = None) -> None: + super().__init__(1, actor_future_provider) def with_critic_factory(self, critic_factory: CriticFactory) -> Self: self._with_critic_factory(0, critic_factory) @@ -361,19 +380,17 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory): - def __init__(self) -> None: - super().__init__() - self._critic_use_actor_module = False + def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None: + super().__init__(actor_future_provider) def with_critic_factory_use_actor(self) -> Self: """Makes the critic use the same network as the actor.""" - self._critic_use_actor_module = True - return self + return self._with_critic_factory_use_actor(0) class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): - def __init__(self) -> None: - super().__init__(2) + def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None: + super().__init__(2, actor_future_provider) def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self: for i in range(len(self._critic_factories)): @@ -388,6 +405,12 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): self._with_critic_factory_default(i, hidden_sizes) return self + def with_common_critic_factory_use_actor(self) -> Self: + """Makes all critics use the same network as the actor.""" + for i in range(len(self._critic_factories)): + self._with_critic_factory_use_actor(i) + return self + def with_critic1_factory(self, critic_factory: CriticFactory) -> Self: self._with_critic_factory(0, critic_factory) return self @@ -399,6 +422,10 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): self._with_critic_factory_default(0, hidden_sizes) return self + def with_critic1_factory_use_actor(self) -> Self: + """Makes the critic use the same network as the actor.""" + return self._with_critic_factory_use_actor(0) + def with_critic2_factory(self, critic_factory: CriticFactory) -> Self: self._with_critic_factory(1, critic_factory) return self @@ -410,6 +437,10 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): self._with_critic_factory_default(0, hidden_sizes) return self + def with_critic2_factory_use_actor(self) -> Self: + """Makes the second critic use the same network as the actor.""" + return self._with_critic_factory_use_actor(1) + class _BuilderMixinCriticEnsembleFactory: def __init__(self) -> None: @@ -475,7 +506,7 @@ class A2CExperimentBuilder( ): super().__init__(env_factory, experiment_config, sampling_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) - _BuilderMixinSingleCriticCanUseActorFactory.__init__(self) + _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: A2CParams = A2CParams() self._env_config = None @@ -483,7 +514,6 @@ class A2CExperimentBuilder( self._params = params return self - @abstractmethod def _create_agent_factory(self) -> AgentFactory: return A2CAgentFactory( self._params, @@ -491,7 +521,6 @@ class A2CExperimentBuilder( self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), - self._critic_use_actor_module, ) @@ -508,14 +537,13 @@ class PPOExperimentBuilder( ): super().__init__(env_factory, experiment_config, sampling_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) - _BuilderMixinSingleCriticCanUseActorFactory.__init__(self) + _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: PPOParams = PPOParams() def with_ppo_params(self, params: PPOParams) -> Self: self._params = params return self - @abstractmethod def _create_agent_factory(self) -> AgentFactory: return PPOAgentFactory( self._params, @@ -523,7 +551,6 @@ class PPOExperimentBuilder( self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), - self._critic_use_actor_module, ) @@ -540,14 +567,13 @@ class NPGExperimentBuilder( ): super().__init__(env_factory, experiment_config, sampling_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) - _BuilderMixinSingleCriticCanUseActorFactory.__init__(self) + _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: NPGParams = NPGParams() def with_npg_params(self, params: NPGParams) -> Self: self._params = params return self - @abstractmethod def _create_agent_factory(self) -> AgentFactory: return NPGAgentFactory( self._params, @@ -555,7 +581,6 @@ class NPGExperimentBuilder( self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), - self._critic_use_actor_module, ) @@ -572,14 +597,13 @@ class TRPOExperimentBuilder( ): super().__init__(env_factory, experiment_config, sampling_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) - _BuilderMixinSingleCriticCanUseActorFactory.__init__(self) + _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: TRPOParams = TRPOParams() def with_trpo_params(self, params: TRPOParams) -> Self: self._params = params return self - @abstractmethod def _create_agent_factory(self) -> AgentFactory: return TRPOAgentFactory( self._params, @@ -587,7 +611,6 @@ class TRPOExperimentBuilder( self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), - self._critic_use_actor_module, ) @@ -609,7 +632,6 @@ class DQNExperimentBuilder( self._params = params return self - @abstractmethod def _create_agent_factory(self) -> AgentFactory: return DQNAgentFactory( self._params, @@ -632,14 +654,13 @@ class DDPGExperimentBuilder( ): super().__init__(env_factory, experiment_config, sampling_config) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) - _BuilderMixinSingleCriticCanUseActorFactory.__init__(self) + _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: DDPGParams = DDPGParams() def with_ddpg_params(self, params: DDPGParams) -> Self: self._params = params return self - @abstractmethod def _create_agent_factory(self) -> AgentFactory: return DDPGAgentFactory( self._params, @@ -670,7 +691,6 @@ class REDQExperimentBuilder( self._params = params return self - @abstractmethod def _create_agent_factory(self) -> AgentFactory: return REDQAgentFactory( self._params, @@ -694,7 +714,7 @@ class SACExperimentBuilder( ): super().__init__(env_factory, experiment_config, sampling_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) - _BuilderMixinDualCriticFactory.__init__(self) + _BuilderMixinDualCriticFactory.__init__(self, self) self._params: SACParams = SACParams() def with_sac_params(self, params: SACParams) -> Self: @@ -712,6 +732,37 @@ class SACExperimentBuilder( ) +class DiscreteSACExperimentBuilder( + ExperimentBuilder, + _BuilderMixinActorFactory, + _BuilderMixinDualCriticFactory, +): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + _BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED) + _BuilderMixinDualCriticFactory.__init__(self, self) + self._params: DiscreteSACParams = DiscreteSACParams() + + def with_sac_params(self, params: DiscreteSACParams) -> Self: + self._params = params + return self + + def _create_agent_factory(self) -> AgentFactory: + return DiscreteSACAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_factory(0), + self._get_critic_factory(1), + self._get_optim_factory(), + ) + + class TD3ExperimentBuilder( ExperimentBuilder, _BuilderMixinActorFactory_ContinuousDeterministic, @@ -725,7 +776,7 @@ class TD3ExperimentBuilder( ): super().__init__(env_factory, experiment_config, sampling_config) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) - _BuilderMixinDualCriticFactory.__init__(self) + _BuilderMixinDualCriticFactory.__init__(self, self) self._params: TD3Params = TD3Params() def with_td3_params(self, params: TD3Params) -> Self: diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 64d18e4..1b161da 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from collections.abc import Sequence +from dataclasses import dataclass from enum import Enum +from typing import Protocol import torch from torch import nn @@ -20,6 +22,18 @@ class ContinuousActorType(Enum): UNSUPPORTED = "unsupported" +@dataclass +class ActorFuture: + """Container, which, in the future, will hold an actor instance.""" + + actor: BaseActor | nn.Module | None = None + + +class ActorFutureProviderProtocol(Protocol): + def get_actor_future(self) -> ActorFuture: + pass + + class ActorFactory(ToStringMixin, ABC): @abstractmethod def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: @@ -175,3 +189,26 @@ class ActorFactoryDiscreteNet(ActorFactory): hidden_sizes=(), device=device, ).to(device) + + +class ActorFactoryTransientStorageDecorator(ActorFactory): + def __init__(self, actor_factory: ActorFactory, actor_future: ActorFuture): + self.actor_factory = actor_factory + self._actor_future = actor_future + + def __getstate__(self): + d = dict(self.__dict__) + del d["_actor_future"] + return d + + def __setstate__(self, state): + self.__dict__ = state + self._actor_future = ActorFuture() + + def _tostring_excludes(self): + return [*super()._tostring_excludes(), "_actor_future"] + + def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: + module = self.actor_factory.create_module(envs, device) + self._actor_future.actor = module + return module diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index e6edfc7..4af97bc 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -4,18 +4,30 @@ from collections.abc import Sequence from torch import nn from tianshou.highlevel.env import Environments, EnvType +from tianshou.highlevel.module.actor import ActorFuture from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal from tianshou.highlevel.module.module_opt import ModuleOpt from tianshou.highlevel.optim import OptimizerFactory from tianshou.utils.net import continuous, discrete -from tianshou.utils.net.common import EnsembleLinear, Net +from tianshou.utils.net.common import BaseActor, EnsembleLinear, Net from tianshou.utils.string import ToStringMixin class CriticFactory(ToStringMixin, ABC): @abstractmethod - def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: - pass + def create_module( + self, + envs: Environments, + device: TDevice, + use_action: bool, + discrete_last_size_use_action_shape: bool = False, + ) -> nn.Module: + """:param envs: the environments + :param device: the torch device + :param use_action: whether to (additionally) expect the action as input + :param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape + :return: the module + """ def create_module_opt( self, @@ -24,8 +36,14 @@ class CriticFactory(ToStringMixin, ABC): use_action: bool, optim_factory: OptimizerFactory, lr: float, + discrete_last_size_use_action_shape: bool = False, ) -> ModuleOpt: - module = self.create_module(envs, device, use_action) + module = self.create_module( + envs, + device, + use_action, + discrete_last_size_use_action_shape=discrete_last_size_use_action_shape, + ) opt = optim_factory.create_optimizer(module, lr) return ModuleOpt(module, opt) @@ -38,7 +56,13 @@ class CriticFactoryDefault(CriticFactory): def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES): self.hidden_sizes = hidden_sizes - def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: + def create_module( + self, + envs: Environments, + device: TDevice, + use_action: bool, + discrete_last_size_use_action_shape=False, + ) -> nn.Module: factory: CriticFactory env_type = envs.get_type() match env_type: @@ -48,14 +72,25 @@ class CriticFactoryDefault(CriticFactory): factory = CriticFactoryDiscreteNet(self.hidden_sizes) case _: raise ValueError(f"{env_type} not supported") - return factory.create_module(envs, device, use_action) + return factory.create_module( + envs, + device, + use_action, + discrete_last_size_use_action_shape=discrete_last_size_use_action_shape, + ) class CriticFactoryContinuousNet(CriticFactory): def __init__(self, hidden_sizes: Sequence[int]): self.hidden_sizes = hidden_sizes - def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: + def create_module( + self, + envs: Environments, + device: TDevice, + use_action: bool, + discrete_last_size_use_action_shape=False, + ) -> nn.Module: action_shape = envs.get_action_shape() if use_action else 0 net_c = Net( envs.get_observation_shape(), @@ -74,7 +109,13 @@ class CriticFactoryDiscreteNet(CriticFactory): def __init__(self, hidden_sizes: Sequence[int]): self.hidden_sizes = hidden_sizes - def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: + def create_module( + self, + envs: Environments, + device: TDevice, + use_action: bool, + discrete_last_size_use_action_shape=False, + ) -> nn.Module: action_shape = envs.get_action_shape() if use_action else 0 net_c = Net( envs.get_observation_shape(), @@ -84,11 +125,50 @@ class CriticFactoryDiscreteNet(CriticFactory): activation=nn.Tanh, device=device, ) - critic = discrete.Critic(net_c, device=device).to(device) + last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1 + critic = discrete.Critic(net_c, device=device, last_size=last_size).to(device) init_linear_orthogonal(critic) return critic +class CriticFactoryReuseActor(CriticFactory): + """A critic factory which reuses the actor's preprocessing component. + + This class is for internal use in experiment builders only. + """ + + def __init__(self, actor_future: ActorFuture): + """:param actor_future: the object, which will hold the actor instance later when the critic is to be created""" + self.actor_future = actor_future + + def _tostring_excludes(self) -> list[str]: + return ["actor_future"] + + def create_module( + self, + envs: Environments, + device: TDevice, + use_action: bool, + discrete_last_size_use_action_shape=False, + ) -> nn.Module: + actor = self.actor_future.actor + if not isinstance(actor, BaseActor): + raise ValueError( + f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}", + ) + if envs.get_type().is_discrete(): + last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1 + return discrete.Critic( + actor.get_preprocess_net(), + device=device, + last_size=last_size, + ).to(device) + elif envs.get_type().is_continuous(): + return continuous.Critic(actor.get_preprocess_net(), device=device).to(device) + else: + raise ValueError + + class CriticEnsembleFactory: @abstractmethod def create_module( diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 540eec9..44cc20f 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -359,6 +359,20 @@ class SACParams(Params, ParamsMixinActorAndDualCritics): return transformers +@dataclass +class DiscreteSACParams(Params, ParamsMixinActorAndDualCritics): + tau: float = 0.005 + gamma: float = 0.99 + alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2 + estimation_step: int = 1 + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) + transformers.append(ParamTransformerAutoAlpha("alpha")) + return transformers + + @dataclass class DQNParams(Params, ParamsMixinLearningRateWithScheduler): discount_factor: float = 0.99