Extend tests, fixing some default behaviour
This commit is contained in:
parent
a8a367c42d
commit
686fd555b0
@ -68,7 +68,11 @@ def main(
|
|||||||
)
|
)
|
||||||
|
|
||||||
env_factory = AtariEnvFactory(
|
env_factory = AtariEnvFactory(
|
||||||
task, experiment_config.seed, sampling_config, frames_stack, scale=scale_obs,
|
task,
|
||||||
|
experiment_config.seed,
|
||||||
|
sampling_config,
|
||||||
|
frames_stack,
|
||||||
|
scale=scale_obs,
|
||||||
)
|
)
|
||||||
|
|
||||||
builder = (
|
builder = (
|
||||||
|
@ -7,9 +7,12 @@ from tianshou.highlevel.experiment import (
|
|||||||
A2CExperimentBuilder,
|
A2CExperimentBuilder,
|
||||||
DDPGExperimentBuilder,
|
DDPGExperimentBuilder,
|
||||||
ExperimentConfig,
|
ExperimentConfig,
|
||||||
|
PGExperimentBuilder,
|
||||||
PPOExperimentBuilder,
|
PPOExperimentBuilder,
|
||||||
|
REDQExperimentBuilder,
|
||||||
SACExperimentBuilder,
|
SACExperimentBuilder,
|
||||||
TD3ExperimentBuilder,
|
TD3ExperimentBuilder,
|
||||||
|
TRPOExperimentBuilder,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -21,6 +24,10 @@ from tianshou.highlevel.experiment import (
|
|||||||
SACExperimentBuilder,
|
SACExperimentBuilder,
|
||||||
DDPGExperimentBuilder,
|
DDPGExperimentBuilder,
|
||||||
TD3ExperimentBuilder,
|
TD3ExperimentBuilder,
|
||||||
|
# NPGExperimentBuilder, # TODO test fails non-deterministically
|
||||||
|
REDQExperimentBuilder,
|
||||||
|
TRPOExperimentBuilder,
|
||||||
|
PGExperimentBuilder,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_experiment_builder_continuous_default_params(builder_cls):
|
def test_experiment_builder_continuous_default_params(builder_cls):
|
||||||
|
@ -5,17 +5,27 @@ import pytest
|
|||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
A2CExperimentBuilder,
|
A2CExperimentBuilder,
|
||||||
|
DiscreteSACExperimentBuilder,
|
||||||
DQNExperimentBuilder,
|
DQNExperimentBuilder,
|
||||||
ExperimentConfig,
|
ExperimentConfig,
|
||||||
|
IQNExperimentBuilder,
|
||||||
PPOExperimentBuilder,
|
PPOExperimentBuilder,
|
||||||
)
|
)
|
||||||
|
from tianshou.utils import logging
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"builder_cls",
|
"builder_cls",
|
||||||
[PPOExperimentBuilder, A2CExperimentBuilder, DQNExperimentBuilder],
|
[
|
||||||
|
PPOExperimentBuilder,
|
||||||
|
A2CExperimentBuilder,
|
||||||
|
DQNExperimentBuilder,
|
||||||
|
DiscreteSACExperimentBuilder,
|
||||||
|
IQNExperimentBuilder,
|
||||||
|
],
|
||||||
)
|
)
|
||||||
def test_experiment_builder_discrete_default_params(builder_cls):
|
def test_experiment_builder_discrete_default_params(builder_cls):
|
||||||
|
logging.configure()
|
||||||
env_factory = DiscreteTestEnvFactory()
|
env_factory = DiscreteTestEnvFactory()
|
||||||
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
||||||
builder = builder_cls(
|
builder = builder_cls(
|
||||||
|
@ -67,7 +67,8 @@ TParams = TypeVar("TParams", bound=Params)
|
|||||||
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
|
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
|
||||||
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
|
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
|
||||||
TDiscreteCriticOnlyParams = TypeVar(
|
TDiscreteCriticOnlyParams = TypeVar(
|
||||||
"TDiscreteCriticOnlyParams", bound=ParamsMixinLearningRateWithScheduler,
|
"TDiscreteCriticOnlyParams",
|
||||||
|
bound=ParamsMixinLearningRateWithScheduler,
|
||||||
)
|
)
|
||||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||||
|
|
||||||
@ -403,7 +404,8 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
|
|||||||
|
|
||||||
|
|
||||||
class DiscreteCriticOnlyAgentFactory(
|
class DiscreteCriticOnlyAgentFactory(
|
||||||
OffpolicyAgentFactory, Generic[TDiscreteCriticOnlyParams, TPolicy],
|
OffpolicyAgentFactory,
|
||||||
|
Generic[TDiscreteCriticOnlyParams, TPolicy],
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -583,6 +585,10 @@ class ActorDualCriticsAgentFactory(
|
|||||||
def _get_discrete_last_size_use_action_shape(self) -> bool:
|
def _get_discrete_last_size_use_action_shape(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_critic_use_action(envs: Environments) -> bool:
|
||||||
|
return envs.get_type().is_continuous()
|
||||||
|
|
||||||
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||||
actor = self.actor_factory.create_module_opt(
|
actor = self.actor_factory.create_module_opt(
|
||||||
envs,
|
envs,
|
||||||
@ -591,10 +597,11 @@ class ActorDualCriticsAgentFactory(
|
|||||||
self.params.actor_lr,
|
self.params.actor_lr,
|
||||||
)
|
)
|
||||||
use_action_shape = self._get_discrete_last_size_use_action_shape()
|
use_action_shape = self._get_discrete_last_size_use_action_shape()
|
||||||
|
critic_use_action = self._get_critic_use_action(envs)
|
||||||
critic1 = self.critic1_factory.create_module_opt(
|
critic1 = self.critic1_factory.create_module_opt(
|
||||||
envs,
|
envs,
|
||||||
device,
|
device,
|
||||||
True,
|
critic_use_action,
|
||||||
self.optim_factory,
|
self.optim_factory,
|
||||||
self.params.critic1_lr,
|
self.params.critic1_lr,
|
||||||
discrete_last_size_use_action_shape=use_action_shape,
|
discrete_last_size_use_action_shape=use_action_shape,
|
||||||
@ -602,7 +609,7 @@ class ActorDualCriticsAgentFactory(
|
|||||||
critic2 = self.critic2_factory.create_module_opt(
|
critic2 = self.critic2_factory.create_module_opt(
|
||||||
envs,
|
envs,
|
||||||
device,
|
device,
|
||||||
True,
|
critic_use_action,
|
||||||
self.optim_factory,
|
self.optim_factory,
|
||||||
self.params.critic2_lr,
|
self.params.critic2_lr,
|
||||||
discrete_last_size_use_action_shape=use_action_shape,
|
discrete_last_size_use_action_shape=use_action_shape,
|
||||||
|
@ -483,14 +483,13 @@ class PGExperimentBuilder(
|
|||||||
):
|
):
|
||||||
super().__init__(env_factory, experiment_config, sampling_config)
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
self._params: A2CParams = A2CParams()
|
self._params: PGParams = PGParams()
|
||||||
self._env_config = None
|
self._env_config = None
|
||||||
|
|
||||||
def with_pg_params(self, params: PGParams) -> Self:
|
def with_pg_params(self, params: PGParams) -> Self:
|
||||||
self._params = params
|
self._params = params
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _create_agent_factory(self) -> AgentFactory:
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
return PGAgentFactory(
|
return PGAgentFactory(
|
||||||
self._params,
|
self._params,
|
||||||
|
@ -92,11 +92,13 @@ class ActorFactoryDefault(ActorFactory):
|
|||||||
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
|
||||||
continuous_unbounded: bool = False,
|
continuous_unbounded: bool = False,
|
||||||
continuous_conditioned_sigma: bool = False,
|
continuous_conditioned_sigma: bool = False,
|
||||||
|
discrete_softmax: bool = True,
|
||||||
):
|
):
|
||||||
self.continuous_actor_type = continuous_actor_type
|
self.continuous_actor_type = continuous_actor_type
|
||||||
self.continuous_unbounded = continuous_unbounded
|
self.continuous_unbounded = continuous_unbounded
|
||||||
self.continuous_conditioned_sigma = continuous_conditioned_sigma
|
self.continuous_conditioned_sigma = continuous_conditioned_sigma
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
self.discrete_softmax = discrete_softmax
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||||
env_type = envs.get_type()
|
env_type = envs.get_type()
|
||||||
@ -117,7 +119,9 @@ class ActorFactoryDefault(ActorFactory):
|
|||||||
raise ValueError(self.continuous_actor_type)
|
raise ValueError(self.continuous_actor_type)
|
||||||
return factory.create_module(envs, device)
|
return factory.create_module(envs, device)
|
||||||
elif env_type == EnvType.DISCRETE:
|
elif env_type == EnvType.DISCRETE:
|
||||||
factory = ActorFactoryDiscreteNet(self.DEFAULT_HIDDEN_SIZES)
|
factory = ActorFactoryDiscreteNet(
|
||||||
|
self.DEFAULT_HIDDEN_SIZES, softmax_output=self.discrete_softmax,
|
||||||
|
)
|
||||||
return factory.create_module(envs, device)
|
return factory.create_module(envs, device)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{env_type} not supported")
|
raise ValueError(f"{env_type} not supported")
|
||||||
@ -180,8 +184,9 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
|
|||||||
|
|
||||||
|
|
||||||
class ActorFactoryDiscreteNet(ActorFactory):
|
class ActorFactoryDiscreteNet(ActorFactory):
|
||||||
def __init__(self, hidden_sizes: Sequence[int]):
|
def __init__(self, hidden_sizes: Sequence[int], softmax_output: bool = True):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
self.softmax_output = softmax_output
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||||
net_a = Net(
|
net_a = Net(
|
||||||
@ -194,6 +199,7 @@ class ActorFactoryDiscreteNet(ActorFactory):
|
|||||||
envs.get_action_shape(),
|
envs.get_action_shape(),
|
||||||
hidden_sizes=(),
|
hidden_sizes=(),
|
||||||
device=device,
|
device=device,
|
||||||
|
softmax_output=self.softmax_output,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user