Extend tests, fixing some default behaviour

This commit is contained in:
Dominik Jain 2023-10-11 16:07:34 +02:00
parent a8a367c42d
commit 686fd555b0
6 changed files with 43 additions and 10 deletions

View File

@ -68,7 +68,11 @@ def main(
) )
env_factory = AtariEnvFactory( env_factory = AtariEnvFactory(
task, experiment_config.seed, sampling_config, frames_stack, scale=scale_obs, task,
experiment_config.seed,
sampling_config,
frames_stack,
scale=scale_obs,
) )
builder = ( builder = (

View File

@ -7,9 +7,12 @@ from tianshou.highlevel.experiment import (
A2CExperimentBuilder, A2CExperimentBuilder,
DDPGExperimentBuilder, DDPGExperimentBuilder,
ExperimentConfig, ExperimentConfig,
PGExperimentBuilder,
PPOExperimentBuilder, PPOExperimentBuilder,
REDQExperimentBuilder,
SACExperimentBuilder, SACExperimentBuilder,
TD3ExperimentBuilder, TD3ExperimentBuilder,
TRPOExperimentBuilder,
) )
@ -21,6 +24,10 @@ from tianshou.highlevel.experiment import (
SACExperimentBuilder, SACExperimentBuilder,
DDPGExperimentBuilder, DDPGExperimentBuilder,
TD3ExperimentBuilder, TD3ExperimentBuilder,
# NPGExperimentBuilder, # TODO test fails non-deterministically
REDQExperimentBuilder,
TRPOExperimentBuilder,
PGExperimentBuilder,
], ],
) )
def test_experiment_builder_continuous_default_params(builder_cls): def test_experiment_builder_continuous_default_params(builder_cls):

View File

@ -5,17 +5,27 @@ import pytest
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
A2CExperimentBuilder, A2CExperimentBuilder,
DiscreteSACExperimentBuilder,
DQNExperimentBuilder, DQNExperimentBuilder,
ExperimentConfig, ExperimentConfig,
IQNExperimentBuilder,
PPOExperimentBuilder, PPOExperimentBuilder,
) )
from tianshou.utils import logging
@pytest.mark.parametrize( @pytest.mark.parametrize(
"builder_cls", "builder_cls",
[PPOExperimentBuilder, A2CExperimentBuilder, DQNExperimentBuilder], [
PPOExperimentBuilder,
A2CExperimentBuilder,
DQNExperimentBuilder,
DiscreteSACExperimentBuilder,
IQNExperimentBuilder,
],
) )
def test_experiment_builder_discrete_default_params(builder_cls): def test_experiment_builder_discrete_default_params(builder_cls):
logging.configure()
env_factory = DiscreteTestEnvFactory() env_factory = DiscreteTestEnvFactory()
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100) sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
builder = builder_cls( builder = builder_cls(

View File

@ -67,7 +67,8 @@ TParams = TypeVar("TParams", bound=Params)
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler) TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics) TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
TDiscreteCriticOnlyParams = TypeVar( TDiscreteCriticOnlyParams = TypeVar(
"TDiscreteCriticOnlyParams", bound=ParamsMixinLearningRateWithScheduler, "TDiscreteCriticOnlyParams",
bound=ParamsMixinLearningRateWithScheduler,
) )
TPolicy = TypeVar("TPolicy", bound=BasePolicy) TPolicy = TypeVar("TPolicy", bound=BasePolicy)
@ -403,7 +404,8 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
class DiscreteCriticOnlyAgentFactory( class DiscreteCriticOnlyAgentFactory(
OffpolicyAgentFactory, Generic[TDiscreteCriticOnlyParams, TPolicy], OffpolicyAgentFactory,
Generic[TDiscreteCriticOnlyParams, TPolicy],
): ):
def __init__( def __init__(
self, self,
@ -583,6 +585,10 @@ class ActorDualCriticsAgentFactory(
def _get_discrete_last_size_use_action_shape(self) -> bool: def _get_discrete_last_size_use_action_shape(self) -> bool:
return True return True
@staticmethod
def _get_critic_use_action(envs: Environments) -> bool:
return envs.get_type().is_continuous()
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
actor = self.actor_factory.create_module_opt( actor = self.actor_factory.create_module_opt(
envs, envs,
@ -591,10 +597,11 @@ class ActorDualCriticsAgentFactory(
self.params.actor_lr, self.params.actor_lr,
) )
use_action_shape = self._get_discrete_last_size_use_action_shape() use_action_shape = self._get_discrete_last_size_use_action_shape()
critic_use_action = self._get_critic_use_action(envs)
critic1 = self.critic1_factory.create_module_opt( critic1 = self.critic1_factory.create_module_opt(
envs, envs,
device, device,
True, critic_use_action,
self.optim_factory, self.optim_factory,
self.params.critic1_lr, self.params.critic1_lr,
discrete_last_size_use_action_shape=use_action_shape, discrete_last_size_use_action_shape=use_action_shape,
@ -602,7 +609,7 @@ class ActorDualCriticsAgentFactory(
critic2 = self.critic2_factory.create_module_opt( critic2 = self.critic2_factory.create_module_opt(
envs, envs,
device, device,
True, critic_use_action,
self.optim_factory, self.optim_factory,
self.params.critic2_lr, self.params.critic2_lr,
discrete_last_size_use_action_shape=use_action_shape, discrete_last_size_use_action_shape=use_action_shape,

View File

@ -483,14 +483,13 @@ class PGExperimentBuilder(
): ):
super().__init__(env_factory, experiment_config, sampling_config) super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
self._params: A2CParams = A2CParams() self._params: PGParams = PGParams()
self._env_config = None self._env_config = None
def with_pg_params(self, params: PGParams) -> Self: def with_pg_params(self, params: PGParams) -> Self:
self._params = params self._params = params
return self return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory: def _create_agent_factory(self) -> AgentFactory:
return PGAgentFactory( return PGAgentFactory(
self._params, self._params,

View File

@ -92,11 +92,13 @@ class ActorFactoryDefault(ActorFactory):
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
continuous_unbounded: bool = False, continuous_unbounded: bool = False,
continuous_conditioned_sigma: bool = False, continuous_conditioned_sigma: bool = False,
discrete_softmax: bool = True,
): ):
self.continuous_actor_type = continuous_actor_type self.continuous_actor_type = continuous_actor_type
self.continuous_unbounded = continuous_unbounded self.continuous_unbounded = continuous_unbounded
self.continuous_conditioned_sigma = continuous_conditioned_sigma self.continuous_conditioned_sigma = continuous_conditioned_sigma
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
self.discrete_softmax = discrete_softmax
def create_module(self, envs: Environments, device: TDevice) -> BaseActor: def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
env_type = envs.get_type() env_type = envs.get_type()
@ -117,7 +119,9 @@ class ActorFactoryDefault(ActorFactory):
raise ValueError(self.continuous_actor_type) raise ValueError(self.continuous_actor_type)
return factory.create_module(envs, device) return factory.create_module(envs, device)
elif env_type == EnvType.DISCRETE: elif env_type == EnvType.DISCRETE:
factory = ActorFactoryDiscreteNet(self.DEFAULT_HIDDEN_SIZES) factory = ActorFactoryDiscreteNet(
self.DEFAULT_HIDDEN_SIZES, softmax_output=self.discrete_softmax,
)
return factory.create_module(envs, device) return factory.create_module(envs, device)
else: else:
raise ValueError(f"{env_type} not supported") raise ValueError(f"{env_type} not supported")
@ -180,8 +184,9 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
class ActorFactoryDiscreteNet(ActorFactory): class ActorFactoryDiscreteNet(ActorFactory):
def __init__(self, hidden_sizes: Sequence[int]): def __init__(self, hidden_sizes: Sequence[int], softmax_output: bool = True):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
self.softmax_output = softmax_output
def create_module(self, envs: Environments, device: TDevice) -> BaseActor: def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
net_a = Net( net_a = Net(
@ -194,6 +199,7 @@ class ActorFactoryDiscreteNet(ActorFactory):
envs.get_action_shape(), envs.get_action_shape(),
hidden_sizes=(), hidden_sizes=(),
device=device, device=device,
softmax_output=self.softmax_output,
).to(device) ).to(device)