Reorder ExperimentBuilder args (EnvFactory first)

This commit is contained in:
Dominik Jain 2023-10-06 13:53:45 +02:00
parent d269063e6a
commit 837ff13c04
10 changed files with 24 additions and 24 deletions

View File

@ -94,7 +94,7 @@ def main(
policy.set_eps(eps_test)
builder = (
DQNExperimentBuilder(experiment_config, env_factory, sampling_config)
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_dqn_params(
DQNParams(
discount_factor=gamma,

View File

@ -13,8 +13,8 @@ from examples.atari.atari_network import (
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
PPOExperimentBuilder,
ExperimentConfig,
PPOExperimentBuilder,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
@ -76,7 +76,7 @@ def main(
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
builder = (
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_ppo_params(
PPOParams(
discount_factor=gamma,

View File

@ -58,7 +58,7 @@ def main(
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
experiment = (
A2CExperimentBuilder(experiment_config, env_factory, sampling_config)
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_a2c_params(
A2CParams(
discount_factor=gamma,

View File

@ -57,7 +57,7 @@ def main(
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
experiment = (
DDPGExperimentBuilder(experiment_config, env_factory, sampling_config)
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_ddpg_params(
DDPGParams(
actor_lr=actor_lr,

View File

@ -11,8 +11,8 @@ from torch.distributions import Independent, Normal
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
PPOExperimentBuilder,
ExperimentConfig,
PPOExperimentBuilder,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
@ -66,7 +66,7 @@ def main(
return Independent(Normal(*logits), 1)
experiment = (
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_ppo_params(
PPOParams(
discount_factor=gamma,

View File

@ -58,7 +58,7 @@ def main(
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
experiment = (
SACExperimentBuilder(experiment_config, env_factory, sampling_config)
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_sac_params(
SACParams(
tau=tau,

View File

@ -62,7 +62,7 @@ def main(
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
experiment = (
TD3ExperimentBuilder(experiment_config, env_factory, sampling_config)
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_td3_params(
TD3Params(
tau=tau,

View File

@ -6,8 +6,8 @@ from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
A2CExperimentBuilder,
DDPGExperimentBuilder,
PPOExperimentBuilder,
ExperimentConfig,
PPOExperimentBuilder,
SACExperimentBuilder,
TD3ExperimentBuilder,
)

View File

@ -6,8 +6,8 @@ from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
A2CExperimentBuilder,
DQNExperimentBuilder,
PPOExperimentBuilder,
ExperimentConfig,
PPOExperimentBuilder,
)

View File

@ -156,8 +156,8 @@ TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
class ExperimentBuilder:
def __init__(
self,
experiment_config: ExperimentConfig,
env_factory: EnvFactory,
experiment_config: ExperimentConfig,
sampling_config: SamplingConfig,
):
self._config = experiment_config
@ -400,12 +400,12 @@ class A2CExperimentBuilder(
):
def __init__(
self,
experiment_config: ExperimentConfig,
env_factory: EnvFactory,
experiment_config: ExperimentConfig,
sampling_config: SamplingConfig,
env_config: PersistableConfigProtocol | None = None,
):
super().__init__(experiment_config, env_factory, sampling_config)
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
self._params: A2CParams = A2CParams()
@ -434,11 +434,11 @@ class PPOExperimentBuilder(
):
def __init__(
self,
experiment_config: ExperimentConfig,
env_factory: EnvFactory,
experiment_config: ExperimentConfig,
sampling_config: SamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
self._params: PPOParams = PPOParams()
@ -465,11 +465,11 @@ class DQNExperimentBuilder(
):
def __init__(
self,
experiment_config: ExperimentConfig,
env_factory: EnvFactory,
experiment_config: ExperimentConfig,
sampling_config: SamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
self._params: DQNParams = DQNParams()
@ -494,11 +494,11 @@ class DDPGExperimentBuilder(
):
def __init__(
self,
experiment_config: ExperimentConfig,
env_factory: EnvFactory,
experiment_config: ExperimentConfig,
sampling_config: SamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
self._params: DDPGParams = DDPGParams()
@ -525,11 +525,11 @@ class SACExperimentBuilder(
):
def __init__(
self,
experiment_config: ExperimentConfig,
env_factory: EnvFactory,
experiment_config: ExperimentConfig,
sampling_config: SamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self)
self._params: SACParams = SACParams()
@ -556,11 +556,11 @@ class TD3ExperimentBuilder(
):
def __init__(
self,
experiment_config: ExperimentConfig,
env_factory: EnvFactory,
experiment_config: ExperimentConfig,
sampling_config: SamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self)
self._params: TD3Params = TD3Params()