Extend tests, fixing some default behaviour
This commit is contained in:
parent
a8a367c42d
commit
686fd555b0
@ -68,7 +68,11 @@ def main(
|
||||
)
|
||||
|
||||
env_factory = AtariEnvFactory(
|
||||
task, experiment_config.seed, sampling_config, frames_stack, scale=scale_obs,
|
||||
task,
|
||||
experiment_config.seed,
|
||||
sampling_config,
|
||||
frames_stack,
|
||||
scale=scale_obs,
|
||||
)
|
||||
|
||||
builder = (
|
||||
|
@ -7,9 +7,12 @@ from tianshou.highlevel.experiment import (
|
||||
A2CExperimentBuilder,
|
||||
DDPGExperimentBuilder,
|
||||
ExperimentConfig,
|
||||
PGExperimentBuilder,
|
||||
PPOExperimentBuilder,
|
||||
REDQExperimentBuilder,
|
||||
SACExperimentBuilder,
|
||||
TD3ExperimentBuilder,
|
||||
TRPOExperimentBuilder,
|
||||
)
|
||||
|
||||
|
||||
@ -21,6 +24,10 @@ from tianshou.highlevel.experiment import (
|
||||
SACExperimentBuilder,
|
||||
DDPGExperimentBuilder,
|
||||
TD3ExperimentBuilder,
|
||||
# NPGExperimentBuilder, # TODO test fails non-deterministically
|
||||
REDQExperimentBuilder,
|
||||
TRPOExperimentBuilder,
|
||||
PGExperimentBuilder,
|
||||
],
|
||||
)
|
||||
def test_experiment_builder_continuous_default_params(builder_cls):
|
||||
|
@ -5,17 +5,27 @@ import pytest
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
A2CExperimentBuilder,
|
||||
DiscreteSACExperimentBuilder,
|
||||
DQNExperimentBuilder,
|
||||
ExperimentConfig,
|
||||
IQNExperimentBuilder,
|
||||
PPOExperimentBuilder,
|
||||
)
|
||||
from tianshou.utils import logging
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"builder_cls",
|
||||
[PPOExperimentBuilder, A2CExperimentBuilder, DQNExperimentBuilder],
|
||||
[
|
||||
PPOExperimentBuilder,
|
||||
A2CExperimentBuilder,
|
||||
DQNExperimentBuilder,
|
||||
DiscreteSACExperimentBuilder,
|
||||
IQNExperimentBuilder,
|
||||
],
|
||||
)
|
||||
def test_experiment_builder_discrete_default_params(builder_cls):
|
||||
logging.configure()
|
||||
env_factory = DiscreteTestEnvFactory()
|
||||
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
||||
builder = builder_cls(
|
||||
|
@ -67,7 +67,8 @@ TParams = TypeVar("TParams", bound=Params)
|
||||
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
|
||||
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
|
||||
TDiscreteCriticOnlyParams = TypeVar(
|
||||
"TDiscreteCriticOnlyParams", bound=ParamsMixinLearningRateWithScheduler,
|
||||
"TDiscreteCriticOnlyParams",
|
||||
bound=ParamsMixinLearningRateWithScheduler,
|
||||
)
|
||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
|
||||
@ -403,7 +404,8 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
|
||||
|
||||
|
||||
class DiscreteCriticOnlyAgentFactory(
|
||||
OffpolicyAgentFactory, Generic[TDiscreteCriticOnlyParams, TPolicy],
|
||||
OffpolicyAgentFactory,
|
||||
Generic[TDiscreteCriticOnlyParams, TPolicy],
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@ -583,6 +585,10 @@ class ActorDualCriticsAgentFactory(
|
||||
def _get_discrete_last_size_use_action_shape(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _get_critic_use_action(envs: Environments) -> bool:
|
||||
return envs.get_type().is_continuous()
|
||||
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||
actor = self.actor_factory.create_module_opt(
|
||||
envs,
|
||||
@ -591,10 +597,11 @@ class ActorDualCriticsAgentFactory(
|
||||
self.params.actor_lr,
|
||||
)
|
||||
use_action_shape = self._get_discrete_last_size_use_action_shape()
|
||||
critic_use_action = self._get_critic_use_action(envs)
|
||||
critic1 = self.critic1_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
True,
|
||||
critic_use_action,
|
||||
self.optim_factory,
|
||||
self.params.critic1_lr,
|
||||
discrete_last_size_use_action_shape=use_action_shape,
|
||||
@ -602,7 +609,7 @@ class ActorDualCriticsAgentFactory(
|
||||
critic2 = self.critic2_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
True,
|
||||
critic_use_action,
|
||||
self.optim_factory,
|
||||
self.params.critic2_lr,
|
||||
discrete_last_size_use_action_shape=use_action_shape,
|
||||
|
@ -483,14 +483,13 @@ class PGExperimentBuilder(
|
||||
):
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
self._params: A2CParams = A2CParams()
|
||||
self._params: PGParams = PGParams()
|
||||
self._env_config = None
|
||||
|
||||
def with_pg_params(self, params: PGParams) -> Self:
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return PGAgentFactory(
|
||||
self._params,
|
||||
|
@ -92,11 +92,13 @@ class ActorFactoryDefault(ActorFactory):
|
||||
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
|
||||
continuous_unbounded: bool = False,
|
||||
continuous_conditioned_sigma: bool = False,
|
||||
discrete_softmax: bool = True,
|
||||
):
|
||||
self.continuous_actor_type = continuous_actor_type
|
||||
self.continuous_unbounded = continuous_unbounded
|
||||
self.continuous_conditioned_sigma = continuous_conditioned_sigma
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.discrete_softmax = discrete_softmax
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||
env_type = envs.get_type()
|
||||
@ -117,7 +119,9 @@ class ActorFactoryDefault(ActorFactory):
|
||||
raise ValueError(self.continuous_actor_type)
|
||||
return factory.create_module(envs, device)
|
||||
elif env_type == EnvType.DISCRETE:
|
||||
factory = ActorFactoryDiscreteNet(self.DEFAULT_HIDDEN_SIZES)
|
||||
factory = ActorFactoryDiscreteNet(
|
||||
self.DEFAULT_HIDDEN_SIZES, softmax_output=self.discrete_softmax,
|
||||
)
|
||||
return factory.create_module(envs, device)
|
||||
else:
|
||||
raise ValueError(f"{env_type} not supported")
|
||||
@ -180,8 +184,9 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
|
||||
|
||||
|
||||
class ActorFactoryDiscreteNet(ActorFactory):
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
def __init__(self, hidden_sizes: Sequence[int], softmax_output: bool = True):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.softmax_output = softmax_output
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||
net_a = Net(
|
||||
@ -194,6 +199,7 @@ class ActorFactoryDiscreteNet(ActorFactory):
|
||||
envs.get_action_shape(),
|
||||
hidden_sizes=(),
|
||||
device=device,
|
||||
softmax_output=self.softmax_output,
|
||||
).to(device)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user