Extend tests, fixing some default behaviour

This commit is contained in:
Dominik Jain 2023-10-11 16:07:34 +02:00
parent a8a367c42d
commit 686fd555b0
6 changed files with 43 additions and 10 deletions

View File

@ -68,7 +68,11 @@ def main(
)
env_factory = AtariEnvFactory(
task, experiment_config.seed, sampling_config, frames_stack, scale=scale_obs,
task,
experiment_config.seed,
sampling_config,
frames_stack,
scale=scale_obs,
)
builder = (

View File

@ -7,9 +7,12 @@ from tianshou.highlevel.experiment import (
A2CExperimentBuilder,
DDPGExperimentBuilder,
ExperimentConfig,
PGExperimentBuilder,
PPOExperimentBuilder,
REDQExperimentBuilder,
SACExperimentBuilder,
TD3ExperimentBuilder,
TRPOExperimentBuilder,
)
@ -21,6 +24,10 @@ from tianshou.highlevel.experiment import (
SACExperimentBuilder,
DDPGExperimentBuilder,
TD3ExperimentBuilder,
# NPGExperimentBuilder, # TODO test fails non-deterministically
REDQExperimentBuilder,
TRPOExperimentBuilder,
PGExperimentBuilder,
],
)
def test_experiment_builder_continuous_default_params(builder_cls):

View File

@ -5,17 +5,27 @@ import pytest
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
A2CExperimentBuilder,
DiscreteSACExperimentBuilder,
DQNExperimentBuilder,
ExperimentConfig,
IQNExperimentBuilder,
PPOExperimentBuilder,
)
from tianshou.utils import logging
@pytest.mark.parametrize(
"builder_cls",
[PPOExperimentBuilder, A2CExperimentBuilder, DQNExperimentBuilder],
[
PPOExperimentBuilder,
A2CExperimentBuilder,
DQNExperimentBuilder,
DiscreteSACExperimentBuilder,
IQNExperimentBuilder,
],
)
def test_experiment_builder_discrete_default_params(builder_cls):
logging.configure()
env_factory = DiscreteTestEnvFactory()
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
builder = builder_cls(

View File

@ -67,7 +67,8 @@ TParams = TypeVar("TParams", bound=Params)
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
TDiscreteCriticOnlyParams = TypeVar(
"TDiscreteCriticOnlyParams", bound=ParamsMixinLearningRateWithScheduler,
"TDiscreteCriticOnlyParams",
bound=ParamsMixinLearningRateWithScheduler,
)
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
@ -403,7 +404,8 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
class DiscreteCriticOnlyAgentFactory(
OffpolicyAgentFactory, Generic[TDiscreteCriticOnlyParams, TPolicy],
OffpolicyAgentFactory,
Generic[TDiscreteCriticOnlyParams, TPolicy],
):
def __init__(
self,
@ -583,6 +585,10 @@ class ActorDualCriticsAgentFactory(
def _get_discrete_last_size_use_action_shape(self) -> bool:
return True
@staticmethod
def _get_critic_use_action(envs: Environments) -> bool:
return envs.get_type().is_continuous()
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
actor = self.actor_factory.create_module_opt(
envs,
@ -591,10 +597,11 @@ class ActorDualCriticsAgentFactory(
self.params.actor_lr,
)
use_action_shape = self._get_discrete_last_size_use_action_shape()
critic_use_action = self._get_critic_use_action(envs)
critic1 = self.critic1_factory.create_module_opt(
envs,
device,
True,
critic_use_action,
self.optim_factory,
self.params.critic1_lr,
discrete_last_size_use_action_shape=use_action_shape,
@ -602,7 +609,7 @@ class ActorDualCriticsAgentFactory(
critic2 = self.critic2_factory.create_module_opt(
envs,
device,
True,
critic_use_action,
self.optim_factory,
self.params.critic2_lr,
discrete_last_size_use_action_shape=use_action_shape,

View File

@ -483,14 +483,13 @@ class PGExperimentBuilder(
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
self._params: A2CParams = A2CParams()
self._params: PGParams = PGParams()
self._env_config = None
def with_pg_params(self, params: PGParams) -> Self:
self._params = params
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
return PGAgentFactory(
self._params,

View File

@ -92,11 +92,13 @@ class ActorFactoryDefault(ActorFactory):
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
continuous_unbounded: bool = False,
continuous_conditioned_sigma: bool = False,
discrete_softmax: bool = True,
):
self.continuous_actor_type = continuous_actor_type
self.continuous_unbounded = continuous_unbounded
self.continuous_conditioned_sigma = continuous_conditioned_sigma
self.hidden_sizes = hidden_sizes
self.discrete_softmax = discrete_softmax
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
env_type = envs.get_type()
@ -117,7 +119,9 @@ class ActorFactoryDefault(ActorFactory):
raise ValueError(self.continuous_actor_type)
return factory.create_module(envs, device)
elif env_type == EnvType.DISCRETE:
factory = ActorFactoryDiscreteNet(self.DEFAULT_HIDDEN_SIZES)
factory = ActorFactoryDiscreteNet(
self.DEFAULT_HIDDEN_SIZES, softmax_output=self.discrete_softmax,
)
return factory.create_module(envs, device)
else:
raise ValueError(f"{env_type} not supported")
@ -180,8 +184,9 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
class ActorFactoryDiscreteNet(ActorFactory):
def __init__(self, hidden_sizes: Sequence[int]):
def __init__(self, hidden_sizes: Sequence[int], softmax_output: bool = True):
self.hidden_sizes = hidden_sizes
self.softmax_output = softmax_output
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
net_a = Net(
@ -194,6 +199,7 @@ class ActorFactoryDiscreteNet(ActorFactory):
envs.get_action_shape(),
hidden_sizes=(),
device=device,
softmax_output=self.softmax_output,
).to(device)