ExperimentBuilder: Allow experiment_config and sampling_config to be None

This commit is contained in:
Dominik Jain 2023-10-06 13:57:00 +02:00
parent 837ff13c04
commit a8dc75fbab

View File

@ -157,9 +157,13 @@ class ExperimentBuilder:
def __init__(
self,
env_factory: EnvFactory,
experiment_config: ExperimentConfig,
sampling_config: SamplingConfig,
experiment_config: ExperimentConfig | None = None,
sampling_config: SamplingConfig | None = None,
):
if experiment_config is None:
experiment_config = ExperimentConfig()
if sampling_config is None:
sampling_config = SamplingConfig()
self._config = experiment_config
self._env_factory = env_factory
self._sampling_config = sampling_config
@ -401,15 +405,14 @@ class A2CExperimentBuilder(
def __init__(
self,
env_factory: EnvFactory,
experiment_config: ExperimentConfig,
sampling_config: SamplingConfig,
env_config: PersistableConfigProtocol | None = None,
experiment_config: ExperimentConfig | None = None,
sampling_config: SamplingConfig | None = None,
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
self._params: A2CParams = A2CParams()
self._env_config = env_config
self._env_config = None
def with_a2c_params(self, params: A2CParams) -> Self:
self._params = params
@ -435,8 +438,8 @@ class PPOExperimentBuilder(
def __init__(
self,
env_factory: EnvFactory,
experiment_config: ExperimentConfig,
sampling_config: SamplingConfig,
experiment_config: ExperimentConfig | None = None,
sampling_config: SamplingConfig | None = None,
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
@ -466,8 +469,8 @@ class DQNExperimentBuilder(
def __init__(
self,
env_factory: EnvFactory,
experiment_config: ExperimentConfig,
sampling_config: SamplingConfig,
experiment_config: ExperimentConfig | None = None,
sampling_config: SamplingConfig | None = None,
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
@ -495,8 +498,8 @@ class DDPGExperimentBuilder(
def __init__(
self,
env_factory: EnvFactory,
experiment_config: ExperimentConfig,
sampling_config: SamplingConfig,
experiment_config: ExperimentConfig | None = None,
sampling_config: SamplingConfig | None = None,
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
@ -526,8 +529,8 @@ class SACExperimentBuilder(
def __init__(
self,
env_factory: EnvFactory,
experiment_config: ExperimentConfig,
sampling_config: SamplingConfig,
experiment_config: ExperimentConfig | None = None,
sampling_config: SamplingConfig | None = None,
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
@ -557,8 +560,8 @@ class TD3ExperimentBuilder(
def __init__(
self,
env_factory: EnvFactory,
experiment_config: ExperimentConfig,
sampling_config: SamplingConfig,
experiment_config: ExperimentConfig | None = None,
sampling_config: SamplingConfig | None = None,
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)