From 686fd555b0610bcbd3e99f9cf924c3a0e9407dca Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 11 Oct 2023 16:07:34 +0200 Subject: [PATCH] Extend tests, fixing some default behaviour --- examples/atari/atari_sac_hl.py | 6 +++++- test/highlevel/test_continuous.py | 7 +++++++ test/highlevel/test_discrete.py | 12 +++++++++++- tianshou/highlevel/agent.py | 15 +++++++++++---- tianshou/highlevel/experiment.py | 3 +-- tianshou/highlevel/module/actor.py | 10 ++++++++-- 6 files changed, 43 insertions(+), 10 deletions(-) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 9909687..ac9ceb2 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -68,7 +68,11 @@ def main( ) env_factory = AtariEnvFactory( - task, experiment_config.seed, sampling_config, frames_stack, scale=scale_obs, + task, + experiment_config.seed, + sampling_config, + frames_stack, + scale=scale_obs, ) builder = ( diff --git a/test/highlevel/test_continuous.py b/test/highlevel/test_continuous.py index 884e9d8..2b7e927 100644 --- a/test/highlevel/test_continuous.py +++ b/test/highlevel/test_continuous.py @@ -7,9 +7,12 @@ from tianshou.highlevel.experiment import ( A2CExperimentBuilder, DDPGExperimentBuilder, ExperimentConfig, + PGExperimentBuilder, PPOExperimentBuilder, + REDQExperimentBuilder, SACExperimentBuilder, TD3ExperimentBuilder, + TRPOExperimentBuilder, ) @@ -21,6 +24,10 @@ from tianshou.highlevel.experiment import ( SACExperimentBuilder, DDPGExperimentBuilder, TD3ExperimentBuilder, + # NPGExperimentBuilder, # TODO test fails non-deterministically + REDQExperimentBuilder, + TRPOExperimentBuilder, + PGExperimentBuilder, ], ) def test_experiment_builder_continuous_default_params(builder_cls): diff --git a/test/highlevel/test_discrete.py b/test/highlevel/test_discrete.py index 0dd624b..52517aa 100644 --- a/test/highlevel/test_discrete.py +++ b/test/highlevel/test_discrete.py @@ -5,17 +5,27 @@ import pytest from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( A2CExperimentBuilder, + DiscreteSACExperimentBuilder, DQNExperimentBuilder, ExperimentConfig, + IQNExperimentBuilder, PPOExperimentBuilder, ) +from tianshou.utils import logging @pytest.mark.parametrize( "builder_cls", - [PPOExperimentBuilder, A2CExperimentBuilder, DQNExperimentBuilder], + [ + PPOExperimentBuilder, + A2CExperimentBuilder, + DQNExperimentBuilder, + DiscreteSACExperimentBuilder, + IQNExperimentBuilder, + ], ) def test_experiment_builder_discrete_default_params(builder_cls): + logging.configure() env_factory = DiscreteTestEnvFactory() sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100) builder = builder_cls( diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 66f6492..b927cda 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -67,7 +67,8 @@ TParams = TypeVar("TParams", bound=Params) TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler) TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics) TDiscreteCriticOnlyParams = TypeVar( - "TDiscreteCriticOnlyParams", bound=ParamsMixinLearningRateWithScheduler, + "TDiscreteCriticOnlyParams", + bound=ParamsMixinLearningRateWithScheduler, ) TPolicy = TypeVar("TPolicy", bound=BasePolicy) @@ -403,7 +404,8 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]): class DiscreteCriticOnlyAgentFactory( - OffpolicyAgentFactory, Generic[TDiscreteCriticOnlyParams, TPolicy], + OffpolicyAgentFactory, + Generic[TDiscreteCriticOnlyParams, TPolicy], ): def __init__( self, @@ -583,6 +585,10 @@ class ActorDualCriticsAgentFactory( def _get_discrete_last_size_use_action_shape(self) -> bool: return True + @staticmethod + def _get_critic_use_action(envs: Environments) -> bool: + return envs.get_type().is_continuous() + def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: actor = self.actor_factory.create_module_opt( envs, @@ -591,10 +597,11 @@ class ActorDualCriticsAgentFactory( self.params.actor_lr, ) use_action_shape = self._get_discrete_last_size_use_action_shape() + critic_use_action = self._get_critic_use_action(envs) critic1 = self.critic1_factory.create_module_opt( envs, device, - True, + critic_use_action, self.optim_factory, self.params.critic1_lr, discrete_last_size_use_action_shape=use_action_shape, @@ -602,7 +609,7 @@ class ActorDualCriticsAgentFactory( critic2 = self.critic2_factory.create_module_opt( envs, device, - True, + critic_use_action, self.optim_factory, self.params.critic2_lr, discrete_last_size_use_action_shape=use_action_shape, diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index a8a8889..ad242d6 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -483,14 +483,13 @@ class PGExperimentBuilder( ): super().__init__(env_factory, experiment_config, sampling_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) - self._params: A2CParams = A2CParams() + self._params: PGParams = PGParams() self._env_config = None def with_pg_params(self, params: PGParams) -> Self: self._params = params return self - @abstractmethod def _create_agent_factory(self) -> AgentFactory: return PGAgentFactory( self._params, diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 0d3d86f..8bd7b5f 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -92,11 +92,13 @@ class ActorFactoryDefault(ActorFactory): hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES, continuous_unbounded: bool = False, continuous_conditioned_sigma: bool = False, + discrete_softmax: bool = True, ): self.continuous_actor_type = continuous_actor_type self.continuous_unbounded = continuous_unbounded self.continuous_conditioned_sigma = continuous_conditioned_sigma self.hidden_sizes = hidden_sizes + self.discrete_softmax = discrete_softmax def create_module(self, envs: Environments, device: TDevice) -> BaseActor: env_type = envs.get_type() @@ -117,7 +119,9 @@ class ActorFactoryDefault(ActorFactory): raise ValueError(self.continuous_actor_type) return factory.create_module(envs, device) elif env_type == EnvType.DISCRETE: - factory = ActorFactoryDiscreteNet(self.DEFAULT_HIDDEN_SIZES) + factory = ActorFactoryDiscreteNet( + self.DEFAULT_HIDDEN_SIZES, softmax_output=self.discrete_softmax, + ) return factory.create_module(envs, device) else: raise ValueError(f"{env_type} not supported") @@ -180,8 +184,9 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous): class ActorFactoryDiscreteNet(ActorFactory): - def __init__(self, hidden_sizes: Sequence[int]): + def __init__(self, hidden_sizes: Sequence[int], softmax_output: bool = True): self.hidden_sizes = hidden_sizes + self.softmax_output = softmax_output def create_module(self, envs: Environments, device: TDevice) -> BaseActor: net_a = Net( @@ -194,6 +199,7 @@ class ActorFactoryDiscreteNet(ActorFactory): envs.get_action_shape(), hidden_sizes=(), device=device, + softmax_output=self.softmax_output, ).to(device)