diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index e474741..8396900 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -5,7 +5,11 @@ import numpy as np import torch from torch import nn -from tianshou.utils.net.discrete import NoisyLinear +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module.actor import ActorFactory +from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice +from tianshou.utils.net.common import BaseActor +from tianshou.utils.net.discrete import Actor, NoisyLinear def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module: @@ -220,3 +224,29 @@ class QRDQN(DQN): obs, state = super().forward(obs) obs = obs.view(-1, self.action_num, self.num_quantiles) return obs, state + + +class ActorFactoryAtariDQN(ActorFactory): + def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool): + self.hidden_size = hidden_size + self.scale_obs = scale_obs + + def create_module(self, envs: Environments, device: TDevice) -> BaseActor: + net_cls = scale_obs(DQN) if self.scale_obs else DQN + net = net_cls( + *envs.get_observation_shape(), + envs.get_action_shape(), + device=device, + features_only=True, + output_dim=self.hidden_size, + layer_init=layer_init, + ) + return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device) + + +class FeatureNetFactoryDQN(ModuleFactory): + def create_module(self, envs: Environments, device: TDevice) -> Module: + dqn = DQN( + *envs.get_observation_shape(), envs.get_action_shape(), device, features_only=True, + ) + return Module(dqn.net, dqn.output_dim) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py new file mode 100644 index 0000000..12939e9 --- /dev/null +++ b/examples/atari/atari_ppo_hl.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 + +import datetime +import os +from collections.abc import Sequence + +from jsonargparse import CLI + +from examples.atari.atari_network import ( + ActorFactoryAtariDQN, + FeatureNetFactoryDQN, +) +from examples.atari.atari_wrapper import AtariEnvFactory +from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.experiment import ( + PPOExperimentBuilder, + RLExperimentConfig, +) +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear +from tianshou.highlevel.params.policy_params import PPOParams +from tianshou.highlevel.params.policy_wrapper import ( + PolicyWrapperFactoryIntrinsicCuriosity, +) + + +def main( + experiment_config: RLExperimentConfig, + task: str = "PongNoFrameskip-v4", + scale_obs: bool = True, + buffer_size: int = 100000, + lr: float = 2.5e-4, + gamma: float = 0.99, + epoch: int = 100, + step_per_epoch: int = 100000, + step_per_collect: int = 1000, + repeat_per_collect: int = 4, + batch_size: int = 256, + hidden_sizes: int | Sequence[int] = 512, + training_num: int = 10, + test_num: int = 10, + rew_norm: bool = False, + vf_coef: float = 0.25, + ent_coef: float = 0.01, + gae_lambda: float = 0.95, + lr_decay: bool = True, + max_grad_norm: float = 0.5, + eps_clip: float = 0.1, + dual_clip: float | None = None, + value_clip: bool = True, + norm_adv: bool = True, + recompute_adv: bool = False, + frames_stack: int = 4, + save_buffer_name: str | None = None, # TODO add support in high-level API? + icm_lr_scale: float = 0.0, + icm_reward_scale: float = 0.01, + icm_forward_loss_weight: float = 0.2, +): + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) + + sampling_config = RLSamplingConfig( + num_epochs=epoch, + step_per_epoch=step_per_epoch, + batch_size=batch_size, + num_train_envs=training_num, + num_test_envs=test_num, + buffer_size=buffer_size, + step_per_collect=step_per_collect, + repeat_per_collect=repeat_per_collect, + replay_buffer_stack_num=frames_stack, + replay_buffer_ignore_obs_next=True, + replay_buffer_save_only_last_obs=True, + ) + + env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack) + + builder = ( + PPOExperimentBuilder(experiment_config, env_factory, sampling_config) + .with_ppo_params( + PPOParams( + discount_factor=gamma, + gae_lambda=gae_lambda, + reward_normalization=rew_norm, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + value_clip=value_clip, + advantage_normalization=norm_adv, + eps_clip=eps_clip, + dual_clip=dual_clip, + recompute_advantage=recompute_adv, + lr=lr, + lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) + if lr_decay + else None, + ), + ) + .with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs)) + .with_critic_factory_use_actor() + ) + if icm_lr_scale > 0: + builder.with_policy_wrapper_factory( + PolicyWrapperFactoryIntrinsicCuriosity( + FeatureNetFactoryDQN(), + [hidden_sizes], + lr, + icm_lr_scale, + icm_reward_scale, + icm_forward_loss_weight, + ), + ) + experiment = builder.build() + experiment.run(log_name) + + +if __name__ == "__main__": + CLI(main) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 462fb67..347294a 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -9,6 +9,8 @@ import gymnasium as gym import numpy as np from tianshou.env import ShmemVectorEnv +from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory try: import envpool @@ -369,3 +371,22 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs): train_envs.seed(seed) test_envs.seed(seed) return env, train_envs, test_envs + + +class AtariEnvFactory(EnvFactory): + def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig, frame_stack: int): + self.task = task + self.sampling_config = sampling_config + self.seed = seed + self.frame_stack = frame_stack + + def create_envs(self, config=None) -> DiscreteEnvironments: + env, train_envs, test_envs = make_atari_env( + task=self.task, + seed=self.seed, + training_num=self.sampling_config.num_train_envs, + test_num=self.sampling_config.num_test_envs, + scale=0, + frame_stack=self.frame_stack, + ) + return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index bd6cb81..16d90ef 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -1,4 +1,3 @@ -import os from abc import ABC, abstractmethod from collections.abc import Callable from typing import Generic, TypeVar @@ -9,14 +8,16 @@ from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import Logger -from tianshou.highlevel.module import ( - ActorCriticModuleOpt, +from tianshou.highlevel.module.actor import ( ActorFactory, +) +from tianshou.highlevel.module.core import TDevice +from tianshou.highlevel.module.critic import CriticFactory +from tianshou.highlevel.module.module_opt import ( + ActorCriticModuleOpt, ActorModuleOptFactory, - CriticFactory, CriticModuleOptFactory, ModuleOpt, - TDevice, ) from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.params.policy_params import ( @@ -27,8 +28,10 @@ from tianshou.highlevel.params.policy_params import ( SACParams, TD3Params, ) +from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory from tianshou.policy import A2CPolicy, BasePolicy, PPOPolicy, SACPolicy, TD3Policy from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer +from tianshou.utils.net import continuous, discrete from tianshou.utils.net.common import ActorCritic CHECKPOINT_DICT_KEY_MODEL = "model" @@ -38,34 +41,62 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy) class AgentFactory(ABC): - def __init__(self, sampling_config: RLSamplingConfig): + def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory): self.sampling_config = sampling_config + self.optim_factory = optim_factory + self.policy_wrapper_factory: PolicyWrapperFactory | None = None def create_train_test_collector(self, policy: BasePolicy, envs: Environments): buffer_size = self.sampling_config.buffer_size train_envs = envs.train_envs if len(train_envs) > 1: - buffer = VectorReplayBuffer(buffer_size, len(train_envs)) + buffer = VectorReplayBuffer( + buffer_size, + len(train_envs), + stack_num=self.sampling_config.replay_buffer_stack_num, + save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs, + ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next, + ) else: - buffer = ReplayBuffer(buffer_size) + buffer = ReplayBuffer( + buffer_size, + stack_num=self.sampling_config.replay_buffer_stack_num, + save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs, + ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next, + ) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, envs.test_envs) if self.sampling_config.start_timesteps > 0: train_collector.collect(n_step=self.sampling_config.start_timesteps, random=True) return train_collector, test_collector + def set_policy_wrapper_factory( + self, policy_wrapper_factory: PolicyWrapperFactory | None, + ) -> None: + self.policy_wrapper_factory = policy_wrapper_factory + @abstractmethod - def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: pass + def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + policy = self._create_policy(envs, device) + if self.policy_wrapper_factory is not None: + policy = self.policy_wrapper_factory.create_wrapped_policy( + policy, envs, self.optim_factory, device, + ) + return policy + @staticmethod def _create_save_best_fn(envs: Environments, log_path: str) -> Callable: def save_best_fn(pol: torch.nn.Module) -> None: - state = { - CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(), - CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(), - } - torch.save(state, os.path.join(log_path, "policy.pth")) + pass + # TODO: Fix saving in general (code works only for mujoco) + # state = { + # CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(), + # CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(), + # } + # torch.save(state, os.path.join(log_path, "policy.pth")) return save_best_fn @@ -160,11 +191,13 @@ class _ActorCriticMixin: critic_factory: CriticFactory, optim_factory: OptimizerFactory, critic_use_action: bool, + critic_use_actor_module: bool, ): self.actor_factory = actor_factory self.critic_factory = critic_factory self.optim_factory = optim_factory self.critic_use_action = critic_use_action + self.critic_use_actor_module = critic_use_actor_module def create_actor_critic_module_opt( self, @@ -173,7 +206,23 @@ class _ActorCriticMixin: lr: float, ) -> ActorCriticModuleOpt: actor = self.actor_factory.create_module(envs, device) - critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action) + if self.critic_use_actor_module: + if self.critic_use_action: + raise ValueError( + "The options critic_use_actor_module and critic_use_action are mutually exclusive", + ) + if envs.get_type().is_discrete(): + critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device) + elif envs.get_type().is_continuous(): + critic = continuous.Critic(actor.get_preprocess_net(), device=device).to(device) + else: + raise ValueError + else: + critic = self.critic_factory.create_module( + envs, + device, + use_action=self.critic_use_action, + ) actor_critic = ActorCritic(actor, critic) optim = self.optim_factory.create_optimizer(actor_critic, lr) return ActorCriticModuleOpt(actor_critic, optim) @@ -237,14 +286,16 @@ class ActorCriticAgentFactory( critic_factory: CriticFactory, optimizer_factory: OptimizerFactory, policy_class: type[TPolicy], + critic_use_actor_module: bool, ): - super().__init__(sampling_config) + super().__init__(sampling_config, optim_factory=optimizer_factory) _ActorCriticMixin.__init__( self, actor_factory, critic_factory, optimizer_factory, critic_use_action=False, + critic_use_actor_module=critic_use_actor_module, ) self.params = params self.policy_class = policy_class @@ -269,7 +320,7 @@ class ActorCriticAgentFactory( kwargs["action_space"] = envs.get_action_space() return kwargs - def create_policy(self, envs: Environments, device: TDevice) -> TPolicy: + def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: return self.policy_class(**self._create_kwargs(envs, device)) @@ -281,6 +332,7 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]): actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory, + critic_use_actor_module: bool, ): super().__init__( params, @@ -289,6 +341,7 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]): critic_factory, optimizer_factory, A2CPolicy, + critic_use_actor_module, ) def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: @@ -303,6 +356,7 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]): actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory, + critic_use_actor_module: bool, ): super().__init__( params, @@ -311,6 +365,7 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]): critic_factory, optimizer_factory, PPOPolicy, + critic_use_actor_module, ) def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: @@ -327,7 +382,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin): critic2_factory: CriticFactory, optim_factory: OptimizerFactory, ): - super().__init__(sampling_config) + super().__init__(sampling_config, optim_factory) _ActorAndDualCriticsMixin.__init__( self, actor_factory, @@ -339,7 +394,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin): self.params = params self.optim_factory = optim_factory - def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: actor = self.create_actor_module_opt(envs, device, self.params.actor_lr) critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr) critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr) @@ -376,7 +431,7 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin): critic2_factory: CriticFactory, optim_factory: OptimizerFactory, ): - super().__init__(sampling_config) + super().__init__(sampling_config, optim_factory) _ActorAndDualCriticsMixin.__init__( self, actor_factory, @@ -388,7 +443,7 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin): self.params = params self.optim_factory = optim_factory - def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: actor = self.create_actor_module_opt(envs, device, self.params.actor_lr) critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr) critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr) diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index de5d247..d90dce1 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -17,3 +17,7 @@ class RLSamplingConfig: update_per_step: int = 1 start_timesteps: int = 0 start_timesteps_random: bool = False + # TODO can we set the parameters below more intelligently? Perhaps based on env. representation? + replay_buffer_ignore_obs_next: bool = False + replay_buffer_save_only_last_obs: bool = False + replay_buffer_stack_num: int = 1 diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 4b7942e..5b28bfc 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -97,6 +97,22 @@ class ContinuousEnvironments(Environments): return EnvType.CONTINUOUS +class DiscreteEnvironments(Environments): + def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): + super().__init__(env, train_envs, test_envs) + self.observation_shape = env.observation_space.shape or env.observation_space.n + self.action_shape = env.action_space.shape or env.action_space.n + + def get_action_shape(self) -> TShape: + return self.action_shape + + def get_observation_shape(self) -> TShape: + return self.observation_shape + + def get_type(self) -> EnvType: + return EnvType.DISCRETE + + class EnvFactory(ABC): @abstractmethod def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index eebd265..3abaed4 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -18,13 +18,12 @@ from tianshou.highlevel.agent import ( from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.env import EnvFactory, Environments from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory -from tianshou.highlevel.module import ( +from tianshou.highlevel.module.actor import ( ActorFactory, ActorFactoryDefault, ContinuousActorType, - CriticFactory, - CriticFactoryDefault, ) +from tianshou.highlevel.module.critic import CriticFactory, CriticFactoryDefault from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam from tianshou.highlevel.params.policy_params import ( A2CParams, @@ -32,6 +31,7 @@ from tianshou.highlevel.params.policy_params import ( SACParams, TD3Params, ) +from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.policy import BasePolicy from tianshou.trainer import BaseTrainer @@ -154,6 +154,7 @@ class RLExperimentBuilder: self._logger_factory: LoggerFactory | None = None self._optim_factory: OptimizerFactory | None = None self._env_config: PersistableConfigProtocol | None = None + self._policy_wrapper_factory: PolicyWrapperFactory | None = None def with_env_config(self, config: PersistableConfigProtocol) -> Self: self._env_config = config @@ -163,6 +164,10 @@ class RLExperimentBuilder: self._logger_factory = logger_factory return self + def with_policy_wrapper_factory(self, policy_wrapper_factory: PolicyWrapperFactory) -> Self: + self._policy_wrapper_factory = policy_wrapper_factory + return self + def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder: self._optim_factory = optim_factory return self @@ -194,10 +199,13 @@ class RLExperimentBuilder: return self._optim_factory def build(self) -> RLExperiment: + agent_factory = self._create_agent_factory() + if self._policy_wrapper_factory: + agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) return RLExperiment( self._config, self._env_factory, - self._create_agent_factory(), + agent_factory, self._logger_factory, env_config=self._env_config, ) @@ -287,6 +295,7 @@ class _BuilderMixinCriticsFactory: class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): def __init__(self): super().__init__(1) + self._critic_use_actor_module = False def with_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder: self: TBuilder | "_BuilderMixinSingleCriticFactory" @@ -301,6 +310,11 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): self._with_critic_factory_default(0, hidden_sizes) return self + def with_critic_factory_use_actor(self) -> Self: + """Makes the critic use the same network as the actor.""" + self._critic_use_actor_module = True + return self + class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): def __init__(self): @@ -378,6 +392,7 @@ class A2CExperimentBuilder( self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), + self._critic_use_actor_module, ) @@ -411,6 +426,7 @@ class PPOExperimentBuilder( self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), + self._critic_use_actor_module, ) diff --git a/tianshou/highlevel/module/__init__.py b/tianshou/highlevel/module/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tianshou/highlevel/module.py b/tianshou/highlevel/module/actor.py similarity index 52% rename from tianshou/highlevel/module.py rename to tianshou/highlevel/module/actor.py index 3a9be10..488928f 100644 --- a/tianshou/highlevel/module.py +++ b/tianshou/highlevel/module/actor.py @@ -1,29 +1,13 @@ from abc import ABC, abstractmethod from collections.abc import Sequence -from dataclasses import dataclass -from typing import TypeAlias -import numpy as np import torch from torch import nn from tianshou.highlevel.env import Environments, EnvType -from tianshou.highlevel.optim import OptimizerFactory -from tianshou.utils.net import continuous -from tianshou.utils.net.common import ActorCritic, Net - -TDevice: TypeAlias = str | int | torch.device - - -def init_linear_orthogonal(module: torch.nn.Module): - """Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0. - - :param module: the module whose submodules are to be processed - """ - for m in module.modules(): - if isinstance(m, torch.nn.Linear): - torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) - torch.nn.init.zeros_(m.bias) +from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal +from tianshou.utils.net import continuous, discrete +from tianshou.utils.net.common import BaseActor, Net class ContinuousActorType: @@ -33,7 +17,7 @@ class ContinuousActorType: class ActorFactory(ABC): @abstractmethod - def create_module(self, envs: Environments, device: TDevice) -> nn.Module: + def create_module(self, envs: Environments, device: TDevice) -> BaseActor: pass @staticmethod @@ -70,18 +54,18 @@ class ActorFactoryDefault(ActorFactory): self.continuous_conditioned_sigma = continuous_conditioned_sigma self.hidden_sizes = hidden_sizes - def create_module(self, envs: Environments, device: TDevice) -> nn.Module: + def create_module(self, envs: Environments, device: TDevice) -> BaseActor: env_type = envs.get_type() if env_type == EnvType.CONTINUOUS: match self.continuous_actor_type: case ContinuousActorType.GAUSSIAN: - factory = ActorFactoryContinuousGaussian( + factory = ActorFactoryContinuousGaussianNet( self.hidden_sizes, unbounded=self.continuous_unbounded, conditioned_sigma=self.continuous_conditioned_sigma, ) case ContinuousActorType.DETERMINISTIC: - factory = ActorFactoryContinuousDeterministic(self.hidden_sizes) + factory = ActorFactoryContinuousDeterministicNet(self.hidden_sizes) case _: raise ValueError(self.continuous_actor_type) return factory.create_module(envs, device) @@ -95,11 +79,11 @@ class ActorFactoryContinuous(ActorFactory, ABC): """Serves as a type bound for actor factories that are suitable for continuous action spaces.""" -class ActorFactoryContinuousDeterministic(ActorFactoryContinuous): +class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous): def __init__(self, hidden_sizes: Sequence[int]): self.hidden_sizes = hidden_sizes - def create_module(self, envs: Environments, device: TDevice) -> nn.Module: + def create_module(self, envs: Environments, device: TDevice) -> BaseActor: net_a = Net( envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, @@ -113,13 +97,13 @@ class ActorFactoryContinuousDeterministic(ActorFactoryContinuous): ).to(device) -class ActorFactoryContinuousGaussian(ActorFactoryContinuous): +class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous): def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False): self.hidden_sizes = hidden_sizes self.unbounded = unbounded self.conditioned_sigma = conditioned_sigma - def create_module(self, envs: Environments, device: TDevice) -> nn.Module: + def create_module(self, envs: Environments, device: TDevice) -> BaseActor: net_a = Net( envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, @@ -142,97 +126,19 @@ class ActorFactoryContinuousGaussian(ActorFactoryContinuous): return actor -class CriticFactory(ABC): - @abstractmethod - def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: - pass - - -class CriticFactoryDefault(CriticFactory): - """A critic factory which, depending on the type of environment, creates a suitable MLP-based critic.""" - - DEFAULT_HIDDEN_SIZES = (64, 64) - - def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES): - self.hidden_sizes = hidden_sizes - - def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: - env_type = envs.get_type() - if env_type == EnvType.CONTINUOUS: - factory = CriticFactoryContinuousNet(self.hidden_sizes) - return factory.create_module(envs, device, use_action) - elif env_type == EnvType.DISCRETE: - raise NotImplementedError - else: - raise ValueError(f"{env_type} not supported") - - -class CriticFactoryContinuous(CriticFactory, ABC): - pass - - -class CriticFactoryContinuousNet(CriticFactoryContinuous): +class ActorFactoryDiscreteNet(ActorFactory): def __init__(self, hidden_sizes: Sequence[int]): self.hidden_sizes = hidden_sizes - def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: - action_shape = envs.get_action_shape() if use_action else 0 - net_c = Net( + def create_module(self, envs: Environments, device: TDevice) -> BaseActor: + net_a = Net( envs.get_observation_shape(), - action_shape=action_shape, hidden_sizes=self.hidden_sizes, - concat=use_action, - activation=nn.Tanh, device=device, ) - critic = continuous.Critic(net_c, device=device).to(device) - init_linear_orthogonal(critic) - return critic - - -@dataclass -class ModuleOpt: - module: torch.nn.Module - optim: torch.optim.Optimizer - - -@dataclass -class ActorCriticModuleOpt: - actor_critic_module: ActorCritic - optim: torch.optim.Optimizer - - @property - def actor(self): - return self.actor_critic_module.actor - - @property - def critic(self): - return self.actor_critic_module.critic - - -class ActorModuleOptFactory: - def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory): - self.actor_factory = actor_factory - self.optim_factory = optim_factory - - def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt: - actor = self.actor_factory.create_module(envs, device) - opt = self.optim_factory.create_optimizer(actor, lr) - return ModuleOpt(actor, opt) - - -class CriticModuleOptFactory: - def __init__( - self, - critic_factory: CriticFactory, - optim_factory: OptimizerFactory, - use_action: bool, - ): - self.critic_factory = critic_factory - self.optim_factory = optim_factory - self.use_action = use_action - - def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt: - critic = self.critic_factory.create_module(envs, device, self.use_action) - opt = self.optim_factory.create_optimizer(critic, lr) - return ModuleOpt(critic, opt) + return discrete.Actor( + net_a, + envs.get_action_shape(), + hidden_sizes=(), + device=device, + ).to(device) diff --git a/tianshou/highlevel/module/core.py b/tianshou/highlevel/module/core.py new file mode 100644 index 0000000..438fd31 --- /dev/null +++ b/tianshou/highlevel/module/core.py @@ -0,0 +1,44 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TypeAlias + +import numpy as np +import torch + +from tianshou.highlevel.env import Environments +from tianshou.utils.net.common import Net + +TDevice: TypeAlias = str | int | torch.device + + +def init_linear_orthogonal(module: torch.nn.Module): + """Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0. + + :param module: the module whose submodules are to be processed + """ + for m in module.modules(): + if isinstance(m, torch.nn.Linear): + torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) + torch.nn.init.zeros_(m.bias) + + +@dataclass +class Module: + module: torch.nn.Module + output_dim: int + + +class ModuleFactory(ABC): + @abstractmethod + def create_module(self, envs: Environments, device: TDevice) -> Module: + pass + + +class ModuleFactoryNet(ModuleFactory): + def __init__(self, hidden_sizes: int | Sequence[int]): + self.hidden_sizes = hidden_sizes + + def create_module(self, envs: Environments, device: TDevice) -> Module: + module = Net(envs.get_observation_shape()) + return Module(module, module.output_dim) diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py new file mode 100644 index 0000000..f15adbe --- /dev/null +++ b/tianshou/highlevel/module/critic.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence + +from torch import nn + +from tianshou.highlevel.env import Environments, EnvType +from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal +from tianshou.utils.net import continuous +from tianshou.utils.net.common import Net + + +class CriticFactory(ABC): + @abstractmethod + def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: + pass + + +class CriticFactoryDefault(CriticFactory): + """A critic factory which, depending on the type of environment, creates a suitable MLP-based critic.""" + + DEFAULT_HIDDEN_SIZES = (64, 64) + + def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES): + self.hidden_sizes = hidden_sizes + + def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: + env_type = envs.get_type() + if env_type == EnvType.CONTINUOUS: + factory = CriticFactoryContinuousNet(self.hidden_sizes) + return factory.create_module(envs, device, use_action) + elif env_type == EnvType.DISCRETE: + raise NotImplementedError + else: + raise ValueError(f"{env_type} not supported") + + +class CriticFactoryContinuous(CriticFactory, ABC): + pass + + +class CriticFactoryContinuousNet(CriticFactoryContinuous): + def __init__(self, hidden_sizes: Sequence[int]): + self.hidden_sizes = hidden_sizes + + def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: + action_shape = envs.get_action_shape() if use_action else 0 + net_c = Net( + envs.get_observation_shape(), + action_shape=action_shape, + hidden_sizes=self.hidden_sizes, + concat=use_action, + activation=nn.Tanh, + device=device, + ) + critic = continuous.Critic(net_c, device=device).to(device) + init_linear_orthogonal(critic) + return critic diff --git a/tianshou/highlevel/module/module_opt.py b/tianshou/highlevel/module/module_opt.py new file mode 100644 index 0000000..8feff85 --- /dev/null +++ b/tianshou/highlevel/module/module_opt.py @@ -0,0 +1,58 @@ +from dataclasses import dataclass + +import torch + +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module.actor import ActorFactory +from tianshou.highlevel.module.core import TDevice +from tianshou.highlevel.module.critic import CriticFactory +from tianshou.highlevel.optim import OptimizerFactory +from tianshou.utils.net.common import ActorCritic + + +@dataclass +class ModuleOpt: + module: torch.nn.Module + optim: torch.optim.Optimizer + + +@dataclass +class ActorCriticModuleOpt: + actor_critic_module: ActorCritic + optim: torch.optim.Optimizer + + @property + def actor(self): + return self.actor_critic_module.actor + + @property + def critic(self): + return self.actor_critic_module.critic + + +class ActorModuleOptFactory: + def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory): + self.actor_factory = actor_factory + self.optim_factory = optim_factory + + def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt: + actor = self.actor_factory.create_module(envs, device) + opt = self.optim_factory.create_optimizer(actor, lr) + return ModuleOpt(actor, opt) + + +class CriticModuleOptFactory: + def __init__( + self, + critic_factory: CriticFactory, + optim_factory: OptimizerFactory, + use_action: bool, + ): + self.critic_factory = critic_factory + self.optim_factory = optim_factory + self.use_action = use_action + + def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt: + critic = self.critic_factory.create_module(envs, device, self.use_action) + opt = self.optim_factory.create_optimizer(critic, lr) + return ModuleOpt(critic, opt) diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index d786a17..4adb93c 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -4,7 +4,7 @@ import numpy as np import torch from tianshou.highlevel.env import Environments -from tianshou.highlevel.module import TDevice +from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.optim import OptimizerFactory diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index dca4a60..2d6c6e9 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -6,7 +6,8 @@ import torch from tianshou.exploration import BaseNoise from tianshou.highlevel.env import Environments -from tianshou.highlevel.module import ModuleOpt, TDevice +from tianshou.highlevel.module.core import TDevice +from tianshou.highlevel.module.module_opt import ModuleOpt from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.params.alpha import AutoAlphaFactory from tianshou.highlevel.params.dist_fn import ( @@ -66,6 +67,18 @@ class ParamTransformerDrop(ParamTransformer): del kwargs[k] +class ParamTransformerChangeValue(ParamTransformer): + def __init__(self, key: str): + self.key = key + + def transform(self, params: dict[str, Any], data: ParamTransformerData): + params[self.key] = self.change_value(params[self.key], data) + + @abstractmethod + def change_value(self, value: Any, data: ParamTransformerData) -> Any: + pass + + class ParamTransformerLRScheduler(ParamTransformer): """Transforms a key containing a learning rate scheduler factory (removed) into a key containing a learning rate scheduler (added) for the data member `optim`. @@ -182,6 +195,14 @@ class ParamTransformerDistributionFunction(ParamTransformer): kwargs[self.key] = value.create_dist_fn(data.envs) +class ParamTransformerActionScaling(ParamTransformerChangeValue): + def change_value(self, value: Any, data: ParamTransformerData) -> Any: + if value == "default": + return data.envs.get_type().is_continuous() + else: + return value + + class GetParamTransformersProtocol(Protocol): def _get_param_transformers(self) -> list[ParamTransformer]: pass @@ -218,9 +239,15 @@ class PGParams(Params): discount_factor: float = 0.99 reward_normalization: bool = False deterministic_eval: bool = False - action_scaling: bool = True + action_scaling: bool | Literal["default"] = "default" + """whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces""" action_bound_method: Literal["clip", "tanh"] | None = "clip" + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.append(ParamTransformerActionScaling("action_scaling")) + return transformers + @dataclass class A2CParams(PGParams, ParamsMixinLearningRateWithScheduler): diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py new file mode 100644 index 0000000..008bfd8 --- /dev/null +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -0,0 +1,72 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Generic, TypeVar + +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module.core import ModuleFactory, TDevice +from tianshou.highlevel.optim import OptimizerFactory +from tianshou.policy import BasePolicy, ICMPolicy +from tianshou.utils.net.discrete import IntrinsicCuriosityModule + +TPolicyIn = TypeVar("TPolicyIn", bound=BasePolicy) +TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy) + + +class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ABC): + @abstractmethod + def create_wrapped_policy( + self, + policy: TPolicyIn, + envs: Environments, + optim_factory: OptimizerFactory, + device: TDevice, + ) -> TPolicyOut: + pass + + +class PolicyWrapperFactoryIntrinsicCuriosity( + Generic[TPolicyIn], PolicyWrapperFactory[TPolicyIn, ICMPolicy], +): + def __init__( + self, + feature_net_factory: ModuleFactory, + hidden_sizes: Sequence[int], + lr: float, + lr_scale: float, + reward_scale: float, + forward_loss_weight, + ): + self.feature_net_factory = feature_net_factory + self.hidden_sizes = hidden_sizes + self.lr = lr + self.lr_scale = lr_scale + self.reward_scale = reward_scale + self.forward_loss_weight = forward_loss_weight + + def create_wrapped_policy( + self, + policy: TPolicyIn, + envs: Environments, + optim_factory: OptimizerFactory, + device: TDevice, + ) -> ICMPolicy: + feature_net = self.feature_net_factory.create_module(envs, device) + action_dim = envs.get_action_shape() + feature_dim = feature_net.output_dim + icm_net = IntrinsicCuriosityModule( + feature_net.module, + feature_dim, + action_dim, + hidden_sizes=self.hidden_sizes, + device=device, + ) + icm_optim = optim_factory.create_optimizer(icm_net, lr=self.lr) + return ICMPolicy( + policy=policy, + model=icm_net, + optim=icm_optim, + action_space=envs.get_action_space(), + lr_scale=self.lr_scale, + reward_scale=self.reward_scale, + forward_loss_weight=self.forward_loss_weight, + ).to(device)