Add DDPG high-level API and MuJoCo example
This commit is contained in:
parent
6b6d9ea609
commit
2671580c6c
@ -247,6 +247,9 @@ class ActorFactoryAtariDQN(ActorFactory):
|
|||||||
class FeatureNetFactoryDQN(ModuleFactory):
|
class FeatureNetFactoryDQN(ModuleFactory):
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
||||||
dqn = DQN(
|
dqn = DQN(
|
||||||
*envs.get_observation_shape(), envs.get_action_shape(), device, features_only=True,
|
*envs.get_observation_shape(),
|
||||||
|
envs.get_action_shape(),
|
||||||
|
device,
|
||||||
|
features_only=True,
|
||||||
)
|
)
|
||||||
return Module(dqn.net, dqn.output_dim)
|
return Module(dqn.net, dqn.output_dim)
|
||||||
|
78
examples/mujoco/mujoco_ddpg_hl.py
Normal file
78
examples/mujoco/mujoco_ddpg_hl.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from jsonargparse import CLI
|
||||||
|
|
||||||
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
|
from tianshou.highlevel.experiment import (
|
||||||
|
DDPGExperimentBuilder,
|
||||||
|
RLExperimentConfig,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.params.noise import MaxActionScaledGaussian
|
||||||
|
from tianshou.highlevel.params.policy_params import DDPGParams
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
experiment_config: RLExperimentConfig,
|
||||||
|
task: str = "Ant-v3",
|
||||||
|
buffer_size: int = 1000000,
|
||||||
|
hidden_sizes: Sequence[int] = (256, 256),
|
||||||
|
actor_lr: float = 1e-3,
|
||||||
|
critic_lr: float = 1e-3,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
tau: float = 0.005,
|
||||||
|
exploration_noise: float = 0.1,
|
||||||
|
start_timesteps: int = 25000,
|
||||||
|
epoch: int = 200,
|
||||||
|
step_per_epoch: int = 5000,
|
||||||
|
step_per_collect: int = 1,
|
||||||
|
update_per_step: int = 1,
|
||||||
|
n_step: int = 1,
|
||||||
|
batch_size: int = 256,
|
||||||
|
training_num: int = 1,
|
||||||
|
test_num: int = 10,
|
||||||
|
):
|
||||||
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
|
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||||
|
|
||||||
|
sampling_config = RLSamplingConfig(
|
||||||
|
num_epochs=epoch,
|
||||||
|
step_per_epoch=step_per_epoch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_train_envs=training_num,
|
||||||
|
num_test_envs=test_num,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
step_per_collect=step_per_collect,
|
||||||
|
update_per_step=update_per_step,
|
||||||
|
repeat_per_collect=None,
|
||||||
|
start_timesteps=start_timesteps,
|
||||||
|
start_timesteps_random=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
|
experiment = (
|
||||||
|
DDPGExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||||
|
.with_ddpg_params(
|
||||||
|
DDPGParams(
|
||||||
|
actor_lr=actor_lr,
|
||||||
|
critic_lr=critic_lr,
|
||||||
|
gamma=gamma,
|
||||||
|
tau=tau,
|
||||||
|
exploration_noise=MaxActionScaledGaussian(exploration_noise),
|
||||||
|
estimation_step=n_step,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.with_actor_factory_default(hidden_sizes)
|
||||||
|
.with_critic_factory_default(hidden_sizes)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
experiment.run(log_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
CLI(main)
|
@ -22,6 +22,7 @@ from tianshou.highlevel.module.module_opt import (
|
|||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.highlevel.params.policy_params import (
|
from tianshou.highlevel.params.policy_params import (
|
||||||
A2CParams,
|
A2CParams,
|
||||||
|
DDPGParams,
|
||||||
Params,
|
Params,
|
||||||
ParamTransformerData,
|
ParamTransformerData,
|
||||||
PPOParams,
|
PPOParams,
|
||||||
@ -29,7 +30,14 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
TD3Params,
|
TD3Params,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||||
from tianshou.policy import A2CPolicy, BasePolicy, PPOPolicy, SACPolicy, TD3Policy
|
from tianshou.policy import (
|
||||||
|
A2CPolicy,
|
||||||
|
BasePolicy,
|
||||||
|
DDPGPolicy,
|
||||||
|
PPOPolicy,
|
||||||
|
SACPolicy,
|
||||||
|
TD3Policy,
|
||||||
|
)
|
||||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||||
from tianshou.utils.net import continuous, discrete
|
from tianshou.utils.net import continuous, discrete
|
||||||
from tianshou.utils.net.common import ActorCritic
|
from tianshou.utils.net.common import ActorCritic
|
||||||
@ -71,7 +79,8 @@ class AgentFactory(ABC):
|
|||||||
return train_collector, test_collector
|
return train_collector, test_collector
|
||||||
|
|
||||||
def set_policy_wrapper_factory(
|
def set_policy_wrapper_factory(
|
||||||
self, policy_wrapper_factory: PolicyWrapperFactory | None,
|
self,
|
||||||
|
policy_wrapper_factory: PolicyWrapperFactory | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.policy_wrapper_factory = policy_wrapper_factory
|
self.policy_wrapper_factory = policy_wrapper_factory
|
||||||
|
|
||||||
@ -83,7 +92,10 @@ class AgentFactory(ABC):
|
|||||||
policy = self._create_policy(envs, device)
|
policy = self._create_policy(envs, device)
|
||||||
if self.policy_wrapper_factory is not None:
|
if self.policy_wrapper_factory is not None:
|
||||||
policy = self.policy_wrapper_factory.create_wrapped_policy(
|
policy = self.policy_wrapper_factory.create_wrapped_policy(
|
||||||
policy, envs, self.optim_factory, device,
|
policy,
|
||||||
|
envs,
|
||||||
|
self.optim_factory,
|
||||||
|
device,
|
||||||
)
|
)
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
@ -372,6 +384,49 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
|||||||
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||||
|
|
||||||
|
|
||||||
|
class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: DDPGParams,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
actor_factory: ActorFactory,
|
||||||
|
critic_factory: CriticFactory,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
):
|
||||||
|
super().__init__(sampling_config, optim_factory)
|
||||||
|
_ActorAndCriticMixin.__init__(
|
||||||
|
self,
|
||||||
|
actor_factory,
|
||||||
|
critic_factory,
|
||||||
|
optim_factory,
|
||||||
|
critic_use_action=True,
|
||||||
|
)
|
||||||
|
self.params = params
|
||||||
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
|
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
|
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
||||||
|
critic = self.create_critic_module_opt(envs, device, self.params.critic_lr)
|
||||||
|
kwargs = self.params.create_kwargs(
|
||||||
|
ParamTransformerData(
|
||||||
|
envs=envs,
|
||||||
|
device=device,
|
||||||
|
optim_factory=self.optim_factory,
|
||||||
|
actor=actor,
|
||||||
|
critic1=critic,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return DDPGPolicy(
|
||||||
|
actor=actor.module,
|
||||||
|
actor_optim=actor.optim,
|
||||||
|
critic=critic.module,
|
||||||
|
critic_optim=critic.optim,
|
||||||
|
action_space=envs.get_action_space(),
|
||||||
|
observation_space=envs.get_observation_space(),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -13,7 +13,7 @@ class RLSamplingConfig:
|
|||||||
num_test_envs: int = 10
|
num_test_envs: int = 10
|
||||||
buffer_size: int = 4096
|
buffer_size: int = 4096
|
||||||
step_per_collect: int = 2048
|
step_per_collect: int = 2048
|
||||||
repeat_per_collect: int = 10
|
repeat_per_collect: int | None = 10
|
||||||
update_per_step: int = 1
|
update_per_step: int = 1
|
||||||
start_timesteps: int = 0
|
start_timesteps: int = 0
|
||||||
start_timesteps_random: bool = False
|
start_timesteps_random: bool = False
|
||||||
|
@ -11,6 +11,7 @@ from tianshou.data import Collector
|
|||||||
from tianshou.highlevel.agent import (
|
from tianshou.highlevel.agent import (
|
||||||
A2CAgentFactory,
|
A2CAgentFactory,
|
||||||
AgentFactory,
|
AgentFactory,
|
||||||
|
DDPGAgentFactory,
|
||||||
PPOAgentFactory,
|
PPOAgentFactory,
|
||||||
SACAgentFactory,
|
SACAgentFactory,
|
||||||
TD3AgentFactory,
|
TD3AgentFactory,
|
||||||
@ -27,6 +28,7 @@ from tianshou.highlevel.module.critic import CriticFactory, CriticFactoryDefault
|
|||||||
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
||||||
from tianshou.highlevel.params.policy_params import (
|
from tianshou.highlevel.params.policy_params import (
|
||||||
A2CParams,
|
A2CParams,
|
||||||
|
DDPGParams,
|
||||||
PPOParams,
|
PPOParams,
|
||||||
SACParams,
|
SACParams,
|
||||||
TD3Params,
|
TD3Params,
|
||||||
@ -406,13 +408,11 @@ class PPOExperimentBuilder(
|
|||||||
experiment_config: RLExperimentConfig,
|
experiment_config: RLExperimentConfig,
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: RLSamplingConfig,
|
||||||
env_config: PersistableConfigProtocol | None = None,
|
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||||
self._params: PPOParams = PPOParams()
|
self._params: PPOParams = PPOParams()
|
||||||
self._env_config = env_config
|
|
||||||
|
|
||||||
def with_ppo_params(self, params: PPOParams) -> Self:
|
def with_ppo_params(self, params: PPOParams) -> Self:
|
||||||
self._params = params
|
self._params = params
|
||||||
@ -430,6 +430,39 @@ class PPOExperimentBuilder(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DDPGExperimentBuilder(
|
||||||
|
RLExperimentBuilder,
|
||||||
|
_BuilderMixinActorFactory_ContinuousDeterministic,
|
||||||
|
_BuilderMixinSingleCriticFactory,
|
||||||
|
):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
experiment_config: RLExperimentConfig,
|
||||||
|
env_factory: EnvFactory,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
env_config: PersistableConfigProtocol | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
|
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||||
|
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||||
|
self._params: DDPGParams = DDPGParams()
|
||||||
|
self._env_config = env_config
|
||||||
|
|
||||||
|
def with_ddpg_params(self, params: DDPGParams) -> Self:
|
||||||
|
self._params = params
|
||||||
|
return self
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
|
return DDPGAgentFactory(
|
||||||
|
self._params,
|
||||||
|
self._sampling_config,
|
||||||
|
self._get_actor_factory(),
|
||||||
|
self._get_critic_factory(0),
|
||||||
|
self._get_optim_factory(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SACExperimentBuilder(
|
class SACExperimentBuilder(
|
||||||
RLExperimentBuilder,
|
RLExperimentBuilder,
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||||
|
@ -128,6 +128,28 @@ class ParamTransformerMultiLRScheduler(ParamTransformer):
|
|||||||
params[self.key_scheduler] = lr_scheduler
|
params[self.key_scheduler] = lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class ParamTransformerActorAndCriticLRScheduler(ParamTransformer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
key_scheduler_factory_actor: str,
|
||||||
|
key_scheduler_factory_critic: str,
|
||||||
|
key_scheduler: str,
|
||||||
|
):
|
||||||
|
self.key_factory_actor = key_scheduler_factory_actor
|
||||||
|
self.key_factory_critic = key_scheduler_factory_critic
|
||||||
|
self.key_scheduler = key_scheduler
|
||||||
|
|
||||||
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||||
|
transformer = ParamTransformerMultiLRScheduler(
|
||||||
|
[
|
||||||
|
(data.actor.optim, self.key_factory_actor),
|
||||||
|
(data.critic1.optim, self.key_factory_critic),
|
||||||
|
],
|
||||||
|
self.key_scheduler,
|
||||||
|
)
|
||||||
|
transformer.transform(params, data)
|
||||||
|
|
||||||
|
|
||||||
class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
|
class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -232,6 +254,24 @@ class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParamsMixinActorAndCritic(GetParamTransformersProtocol):
|
||||||
|
actor_lr: float = 1e-3
|
||||||
|
critic_lr: float = 1e-3
|
||||||
|
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
critic_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
|
||||||
|
def _get_param_transformers(self):
|
||||||
|
return [
|
||||||
|
ParamTransformerDrop("actor_lr", "critic_lr"),
|
||||||
|
ParamTransformerActorAndCriticLRScheduler(
|
||||||
|
"actor_lr_scheduler_factory",
|
||||||
|
"critic_lr_scheduler_factory",
|
||||||
|
"lr_scheduler",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGParams(Params):
|
class PGParams(Params):
|
||||||
"""Config of general policy-gradient algorithms."""
|
"""Config of general policy-gradient algorithms."""
|
||||||
@ -316,6 +356,22 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
|
|||||||
return transformers
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DDPGParams(Params, ParamsMixinActorAndCritic):
|
||||||
|
tau: float = 0.005
|
||||||
|
gamma: float = 0.99
|
||||||
|
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
|
||||||
|
estimation_step: int = 1
|
||||||
|
action_scaling: bool = True
|
||||||
|
action_bound_method: Literal["clip"] | None = "clip"
|
||||||
|
|
||||||
|
def _get_param_transformers(self):
|
||||||
|
transformers = super()._get_param_transformers()
|
||||||
|
transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self))
|
||||||
|
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
||||||
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TD3Params(Params, ParamsMixinActorAndDualCritics):
|
class TD3Params(Params, ParamsMixinActorAndDualCritics):
|
||||||
tau: float = 0.005
|
tau: float = 0.005
|
||||||
|
@ -25,7 +25,8 @@ class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ABC):
|
|||||||
|
|
||||||
|
|
||||||
class PolicyWrapperFactoryIntrinsicCuriosity(
|
class PolicyWrapperFactoryIntrinsicCuriosity(
|
||||||
Generic[TPolicyIn], PolicyWrapperFactory[TPolicyIn, ICMPolicy],
|
Generic[TPolicyIn],
|
||||||
|
PolicyWrapperFactory[TPolicyIn, ICMPolicy],
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user