From c7d0b6b4b2debfa64b518953af59e5cb497f267e Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 10 Oct 2023 20:07:30 +0200 Subject: [PATCH] Simplify agent factories by making better use of base classes --- tianshou/highlevel/agent.py | 259 ++++++------------------------------ 1 file changed, 43 insertions(+), 216 deletions(-) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 2000c09..a9c9d21 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -27,6 +27,7 @@ from tianshou.highlevel.params.policy_params import ( NPGParams, Params, ParamsMixinActorAndDualCritics, + ParamsMixinLearningRateWithScheduler, ParamTransformerData, PGParams, PPOParams, @@ -58,6 +59,7 @@ from tianshou.utils.string import ToStringMixin CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" TParams = TypeVar("TParams", bound=Params) +TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler) TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics) TPolicy = TypeVar("TPolicy", bound=BasePolicy) @@ -316,37 +318,43 @@ class PGAgentFactory(OnpolicyAgentFactory): class ActorCriticAgentFactory( - Generic[TParams, TPolicy], + Generic[TActorCriticParams, TPolicy], OnpolicyAgentFactory, - _ActorCriticMixin, ABC, ): def __init__( self, - params: TParams, + params: TActorCriticParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory, - policy_class: type[TPolicy], ): super().__init__(sampling_config, optim_factory=optimizer_factory) - _ActorCriticMixin.__init__( - self, - actor_factory, - critic_factory, - optimizer_factory, - critic_use_action=False, - ) self.params = params - self.policy_class = policy_class + self.actor_factory = actor_factory + self.critic_factory = critic_factory + self.optim_factory = optimizer_factory + self.critic_use_action = False @abstractmethod - def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: + def _get_policy_class(self) -> type[TPolicy]: pass + def create_actor_critic_module_opt( + self, + envs: Environments, + device: TDevice, + lr: float, + ) -> ActorCriticModuleOpt: + actor = self.actor_factory.create_module(envs, device) + critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action) + actor_critic = ActorCritic(actor, critic) + optim = self.optim_factory.create_optimizer(actor_critic, lr) + return ActorCriticModuleOpt(actor_critic, optim) + def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: - actor_critic = self._create_actor_critic(envs, device) + actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, @@ -362,95 +370,28 @@ class ActorCriticAgentFactory( return kwargs def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: - return self.policy_class(**self._create_kwargs(envs, device)) + policy_class = self._get_policy_class() + return policy_class(**self._create_kwargs(envs, device)) class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]): - def __init__( - self, - params: A2CParams, - sampling_config: SamplingConfig, - actor_factory: ActorFactory, - critic_factory: CriticFactory, - optimizer_factory: OptimizerFactory, - ): - super().__init__( - params, - sampling_config, - actor_factory, - critic_factory, - optimizer_factory, - A2CPolicy, - ) - - def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: - return self.create_actor_critic_module_opt(envs, device, self.params.lr) + def _get_policy_class(self) -> type[A2CPolicy]: + return A2CPolicy class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]): - def __init__( - self, - params: PPOParams, - sampling_config: SamplingConfig, - actor_factory: ActorFactory, - critic_factory: CriticFactory, - optimizer_factory: OptimizerFactory, - ): - super().__init__( - params, - sampling_config, - actor_factory, - critic_factory, - optimizer_factory, - PPOPolicy, - ) - - def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: - return self.create_actor_critic_module_opt(envs, device, self.params.lr) + def _get_policy_class(self) -> type[PPOPolicy]: + return PPOPolicy class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]): - def __init__( - self, - params: NPGParams, - sampling_config: SamplingConfig, - actor_factory: ActorFactory, - critic_factory: CriticFactory, - optimizer_factory: OptimizerFactory, - ): - super().__init__( - params, - sampling_config, - actor_factory, - critic_factory, - optimizer_factory, - NPGPolicy, - ) - - def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: - return self.create_actor_critic_module_opt(envs, device, self.params.lr) + def _get_policy_class(self) -> type[NPGPolicy]: + return NPGPolicy class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]): - def __init__( - self, - params: TRPOParams, - sampling_config: SamplingConfig, - actor_factory: ActorFactory, - critic_factory: CriticFactory, - optimizer_factory: OptimizerFactory, - ): - super().__init__( - params, - sampling_config, - actor_factory, - critic_factory, - optimizer_factory, - TRPOPolicy, - ) - - def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: - return self.create_actor_critic_module_opt(envs, device, self.params.lr) + def _get_policy_class(self) -> type[TRPOPolicy]: + return TRPOPolicy class DQNAgentFactory(OffpolicyAgentFactory): @@ -590,7 +531,9 @@ class REDQAgentFactory(OffpolicyAgentFactory): class ActorDualCriticsAgentFactory( - OffpolicyAgentFactory, Generic[TActorDualCriticsParams, TPolicy], ABC, + OffpolicyAgentFactory, + Generic[TActorDualCriticsParams, TPolicy], + ABC, ): def __init__( self, @@ -612,9 +555,8 @@ class ActorDualCriticsAgentFactory( def _get_policy_class(self) -> type[TPolicy]: pass - @abstractmethod def _get_discrete_last_size_use_action_shape(self) -> bool: - pass + return True def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: actor = self.actor_factory.create_module_opt( @@ -664,131 +606,16 @@ class ActorDualCriticsAgentFactory( ) -class SACAgentFactory(OffpolicyAgentFactory): - def __init__( - self, - params: SACParams, - sampling_config: SamplingConfig, - actor_factory: ActorFactory, - critic1_factory: CriticFactory, - critic2_factory: CriticFactory, - optim_factory: OptimizerFactory, - ): - super().__init__(sampling_config, optim_factory) - self.params = params - self.actor_factory = actor_factory - self.critic1_factory = critic1_factory - self.critic2_factory = critic2_factory - self.optim_factory = optim_factory - - def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: - actor = self.actor_factory.create_module_opt( - envs, - device, - self.optim_factory, - self.params.actor_lr, - ) - critic1 = self.critic1_factory.create_module_opt( - envs, - device, - True, - self.optim_factory, - self.params.critic1_lr, - ) - critic2 = self.critic2_factory.create_module_opt( - envs, - device, - True, - self.optim_factory, - self.params.critic2_lr, - ) - kwargs = self.params.create_kwargs( - ParamTransformerData( - envs=envs, - device=device, - optim_factory=self.optim_factory, - actor=actor, - critic1=critic1, - critic2=critic2, - ), - ) - return SACPolicy( - actor=actor.module, - actor_optim=actor.optim, - critic=critic1.module, - critic_optim=critic1.optim, - critic2=critic2.module, - critic2_optim=critic2.optim, - action_space=envs.get_action_space(), - observation_space=envs.get_observation_space(), - **kwargs, - ) +class SACAgentFactory(ActorDualCriticsAgentFactory[SACParams, SACPolicy]): + def _get_policy_class(self) -> type[SACPolicy]: + return SACPolicy class DiscreteSACAgentFactory(ActorDualCriticsAgentFactory[DiscreteSACParams, DiscreteSACPolicy]): - def _get_discrete_last_size_use_action_shape(self) -> bool: - return True - - def _get_policy_class(self) -> type[TPolicy]: + def _get_policy_class(self) -> type[DiscreteSACPolicy]: return DiscreteSACPolicy -class TD3AgentFactory(OffpolicyAgentFactory): - def __init__( - self, - params: TD3Params, - sampling_config: SamplingConfig, - actor_factory: ActorFactory, - critic1_factory: CriticFactory, - critic2_factory: CriticFactory, - optim_factory: OptimizerFactory, - ): - super().__init__(sampling_config, optim_factory) - self.params = params - self.actor_factory = actor_factory - self.critic1_factory = critic1_factory - self.critic2_factory = critic2_factory - self.optim_factory = optim_factory - - def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: - actor = self.actor_factory.create_module_opt( - envs, - device, - self.optim_factory, - self.params.actor_lr, - ) - critic1 = self.critic1_factory.create_module_opt( - envs, - device, - True, - self.optim_factory, - self.params.critic1_lr, - ) - critic2 = self.critic2_factory.create_module_opt( - envs, - device, - True, - self.optim_factory, - self.params.critic2_lr, - ) - kwargs = self.params.create_kwargs( - ParamTransformerData( - envs=envs, - device=device, - optim_factory=self.optim_factory, - actor=actor, - critic1=critic1, - critic2=critic2, - ), - ) - return TD3Policy( - actor=actor.module, - actor_optim=actor.optim, - critic=critic1.module, - critic_optim=critic1.optim, - critic2=critic2.module, - critic2_optim=critic2.optim, - action_space=envs.get_action_space(), - observation_space=envs.get_observation_space(), - **kwargs, - ) +class TD3AgentFactory(ActorDualCriticsAgentFactory[TD3Params, TD3Policy]): + def _get_policy_class(self) -> type[TD3Policy]: + return TD3Policy