Reorder ExperimentBuilder args (EnvFactory first)
This commit is contained in:
parent
d269063e6a
commit
837ff13c04
@ -94,7 +94,7 @@ def main(
|
|||||||
policy.set_eps(eps_test)
|
policy.set_eps(eps_test)
|
||||||
|
|
||||||
builder = (
|
builder = (
|
||||||
DQNExperimentBuilder(experiment_config, env_factory, sampling_config)
|
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
.with_dqn_params(
|
.with_dqn_params(
|
||||||
DQNParams(
|
DQNParams(
|
||||||
discount_factor=gamma,
|
discount_factor=gamma,
|
||||||
|
@ -13,8 +13,8 @@ from examples.atari.atari_network import (
|
|||||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
PPOExperimentBuilder,
|
|
||||||
ExperimentConfig,
|
ExperimentConfig,
|
||||||
|
PPOExperimentBuilder,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
from tianshou.highlevel.params.policy_params import PPOParams
|
from tianshou.highlevel.params.policy_params import PPOParams
|
||||||
@ -76,7 +76,7 @@ def main(
|
|||||||
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
|
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
|
||||||
|
|
||||||
builder = (
|
builder = (
|
||||||
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
|
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
.with_ppo_params(
|
.with_ppo_params(
|
||||||
PPOParams(
|
PPOParams(
|
||||||
discount_factor=gamma,
|
discount_factor=gamma,
|
||||||
|
@ -58,7 +58,7 @@ def main(
|
|||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
A2CExperimentBuilder(experiment_config, env_factory, sampling_config)
|
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
.with_a2c_params(
|
.with_a2c_params(
|
||||||
A2CParams(
|
A2CParams(
|
||||||
discount_factor=gamma,
|
discount_factor=gamma,
|
||||||
|
@ -57,7 +57,7 @@ def main(
|
|||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
DDPGExperimentBuilder(experiment_config, env_factory, sampling_config)
|
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
.with_ddpg_params(
|
.with_ddpg_params(
|
||||||
DDPGParams(
|
DDPGParams(
|
||||||
actor_lr=actor_lr,
|
actor_lr=actor_lr,
|
||||||
|
@ -11,8 +11,8 @@ from torch.distributions import Independent, Normal
|
|||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
PPOExperimentBuilder,
|
|
||||||
ExperimentConfig,
|
ExperimentConfig,
|
||||||
|
PPOExperimentBuilder,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
from tianshou.highlevel.params.policy_params import PPOParams
|
from tianshou.highlevel.params.policy_params import PPOParams
|
||||||
@ -66,7 +66,7 @@ def main(
|
|||||||
return Independent(Normal(*logits), 1)
|
return Independent(Normal(*logits), 1)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
|
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
.with_ppo_params(
|
.with_ppo_params(
|
||||||
PPOParams(
|
PPOParams(
|
||||||
discount_factor=gamma,
|
discount_factor=gamma,
|
||||||
|
@ -58,7 +58,7 @@ def main(
|
|||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
SACExperimentBuilder(experiment_config, env_factory, sampling_config)
|
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
.with_sac_params(
|
.with_sac_params(
|
||||||
SACParams(
|
SACParams(
|
||||||
tau=tau,
|
tau=tau,
|
||||||
|
@ -62,7 +62,7 @@ def main(
|
|||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
TD3ExperimentBuilder(experiment_config, env_factory, sampling_config)
|
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
.with_td3_params(
|
.with_td3_params(
|
||||||
TD3Params(
|
TD3Params(
|
||||||
tau=tau,
|
tau=tau,
|
||||||
|
@ -6,8 +6,8 @@ from tianshou.highlevel.config import SamplingConfig
|
|||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
A2CExperimentBuilder,
|
A2CExperimentBuilder,
|
||||||
DDPGExperimentBuilder,
|
DDPGExperimentBuilder,
|
||||||
PPOExperimentBuilder,
|
|
||||||
ExperimentConfig,
|
ExperimentConfig,
|
||||||
|
PPOExperimentBuilder,
|
||||||
SACExperimentBuilder,
|
SACExperimentBuilder,
|
||||||
TD3ExperimentBuilder,
|
TD3ExperimentBuilder,
|
||||||
)
|
)
|
||||||
|
@ -6,8 +6,8 @@ from tianshou.highlevel.config import SamplingConfig
|
|||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
A2CExperimentBuilder,
|
A2CExperimentBuilder,
|
||||||
DQNExperimentBuilder,
|
DQNExperimentBuilder,
|
||||||
PPOExperimentBuilder,
|
|
||||||
ExperimentConfig,
|
ExperimentConfig,
|
||||||
|
PPOExperimentBuilder,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -156,8 +156,8 @@ TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
|
|||||||
class ExperimentBuilder:
|
class ExperimentBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experiment_config: ExperimentConfig,
|
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
|
experiment_config: ExperimentConfig,
|
||||||
sampling_config: SamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
):
|
):
|
||||||
self._config = experiment_config
|
self._config = experiment_config
|
||||||
@ -400,12 +400,12 @@ class A2CExperimentBuilder(
|
|||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experiment_config: ExperimentConfig,
|
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
|
experiment_config: ExperimentConfig,
|
||||||
sampling_config: SamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
env_config: PersistableConfigProtocol | None = None,
|
env_config: PersistableConfigProtocol | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||||
self._params: A2CParams = A2CParams()
|
self._params: A2CParams = A2CParams()
|
||||||
@ -434,11 +434,11 @@ class PPOExperimentBuilder(
|
|||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experiment_config: ExperimentConfig,
|
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
|
experiment_config: ExperimentConfig,
|
||||||
sampling_config: SamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||||
self._params: PPOParams = PPOParams()
|
self._params: PPOParams = PPOParams()
|
||||||
@ -465,11 +465,11 @@ class DQNExperimentBuilder(
|
|||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experiment_config: ExperimentConfig,
|
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
|
experiment_config: ExperimentConfig,
|
||||||
sampling_config: SamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
|
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
|
||||||
self._params: DQNParams = DQNParams()
|
self._params: DQNParams = DQNParams()
|
||||||
|
|
||||||
@ -494,11 +494,11 @@ class DDPGExperimentBuilder(
|
|||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experiment_config: ExperimentConfig,
|
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
|
experiment_config: ExperimentConfig,
|
||||||
sampling_config: SamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||||
self._params: DDPGParams = DDPGParams()
|
self._params: DDPGParams = DDPGParams()
|
||||||
@ -525,11 +525,11 @@ class SACExperimentBuilder(
|
|||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experiment_config: ExperimentConfig,
|
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
|
experiment_config: ExperimentConfig,
|
||||||
sampling_config: SamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
_BuilderMixinDualCriticFactory.__init__(self)
|
_BuilderMixinDualCriticFactory.__init__(self)
|
||||||
self._params: SACParams = SACParams()
|
self._params: SACParams = SACParams()
|
||||||
@ -556,11 +556,11 @@ class TD3ExperimentBuilder(
|
|||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experiment_config: ExperimentConfig,
|
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
|
experiment_config: ExperimentConfig,
|
||||||
sampling_config: SamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||||
_BuilderMixinDualCriticFactory.__init__(self)
|
_BuilderMixinDualCriticFactory.__init__(self)
|
||||||
self._params: TD3Params = TD3Params()
|
self._params: TD3Params = TD3Params()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user