Experiment builders for DQN and IQN:

* Fix: Disable softmax in default models
  * Add method with_model_factory_default (for DQN)
This commit is contained in:
Dominik Jain 2024-01-08 18:01:04 +01:00
parent f77d95da04
commit d4e4f4ff63

View File

@ -903,7 +903,7 @@ class DQNExperimentBuilder(
super().__init__(env_factory, experiment_config, sampling_config)
self._params: DQNParams = DQNParams()
self._model_factory: IntermediateModuleFactory = IntermediateModuleFactoryFromActorFactory(
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED, discrete_softmax=False),
)
def with_dqn_params(self, params: DQNParams) -> Self:
@ -911,9 +911,34 @@ class DQNExperimentBuilder(
return self
def with_model_factory(self, module_factory: IntermediateModuleFactory) -> Self:
""":param module_factory: factory for a module which maps environment observations to a vector of Q-values (one for each action)
:return: the builder
"""
self._model_factory = module_factory
return self
def with_model_factory_default(
self,
hidden_sizes: Sequence[int],
hidden_activation: ModuleType = torch.nn.ReLU,
) -> Self:
"""Allows to configure the default factory for the model of the Q function, which maps environment observations to a vector of
Q-values (one for each action). The default model is a multi-layer perceptron.
:param hidden_sizes: the sequence of dimensions used for hidden layers
:param hidden_activation: the activation function to use for hidden layers (not used for the output layer)
:return: the builder
"""
self._model_factory = IntermediateModuleFactoryFromActorFactory(
ActorFactoryDefault(
ContinuousActorType.UNSUPPORTED,
hidden_sizes=hidden_sizes,
hidden_activation=hidden_activation,
discrete_softmax=False,
),
)
return self
def _create_agent_factory(self) -> AgentFactory:
return DQNAgentFactory(
self._params,
@ -934,7 +959,7 @@ class IQNExperimentBuilder(ExperimentBuilder):
self._params: IQNParams = IQNParams()
self._preprocess_network_factory: IntermediateModuleFactory = (
IntermediateModuleFactoryFromActorFactory(
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED, discrete_softmax=False),
)
)