Reorder ExperimentBuilder args (EnvFactory first)
This commit is contained in:
parent
d269063e6a
commit
837ff13c04
@ -94,7 +94,7 @@ def main(
|
||||
policy.set_eps(eps_test)
|
||||
|
||||
builder = (
|
||||
DQNExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
.with_dqn_params(
|
||||
DQNParams(
|
||||
discount_factor=gamma,
|
||||
|
@ -13,8 +13,8 @@ from examples.atari.atari_network import (
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
PPOExperimentBuilder,
|
||||
ExperimentConfig,
|
||||
PPOExperimentBuilder,
|
||||
)
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||
from tianshou.highlevel.params.policy_params import PPOParams
|
||||
@ -76,7 +76,7 @@ def main(
|
||||
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
|
||||
|
||||
builder = (
|
||||
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
.with_ppo_params(
|
||||
PPOParams(
|
||||
discount_factor=gamma,
|
||||
|
@ -58,7 +58,7 @@ def main(
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
|
||||
experiment = (
|
||||
A2CExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
.with_a2c_params(
|
||||
A2CParams(
|
||||
discount_factor=gamma,
|
||||
|
@ -57,7 +57,7 @@ def main(
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
|
||||
experiment = (
|
||||
DDPGExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
.with_ddpg_params(
|
||||
DDPGParams(
|
||||
actor_lr=actor_lr,
|
||||
|
@ -11,8 +11,8 @@ from torch.distributions import Independent, Normal
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
PPOExperimentBuilder,
|
||||
ExperimentConfig,
|
||||
PPOExperimentBuilder,
|
||||
)
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||
from tianshou.highlevel.params.policy_params import PPOParams
|
||||
@ -66,7 +66,7 @@ def main(
|
||||
return Independent(Normal(*logits), 1)
|
||||
|
||||
experiment = (
|
||||
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
.with_ppo_params(
|
||||
PPOParams(
|
||||
discount_factor=gamma,
|
||||
|
@ -58,7 +58,7 @@ def main(
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
|
||||
experiment = (
|
||||
SACExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
.with_sac_params(
|
||||
SACParams(
|
||||
tau=tau,
|
||||
|
@ -62,7 +62,7 @@ def main(
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
|
||||
experiment = (
|
||||
TD3ExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
.with_td3_params(
|
||||
TD3Params(
|
||||
tau=tau,
|
||||
|
@ -6,8 +6,8 @@ from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
A2CExperimentBuilder,
|
||||
DDPGExperimentBuilder,
|
||||
PPOExperimentBuilder,
|
||||
ExperimentConfig,
|
||||
PPOExperimentBuilder,
|
||||
SACExperimentBuilder,
|
||||
TD3ExperimentBuilder,
|
||||
)
|
||||
|
@ -6,8 +6,8 @@ from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
A2CExperimentBuilder,
|
||||
DQNExperimentBuilder,
|
||||
PPOExperimentBuilder,
|
||||
ExperimentConfig,
|
||||
PPOExperimentBuilder,
|
||||
)
|
||||
|
||||
|
||||
|
@ -156,8 +156,8 @@ TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
|
||||
class ExperimentBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: ExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
experiment_config: ExperimentConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
):
|
||||
self._config = experiment_config
|
||||
@ -400,12 +400,12 @@ class A2CExperimentBuilder(
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: ExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
experiment_config: ExperimentConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
env_config: PersistableConfigProtocol | None = None,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||
self._params: A2CParams = A2CParams()
|
||||
@ -434,11 +434,11 @@ class PPOExperimentBuilder(
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: ExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
experiment_config: ExperimentConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||
self._params: PPOParams = PPOParams()
|
||||
@ -465,11 +465,11 @@ class DQNExperimentBuilder(
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: ExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
experiment_config: ExperimentConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
|
||||
self._params: DQNParams = DQNParams()
|
||||
|
||||
@ -494,11 +494,11 @@ class DDPGExperimentBuilder(
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: ExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
experiment_config: ExperimentConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||
self._params: DDPGParams = DDPGParams()
|
||||
@ -525,11 +525,11 @@ class SACExperimentBuilder(
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: ExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
experiment_config: ExperimentConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
_BuilderMixinDualCriticFactory.__init__(self)
|
||||
self._params: SACParams = SACParams()
|
||||
@ -556,11 +556,11 @@ class TD3ExperimentBuilder(
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: ExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
experiment_config: ExperimentConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||
_BuilderMixinDualCriticFactory.__init__(self)
|
||||
self._params: TD3Params = TD3Params()
|
||||
|
Loading…
x
Reference in New Issue
Block a user