Add DDPG high-level API and MuJoCo example

This commit is contained in:
Dominik Jain 2023-10-03 20:26:39 +02:00
parent 6b6d9ea609
commit 2671580c6c
7 changed files with 234 additions and 8 deletions

View File

@ -247,6 +247,9 @@ class ActorFactoryAtariDQN(ActorFactory):
class FeatureNetFactoryDQN(ModuleFactory):
def create_module(self, envs: Environments, device: TDevice) -> Module:
dqn = DQN(
*envs.get_observation_shape(), envs.get_action_shape(), device, features_only=True,
*envs.get_observation_shape(),
envs.get_action_shape(),
device,
features_only=True,
)
return Module(dqn.net, dqn.output_dim)

View File

@ -0,0 +1,78 @@
#!/usr/bin/env python3
import datetime
import os
from collections.abc import Sequence
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
DDPGExperimentBuilder,
RLExperimentConfig,
)
from tianshou.highlevel.params.noise import MaxActionScaledGaussian
from tianshou.highlevel.params.policy_params import DDPGParams
def main(
experiment_config: RLExperimentConfig,
task: str = "Ant-v3",
buffer_size: int = 1000000,
hidden_sizes: Sequence[int] = (256, 256),
actor_lr: float = 1e-3,
critic_lr: float = 1e-3,
gamma: float = 0.99,
tau: float = 0.005,
exploration_noise: float = 0.1,
start_timesteps: int = 25000,
epoch: int = 200,
step_per_epoch: int = 5000,
step_per_collect: int = 1,
update_per_step: int = 1,
n_step: int = 1,
batch_size: int = 256,
training_num: int = 1,
test_num: int = 10,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
update_per_step=update_per_step,
repeat_per_collect=None,
start_timesteps=start_timesteps,
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
experiment = (
DDPGExperimentBuilder(experiment_config, env_factory, sampling_config)
.with_ddpg_params(
DDPGParams(
actor_lr=actor_lr,
critic_lr=critic_lr,
gamma=gamma,
tau=tau,
exploration_noise=MaxActionScaledGaussian(exploration_noise),
estimation_step=n_step,
),
)
.with_actor_factory_default(hidden_sizes)
.with_critic_factory_default(hidden_sizes)
.build()
)
experiment.run(log_name)
if __name__ == "__main__":
CLI(main)

View File

@ -22,6 +22,7 @@ from tianshou.highlevel.module.module_opt import (
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.policy_params import (
A2CParams,
DDPGParams,
Params,
ParamTransformerData,
PPOParams,
@ -29,7 +30,14 @@ from tianshou.highlevel.params.policy_params import (
TD3Params,
)
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.policy import A2CPolicy, BasePolicy, PPOPolicy, SACPolicy, TD3Policy
from tianshou.policy import (
A2CPolicy,
BasePolicy,
DDPGPolicy,
PPOPolicy,
SACPolicy,
TD3Policy,
)
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import ActorCritic
@ -71,7 +79,8 @@ class AgentFactory(ABC):
return train_collector, test_collector
def set_policy_wrapper_factory(
self, policy_wrapper_factory: PolicyWrapperFactory | None,
self,
policy_wrapper_factory: PolicyWrapperFactory | None,
) -> None:
self.policy_wrapper_factory = policy_wrapper_factory
@ -83,7 +92,10 @@ class AgentFactory(ABC):
policy = self._create_policy(envs, device)
if self.policy_wrapper_factory is not None:
policy = self.policy_wrapper_factory.create_wrapped_policy(
policy, envs, self.optim_factory, device,
policy,
envs,
self.optim_factory,
device,
)
return policy
@ -372,6 +384,49 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
def __init__(
self,
params: DDPGParams,
sampling_config: RLSamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
_ActorAndCriticMixin.__init__(
self,
actor_factory,
critic_factory,
optim_factory,
critic_use_action=True,
)
self.params = params
self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
critic = self.create_critic_module_opt(envs, device, self.params.critic_lr)
kwargs = self.params.create_kwargs(
ParamTransformerData(
envs=envs,
device=device,
optim_factory=self.optim_factory,
actor=actor,
critic1=critic,
),
)
return DDPGPolicy(
actor=actor.module,
actor_optim=actor.optim,
critic=critic.module,
critic_optim=critic.optim,
action_space=envs.get_action_space(),
observation_space=envs.get_observation_space(),
**kwargs,
)
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
def __init__(
self,

View File

@ -13,7 +13,7 @@ class RLSamplingConfig:
num_test_envs: int = 10
buffer_size: int = 4096
step_per_collect: int = 2048
repeat_per_collect: int = 10
repeat_per_collect: int | None = 10
update_per_step: int = 1
start_timesteps: int = 0
start_timesteps_random: bool = False

View File

@ -11,6 +11,7 @@ from tianshou.data import Collector
from tianshou.highlevel.agent import (
A2CAgentFactory,
AgentFactory,
DDPGAgentFactory,
PPOAgentFactory,
SACAgentFactory,
TD3AgentFactory,
@ -27,6 +28,7 @@ from tianshou.highlevel.module.critic import CriticFactory, CriticFactoryDefault
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
from tianshou.highlevel.params.policy_params import (
A2CParams,
DDPGParams,
PPOParams,
SACParams,
TD3Params,
@ -406,13 +408,11 @@ class PPOExperimentBuilder(
experiment_config: RLExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
env_config: PersistableConfigProtocol | None = None,
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticFactory.__init__(self)
self._params: PPOParams = PPOParams()
self._env_config = env_config
def with_ppo_params(self, params: PPOParams) -> Self:
self._params = params
@ -430,6 +430,39 @@ class PPOExperimentBuilder(
)
class DDPGExperimentBuilder(
RLExperimentBuilder,
_BuilderMixinActorFactory_ContinuousDeterministic,
_BuilderMixinSingleCriticFactory,
):
def __init__(
self,
experiment_config: RLExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
env_config: PersistableConfigProtocol | None = None,
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
_BuilderMixinSingleCriticFactory.__init__(self)
self._params: DDPGParams = DDPGParams()
self._env_config = env_config
def with_ddpg_params(self, params: DDPGParams) -> Self:
self._params = params
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
return DDPGAgentFactory(
self._params,
self._sampling_config,
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_optim_factory(),
)
class SACExperimentBuilder(
RLExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian,

View File

@ -128,6 +128,28 @@ class ParamTransformerMultiLRScheduler(ParamTransformer):
params[self.key_scheduler] = lr_scheduler
class ParamTransformerActorAndCriticLRScheduler(ParamTransformer):
def __init__(
self,
key_scheduler_factory_actor: str,
key_scheduler_factory_critic: str,
key_scheduler: str,
):
self.key_factory_actor = key_scheduler_factory_actor
self.key_factory_critic = key_scheduler_factory_critic
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
transformer = ParamTransformerMultiLRScheduler(
[
(data.actor.optim, self.key_factory_actor),
(data.critic1.optim, self.key_factory_critic),
],
self.key_scheduler,
)
transformer.transform(params, data)
class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
def __init__(
self,
@ -232,6 +254,24 @@ class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
]
@dataclass
class ParamsMixinActorAndCritic(GetParamTransformersProtocol):
actor_lr: float = 1e-3
critic_lr: float = 1e-3
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
critic_lr_scheduler_factory: LRSchedulerFactory | None = None
def _get_param_transformers(self):
return [
ParamTransformerDrop("actor_lr", "critic_lr"),
ParamTransformerActorAndCriticLRScheduler(
"actor_lr_scheduler_factory",
"critic_lr_scheduler_factory",
"lr_scheduler",
),
]
@dataclass
class PGParams(Params):
"""Config of general policy-gradient algorithms."""
@ -316,6 +356,22 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
return transformers
@dataclass
class DDPGParams(Params, ParamsMixinActorAndCritic):
tau: float = 0.005
gamma: float = 0.99
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
estimation_step: int = 1
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
def _get_param_transformers(self):
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self))
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
return transformers
@dataclass
class TD3Params(Params, ParamsMixinActorAndDualCritics):
tau: float = 0.005

View File

@ -25,7 +25,8 @@ class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ABC):
class PolicyWrapperFactoryIntrinsicCuriosity(
Generic[TPolicyIn], PolicyWrapperFactory[TPolicyIn, ICMPolicy],
Generic[TPolicyIn],
PolicyWrapperFactory[TPolicyIn, ICMPolicy],
):
def __init__(
self,