From b54fcd12cb96e47c4775177009d5a7a125a3ac80 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 5 Oct 2023 19:21:08 +0200 Subject: [PATCH] Change high-level DQN interface to expect an actor instead of a critic, because that is what is functionally required --- examples/atari/atari_dqn_hl.py | 4 ++-- examples/atari/atari_network.py | 19 ++++++------------ examples/atari/atari_ppo_hl.py | 2 +- tianshou/highlevel/agent.py | 16 +++++++++------ tianshou/highlevel/experiment.py | 6 +++--- tianshou/highlevel/module/actor.py | 8 ++++++-- tianshou/highlevel/module/critic.py | 30 ++++++++++++++++++++++------- 7 files changed, 51 insertions(+), 34 deletions(-) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 42583d8..47916dc 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -6,7 +6,7 @@ import os from jsonargparse import CLI from examples.atari.atari_network import ( - CriticFactoryAtariDQN, + ActorFactoryAtariPlainDQN, FeatureNetFactoryDQN, ) from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback @@ -103,7 +103,7 @@ def main( target_update_freq=target_update_freq, ), ) - .with_critic_factory(CriticFactoryAtariDQN()) + .with_actor_factory(ActorFactoryAtariPlainDQN()) .with_trainer_epoch_callback_train(TrainEpochCallback()) .with_trainer_epoch_callback_test(TestEpochCallback()) .with_trainer_stop_callback(AtariStopCallback(task)) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 0d72fad..3757d70 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -8,8 +8,6 @@ from torch import nn from tianshou.highlevel.env import Environments from tianshou.highlevel.module.actor import ActorFactory from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice -from tianshou.highlevel.module.critic import CriticFactory -from tianshou.utils.net.common import BaseActor from tianshou.utils.net.discrete import Actor, NoisyLinear @@ -227,14 +225,8 @@ class QRDQN(DQN): return obs, state -class CriticFactoryAtariDQN(CriticFactory): - def create_module( - self, - envs: Environments, - device: TDevice, - use_action: bool, - ) -> torch.nn.Module: - assert use_action +class ActorFactoryAtariPlainDQN(ActorFactory): + def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module: return DQN( *envs.get_observation_shape(), envs.get_action_shape(), @@ -243,17 +235,18 @@ class CriticFactoryAtariDQN(CriticFactory): class ActorFactoryAtariDQN(ActorFactory): - def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool): + def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool, features_only: bool): self.hidden_size = hidden_size self.scale_obs = scale_obs + self.features_only = features_only - def create_module(self, envs: Environments, device: TDevice) -> BaseActor: + def create_module(self, envs: Environments, device: TDevice) -> Actor: net_cls = scale_obs(DQN) if self.scale_obs else DQN net = net_cls( *envs.get_observation_shape(), envs.get_action_shape(), device=device, - features_only=True, + features_only=self.features_only, output_dim=self.hidden_size, layer_init=layer_init, ) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index b02ff07..3179934 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -96,7 +96,7 @@ def main( else None, ), ) - .with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs)) + .with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True)) .with_critic_factory_use_actor() .with_trainer_stop_callback(AtariStopCallback(task)) ) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 88b2fc1..af0fea9 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -44,7 +44,7 @@ from tianshou.policy import ( ) from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.utils.net import continuous, discrete -from tianshou.utils.net.common import ActorCritic +from tianshou.utils.net.common import ActorCritic, BaseActor from tianshou.utils.string import ToStringMixin CHECKPOINT_DICT_KEY_MODEL = "model" @@ -285,6 +285,10 @@ class _ActorCriticMixin: raise ValueError( "The options critic_use_actor_module and critic_use_action are mutually exclusive", ) + if not isinstance(actor, BaseActor): + raise ValueError( + f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}", + ) if envs.get_type().is_discrete(): critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device) elif envs.get_type().is_continuous(): @@ -444,17 +448,17 @@ class DQNAgentFactory(OffpolicyAgentFactory): self, params: DQNParams, sampling_config: RLSamplingConfig, - critic_factory: CriticFactory, + actor_factory: ActorFactory, optim_factory: OptimizerFactory, ): super().__init__(sampling_config, optim_factory) self.params = params - self.critic_factory = critic_factory + self.actor_factory = actor_factory self.optim_factory = optim_factory def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: - critic = self.critic_factory.create_module(envs, device, use_action=True) - optim = self.optim_factory.create_optimizer(critic, self.params.lr) + model = self.actor_factory.create_module(envs, device) + optim = self.optim_factory.create_optimizer(model, self.params.lr) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, @@ -467,7 +471,7 @@ class DQNAgentFactory(OffpolicyAgentFactory): # noinspection PyTypeChecker action_space: gymnasium.spaces.Discrete = envs.get_action_space() return DQNPolicy( - model=critic, + model=model, optim=optim, action_space=action_space, observation_space=envs.get_observation_space(), diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index b93a758..aed6005 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -461,7 +461,7 @@ class PPOExperimentBuilder( class DQNExperimentBuilder( RLExperimentBuilder, - _BuilderMixinSingleCriticFactory, + _BuilderMixinActorFactory, ): def __init__( self, @@ -470,7 +470,7 @@ class DQNExperimentBuilder( sampling_config: RLSamplingConfig, ): super().__init__(experiment_config, env_factory, sampling_config) - _BuilderMixinSingleCriticFactory.__init__(self) + _BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED) self._params: DQNParams = DQNParams() def with_dqn_params(self, params: DQNParams) -> Self: @@ -482,7 +482,7 @@ class DQNExperimentBuilder( return DQNAgentFactory( self._params, self._sampling_config, - self._get_critic_factory(0), + self._get_actor_factory(), self._get_optim_factory(), ) diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 8849f38..2c4fa5c 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -14,11 +14,12 @@ from tianshou.utils.string import ToStringMixin class ContinuousActorType: GAUSSIAN = "gaussian" DETERMINISTIC = "deterministic" + UNSUPPORTED = "unsupported" class ActorFactory(ToStringMixin, ABC): @abstractmethod - def create_module(self, envs: Environments, device: TDevice) -> BaseActor: + def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: pass @staticmethod @@ -67,11 +68,14 @@ class ActorFactoryDefault(ActorFactory): ) case ContinuousActorType.DETERMINISTIC: factory = ActorFactoryContinuousDeterministicNet(self.hidden_sizes) + case ContinuousActorType.UNSUPPORTED: + raise ValueError("Continuous action spaces are not supported by the algorithm") case _: raise ValueError(self.continuous_actor_type) return factory.create_module(envs, device) elif env_type == EnvType.DISCRETE: - raise NotImplementedError + factory = ActorFactoryDiscreteNet(self.DEFAULT_HIDDEN_SIZES) + return factory.create_module(envs, device) else: raise ValueError(f"{env_type} not supported") diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index 092c172..83cb797 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -5,7 +5,7 @@ from torch import nn from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal -from tianshou.utils.net import continuous +from tianshou.utils.net import continuous, discrete from tianshou.utils.net.common import Net from tianshou.utils.string import ToStringMixin @@ -30,16 +30,13 @@ class CriticFactoryDefault(CriticFactory): factory = CriticFactoryContinuousNet(self.hidden_sizes) return factory.create_module(envs, device, use_action) elif env_type == EnvType.DISCRETE: - raise NotImplementedError + factory = CriticFactoryDiscreteNet(self.hidden_sizes) + return factory.create_module(envs, device, use_action) else: raise ValueError(f"{env_type} not supported") -class CriticFactoryContinuous(CriticFactory, ABC): - pass - - -class CriticFactoryContinuousNet(CriticFactoryContinuous): +class CriticFactoryContinuousNet(CriticFactory): def __init__(self, hidden_sizes: Sequence[int]): self.hidden_sizes = hidden_sizes @@ -56,3 +53,22 @@ class CriticFactoryContinuousNet(CriticFactoryContinuous): critic = continuous.Critic(net_c, device=device).to(device) init_linear_orthogonal(critic) return critic + + +class CriticFactoryDiscreteNet(CriticFactory): + def __init__(self, hidden_sizes: Sequence[int]): + self.hidden_sizes = hidden_sizes + + def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: + action_shape = envs.get_action_shape() if use_action else 0 + net_c = Net( + envs.get_observation_shape(), + action_shape=action_shape, + hidden_sizes=self.hidden_sizes, + concat=use_action, + activation=nn.Tanh, + device=device, + ) + critic = discrete.Critic(net_c, device=device).to(device) + init_linear_orthogonal(critic) + return critic