Add support for discrete PPO

* Refactored module `module` (split into submodules)
* Basic support for discrete environments
* Implement Atari env. factory
* Implement DQN-based actor factory
* Implement notion of reusing agent preprocessing network for critic
* Add example atari_ppo_hl
This commit is contained in:
Dominik Jain 2023-09-28 20:07:52 +02:00
parent e0e7349b0a
commit 6b6d9ea609
15 changed files with 566 additions and 143 deletions

View File

@ -5,7 +5,11 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from tianshou.utils.net.discrete import NoisyLinear from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import ActorFactory
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice
from tianshou.utils.net.common import BaseActor
from tianshou.utils.net.discrete import Actor, NoisyLinear
def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module: def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module:
@ -220,3 +224,29 @@ class QRDQN(DQN):
obs, state = super().forward(obs) obs, state = super().forward(obs)
obs = obs.view(-1, self.action_num, self.num_quantiles) obs = obs.view(-1, self.action_num, self.num_quantiles)
return obs, state return obs, state
class ActorFactoryAtariDQN(ActorFactory):
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool):
self.hidden_size = hidden_size
self.scale_obs = scale_obs
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
net_cls = scale_obs(DQN) if self.scale_obs else DQN
net = net_cls(
*envs.get_observation_shape(),
envs.get_action_shape(),
device=device,
features_only=True,
output_dim=self.hidden_size,
layer_init=layer_init,
)
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
class FeatureNetFactoryDQN(ModuleFactory):
def create_module(self, envs: Environments, device: TDevice) -> Module:
dqn = DQN(
*envs.get_observation_shape(), envs.get_action_shape(), device, features_only=True,
)
return Module(dqn.net, dqn.output_dim)

View File

@ -0,0 +1,117 @@
#!/usr/bin/env python3
import datetime
import os
from collections.abc import Sequence
from jsonargparse import CLI
from examples.atari.atari_network import (
ActorFactoryAtariDQN,
FeatureNetFactoryDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
PPOExperimentBuilder,
RLExperimentConfig,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity,
)
def main(
experiment_config: RLExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: bool = True,
buffer_size: int = 100000,
lr: float = 2.5e-4,
gamma: float = 0.99,
epoch: int = 100,
step_per_epoch: int = 100000,
step_per_collect: int = 1000,
repeat_per_collect: int = 4,
batch_size: int = 256,
hidden_sizes: int | Sequence[int] = 512,
training_num: int = 10,
test_num: int = 10,
rew_norm: bool = False,
vf_coef: float = 0.25,
ent_coef: float = 0.01,
gae_lambda: float = 0.95,
lr_decay: bool = True,
max_grad_norm: float = 0.5,
eps_clip: float = 0.1,
dual_clip: float | None = None,
value_clip: bool = True,
norm_adv: bool = True,
recompute_adv: bool = False,
frames_stack: int = 4,
save_buffer_name: str | None = None, # TODO add support in high-level API?
icm_lr_scale: float = 0.0,
icm_reward_scale: float = 0.01,
icm_forward_loss_weight: float = 0.2,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
repeat_per_collect=repeat_per_collect,
replay_buffer_stack_num=frames_stack,
replay_buffer_ignore_obs_next=True,
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
builder = (
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
.with_ppo_params(
PPOParams(
discount_factor=gamma,
gae_lambda=gae_lambda,
reward_normalization=rew_norm,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
value_clip=value_clip,
advantage_normalization=norm_adv,
eps_clip=eps_clip,
dual_clip=dual_clip,
recompute_advantage=recompute_adv,
lr=lr,
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
),
)
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs))
.with_critic_factory_use_actor()
)
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity(
FeatureNetFactoryDQN(),
[hidden_sizes],
lr,
icm_lr_scale,
icm_reward_scale,
icm_forward_loss_weight,
),
)
experiment = builder.build()
experiment.run(log_name)
if __name__ == "__main__":
CLI(main)

View File

@ -9,6 +9,8 @@ import gymnasium as gym
import numpy as np import numpy as np
from tianshou.env import ShmemVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
try: try:
import envpool import envpool
@ -369,3 +371,22 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
train_envs.seed(seed) train_envs.seed(seed)
test_envs.seed(seed) test_envs.seed(seed)
return env, train_envs, test_envs return env, train_envs, test_envs
class AtariEnvFactory(EnvFactory):
def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig, frame_stack: int):
self.task = task
self.sampling_config = sampling_config
self.seed = seed
self.frame_stack = frame_stack
def create_envs(self, config=None) -> DiscreteEnvironments:
env, train_envs, test_envs = make_atari_env(
task=self.task,
seed=self.seed,
training_num=self.sampling_config.num_train_envs,
test_num=self.sampling_config.num_test_envs,
scale=0,
frame_stack=self.frame_stack,
)
return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)

View File

@ -1,4 +1,3 @@
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from typing import Generic, TypeVar from typing import Generic, TypeVar
@ -9,14 +8,16 @@ from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module import ( from tianshou.highlevel.module.actor import (
ActorCriticModuleOpt,
ActorFactory, ActorFactory,
)
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.critic import CriticFactory
from tianshou.highlevel.module.module_opt import (
ActorCriticModuleOpt,
ActorModuleOptFactory, ActorModuleOptFactory,
CriticFactory,
CriticModuleOptFactory, CriticModuleOptFactory,
ModuleOpt, ModuleOpt,
TDevice,
) )
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.policy_params import ( from tianshou.highlevel.params.policy_params import (
@ -27,8 +28,10 @@ from tianshou.highlevel.params.policy_params import (
SACParams, SACParams,
TD3Params, TD3Params,
) )
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.policy import A2CPolicy, BasePolicy, PPOPolicy, SACPolicy, TD3Policy from tianshou.policy import A2CPolicy, BasePolicy, PPOPolicy, SACPolicy, TD3Policy
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.common import ActorCritic
CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_MODEL = "model"
@ -38,34 +41,62 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy)
class AgentFactory(ABC): class AgentFactory(ABC):
def __init__(self, sampling_config: RLSamplingConfig): def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory):
self.sampling_config = sampling_config self.sampling_config = sampling_config
self.optim_factory = optim_factory
self.policy_wrapper_factory: PolicyWrapperFactory | None = None
def create_train_test_collector(self, policy: BasePolicy, envs: Environments): def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
buffer_size = self.sampling_config.buffer_size buffer_size = self.sampling_config.buffer_size
train_envs = envs.train_envs train_envs = envs.train_envs
if len(train_envs) > 1: if len(train_envs) > 1:
buffer = VectorReplayBuffer(buffer_size, len(train_envs)) buffer = VectorReplayBuffer(
buffer_size,
len(train_envs),
stack_num=self.sampling_config.replay_buffer_stack_num,
save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs,
ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next,
)
else: else:
buffer = ReplayBuffer(buffer_size) buffer = ReplayBuffer(
buffer_size,
stack_num=self.sampling_config.replay_buffer_stack_num,
save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs,
ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next,
)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, envs.test_envs) test_collector = Collector(policy, envs.test_envs)
if self.sampling_config.start_timesteps > 0: if self.sampling_config.start_timesteps > 0:
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=True) train_collector.collect(n_step=self.sampling_config.start_timesteps, random=True)
return train_collector, test_collector return train_collector, test_collector
def set_policy_wrapper_factory(
self, policy_wrapper_factory: PolicyWrapperFactory | None,
) -> None:
self.policy_wrapper_factory = policy_wrapper_factory
@abstractmethod @abstractmethod
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
pass pass
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
policy = self._create_policy(envs, device)
if self.policy_wrapper_factory is not None:
policy = self.policy_wrapper_factory.create_wrapped_policy(
policy, envs, self.optim_factory, device,
)
return policy
@staticmethod @staticmethod
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable: def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
def save_best_fn(pol: torch.nn.Module) -> None: def save_best_fn(pol: torch.nn.Module) -> None:
state = { pass
CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(), # TODO: Fix saving in general (code works only for mujoco)
CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(), # state = {
} # CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
torch.save(state, os.path.join(log_path, "policy.pth")) # CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
# }
# torch.save(state, os.path.join(log_path, "policy.pth"))
return save_best_fn return save_best_fn
@ -160,11 +191,13 @@ class _ActorCriticMixin:
critic_factory: CriticFactory, critic_factory: CriticFactory,
optim_factory: OptimizerFactory, optim_factory: OptimizerFactory,
critic_use_action: bool, critic_use_action: bool,
critic_use_actor_module: bool,
): ):
self.actor_factory = actor_factory self.actor_factory = actor_factory
self.critic_factory = critic_factory self.critic_factory = critic_factory
self.optim_factory = optim_factory self.optim_factory = optim_factory
self.critic_use_action = critic_use_action self.critic_use_action = critic_use_action
self.critic_use_actor_module = critic_use_actor_module
def create_actor_critic_module_opt( def create_actor_critic_module_opt(
self, self,
@ -173,7 +206,23 @@ class _ActorCriticMixin:
lr: float, lr: float,
) -> ActorCriticModuleOpt: ) -> ActorCriticModuleOpt:
actor = self.actor_factory.create_module(envs, device) actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action) if self.critic_use_actor_module:
if self.critic_use_action:
raise ValueError(
"The options critic_use_actor_module and critic_use_action are mutually exclusive",
)
if envs.get_type().is_discrete():
critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device)
elif envs.get_type().is_continuous():
critic = continuous.Critic(actor.get_preprocess_net(), device=device).to(device)
else:
raise ValueError
else:
critic = self.critic_factory.create_module(
envs,
device,
use_action=self.critic_use_action,
)
actor_critic = ActorCritic(actor, critic) actor_critic = ActorCritic(actor, critic)
optim = self.optim_factory.create_optimizer(actor_critic, lr) optim = self.optim_factory.create_optimizer(actor_critic, lr)
return ActorCriticModuleOpt(actor_critic, optim) return ActorCriticModuleOpt(actor_critic, optim)
@ -237,14 +286,16 @@ class ActorCriticAgentFactory(
critic_factory: CriticFactory, critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory, optimizer_factory: OptimizerFactory,
policy_class: type[TPolicy], policy_class: type[TPolicy],
critic_use_actor_module: bool,
): ):
super().__init__(sampling_config) super().__init__(sampling_config, optim_factory=optimizer_factory)
_ActorCriticMixin.__init__( _ActorCriticMixin.__init__(
self, self,
actor_factory, actor_factory,
critic_factory, critic_factory,
optimizer_factory, optimizer_factory,
critic_use_action=False, critic_use_action=False,
critic_use_actor_module=critic_use_actor_module,
) )
self.params = params self.params = params
self.policy_class = policy_class self.policy_class = policy_class
@ -269,7 +320,7 @@ class ActorCriticAgentFactory(
kwargs["action_space"] = envs.get_action_space() kwargs["action_space"] = envs.get_action_space()
return kwargs return kwargs
def create_policy(self, envs: Environments, device: TDevice) -> TPolicy: def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
return self.policy_class(**self._create_kwargs(envs, device)) return self.policy_class(**self._create_kwargs(envs, device))
@ -281,6 +332,7 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic_factory: CriticFactory, critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory, optimizer_factory: OptimizerFactory,
critic_use_actor_module: bool,
): ):
super().__init__( super().__init__(
params, params,
@ -289,6 +341,7 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
critic_factory, critic_factory,
optimizer_factory, optimizer_factory,
A2CPolicy, A2CPolicy,
critic_use_actor_module,
) )
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
@ -303,6 +356,7 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic_factory: CriticFactory, critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory, optimizer_factory: OptimizerFactory,
critic_use_actor_module: bool,
): ):
super().__init__( super().__init__(
params, params,
@ -311,6 +365,7 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
critic_factory, critic_factory,
optimizer_factory, optimizer_factory,
PPOPolicy, PPOPolicy,
critic_use_actor_module,
) )
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
@ -327,7 +382,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
critic2_factory: CriticFactory, critic2_factory: CriticFactory,
optim_factory: OptimizerFactory, optim_factory: OptimizerFactory,
): ):
super().__init__(sampling_config) super().__init__(sampling_config, optim_factory)
_ActorAndDualCriticsMixin.__init__( _ActorAndDualCriticsMixin.__init__(
self, self,
actor_factory, actor_factory,
@ -339,7 +394,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
self.params = params self.params = params
self.optim_factory = optim_factory self.optim_factory = optim_factory
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr) actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr) critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr) critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
@ -376,7 +431,7 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
critic2_factory: CriticFactory, critic2_factory: CriticFactory,
optim_factory: OptimizerFactory, optim_factory: OptimizerFactory,
): ):
super().__init__(sampling_config) super().__init__(sampling_config, optim_factory)
_ActorAndDualCriticsMixin.__init__( _ActorAndDualCriticsMixin.__init__(
self, self,
actor_factory, actor_factory,
@ -388,7 +443,7 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
self.params = params self.params = params
self.optim_factory = optim_factory self.optim_factory = optim_factory
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr) actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr) critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr) critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)

View File

@ -17,3 +17,7 @@ class RLSamplingConfig:
update_per_step: int = 1 update_per_step: int = 1
start_timesteps: int = 0 start_timesteps: int = 0
start_timesteps_random: bool = False start_timesteps_random: bool = False
# TODO can we set the parameters below more intelligently? Perhaps based on env. representation?
replay_buffer_ignore_obs_next: bool = False
replay_buffer_save_only_last_obs: bool = False
replay_buffer_stack_num: int = 1

View File

@ -97,6 +97,22 @@ class ContinuousEnvironments(Environments):
return EnvType.CONTINUOUS return EnvType.CONTINUOUS
class DiscreteEnvironments(Environments):
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs)
self.observation_shape = env.observation_space.shape or env.observation_space.n
self.action_shape = env.action_space.shape or env.action_space.n
def get_action_shape(self) -> TShape:
return self.action_shape
def get_observation_shape(self) -> TShape:
return self.observation_shape
def get_type(self) -> EnvType:
return EnvType.DISCRETE
class EnvFactory(ABC): class EnvFactory(ABC):
@abstractmethod @abstractmethod
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:

View File

@ -18,13 +18,12 @@ from tianshou.highlevel.agent import (
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import EnvFactory, Environments from tianshou.highlevel.env import EnvFactory, Environments
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
from tianshou.highlevel.module import ( from tianshou.highlevel.module.actor import (
ActorFactory, ActorFactory,
ActorFactoryDefault, ActorFactoryDefault,
ContinuousActorType, ContinuousActorType,
CriticFactory,
CriticFactoryDefault,
) )
from tianshou.highlevel.module.critic import CriticFactory, CriticFactoryDefault
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
from tianshou.highlevel.params.policy_params import ( from tianshou.highlevel.params.policy_params import (
A2CParams, A2CParams,
@ -32,6 +31,7 @@ from tianshou.highlevel.params.policy_params import (
SACParams, SACParams,
TD3Params, TD3Params,
) )
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer from tianshou.trainer import BaseTrainer
@ -154,6 +154,7 @@ class RLExperimentBuilder:
self._logger_factory: LoggerFactory | None = None self._logger_factory: LoggerFactory | None = None
self._optim_factory: OptimizerFactory | None = None self._optim_factory: OptimizerFactory | None = None
self._env_config: PersistableConfigProtocol | None = None self._env_config: PersistableConfigProtocol | None = None
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
def with_env_config(self, config: PersistableConfigProtocol) -> Self: def with_env_config(self, config: PersistableConfigProtocol) -> Self:
self._env_config = config self._env_config = config
@ -163,6 +164,10 @@ class RLExperimentBuilder:
self._logger_factory = logger_factory self._logger_factory = logger_factory
return self return self
def with_policy_wrapper_factory(self, policy_wrapper_factory: PolicyWrapperFactory) -> Self:
self._policy_wrapper_factory = policy_wrapper_factory
return self
def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder: def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder:
self._optim_factory = optim_factory self._optim_factory = optim_factory
return self return self
@ -194,10 +199,13 @@ class RLExperimentBuilder:
return self._optim_factory return self._optim_factory
def build(self) -> RLExperiment: def build(self) -> RLExperiment:
agent_factory = self._create_agent_factory()
if self._policy_wrapper_factory:
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
return RLExperiment( return RLExperiment(
self._config, self._config,
self._env_factory, self._env_factory,
self._create_agent_factory(), agent_factory,
self._logger_factory, self._logger_factory,
env_config=self._env_config, env_config=self._env_config,
) )
@ -287,6 +295,7 @@ class _BuilderMixinCriticsFactory:
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self): def __init__(self):
super().__init__(1) super().__init__(1)
self._critic_use_actor_module = False
def with_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder: def with_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
self: TBuilder | "_BuilderMixinSingleCriticFactory" self: TBuilder | "_BuilderMixinSingleCriticFactory"
@ -301,6 +310,11 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes)
return self return self
def with_critic_factory_use_actor(self) -> Self:
"""Makes the critic use the same network as the actor."""
self._critic_use_actor_module = True
return self
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self): def __init__(self):
@ -378,6 +392,7 @@ class A2CExperimentBuilder(
self._get_actor_factory(), self._get_actor_factory(),
self._get_critic_factory(0), self._get_critic_factory(0),
self._get_optim_factory(), self._get_optim_factory(),
self._critic_use_actor_module,
) )
@ -411,6 +426,7 @@ class PPOExperimentBuilder(
self._get_actor_factory(), self._get_actor_factory(),
self._get_critic_factory(0), self._get_critic_factory(0),
self._get_optim_factory(), self._get_optim_factory(),
self._critic_use_actor_module,
) )

View File

View File

@ -1,29 +1,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass
from typing import TypeAlias
import numpy as np
import torch import torch
from torch import nn from torch import nn
from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.utils.net import continuous from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.common import BaseActor, Net
TDevice: TypeAlias = str | int | torch.device
def init_linear_orthogonal(module: torch.nn.Module):
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
:param module: the module whose submodules are to be processed
"""
for m in module.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
class ContinuousActorType: class ContinuousActorType:
@ -33,7 +17,7 @@ class ContinuousActorType:
class ActorFactory(ABC): class ActorFactory(ABC):
@abstractmethod @abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> nn.Module: def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
pass pass
@staticmethod @staticmethod
@ -70,18 +54,18 @@ class ActorFactoryDefault(ActorFactory):
self.continuous_conditioned_sigma = continuous_conditioned_sigma self.continuous_conditioned_sigma = continuous_conditioned_sigma
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice) -> nn.Module: def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
env_type = envs.get_type() env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS: if env_type == EnvType.CONTINUOUS:
match self.continuous_actor_type: match self.continuous_actor_type:
case ContinuousActorType.GAUSSIAN: case ContinuousActorType.GAUSSIAN:
factory = ActorFactoryContinuousGaussian( factory = ActorFactoryContinuousGaussianNet(
self.hidden_sizes, self.hidden_sizes,
unbounded=self.continuous_unbounded, unbounded=self.continuous_unbounded,
conditioned_sigma=self.continuous_conditioned_sigma, conditioned_sigma=self.continuous_conditioned_sigma,
) )
case ContinuousActorType.DETERMINISTIC: case ContinuousActorType.DETERMINISTIC:
factory = ActorFactoryContinuousDeterministic(self.hidden_sizes) factory = ActorFactoryContinuousDeterministicNet(self.hidden_sizes)
case _: case _:
raise ValueError(self.continuous_actor_type) raise ValueError(self.continuous_actor_type)
return factory.create_module(envs, device) return factory.create_module(envs, device)
@ -95,11 +79,11 @@ class ActorFactoryContinuous(ActorFactory, ABC):
"""Serves as a type bound for actor factories that are suitable for continuous action spaces.""" """Serves as a type bound for actor factories that are suitable for continuous action spaces."""
class ActorFactoryContinuousDeterministic(ActorFactoryContinuous): class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int]): def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice) -> nn.Module: def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
net_a = Net( net_a = Net(
envs.get_observation_shape(), envs.get_observation_shape(),
hidden_sizes=self.hidden_sizes, hidden_sizes=self.hidden_sizes,
@ -113,13 +97,13 @@ class ActorFactoryContinuousDeterministic(ActorFactoryContinuous):
).to(device) ).to(device)
class ActorFactoryContinuousGaussian(ActorFactoryContinuous): class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False): def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
self.unbounded = unbounded self.unbounded = unbounded
self.conditioned_sigma = conditioned_sigma self.conditioned_sigma = conditioned_sigma
def create_module(self, envs: Environments, device: TDevice) -> nn.Module: def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
net_a = Net( net_a = Net(
envs.get_observation_shape(), envs.get_observation_shape(),
hidden_sizes=self.hidden_sizes, hidden_sizes=self.hidden_sizes,
@ -142,97 +126,19 @@ class ActorFactoryContinuousGaussian(ActorFactoryContinuous):
return actor return actor
class CriticFactory(ABC): class ActorFactoryDiscreteNet(ActorFactory):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
pass
class CriticFactoryDefault(CriticFactory):
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS:
factory = CriticFactoryContinuousNet(self.hidden_sizes)
return factory.create_module(envs, device, use_action)
elif env_type == EnvType.DISCRETE:
raise NotImplementedError
else:
raise ValueError(f"{env_type} not supported")
class CriticFactoryContinuous(CriticFactory, ABC):
pass
class CriticFactoryContinuousNet(CriticFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int]): def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
action_shape = envs.get_action_shape() if use_action else 0 net_a = Net(
net_c = Net(
envs.get_observation_shape(), envs.get_observation_shape(),
action_shape=action_shape,
hidden_sizes=self.hidden_sizes, hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
device=device, device=device,
) )
critic = continuous.Critic(net_c, device=device).to(device) return discrete.Actor(
init_linear_orthogonal(critic) net_a,
return critic envs.get_action_shape(),
hidden_sizes=(),
device=device,
@dataclass ).to(device)
class ModuleOpt:
module: torch.nn.Module
optim: torch.optim.Optimizer
@dataclass
class ActorCriticModuleOpt:
actor_critic_module: ActorCritic
optim: torch.optim.Optimizer
@property
def actor(self):
return self.actor_critic_module.actor
@property
def critic(self):
return self.actor_critic_module.critic
class ActorModuleOptFactory:
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_factory = actor_factory
self.optim_factory = optim_factory
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
actor = self.actor_factory.create_module(envs, device)
opt = self.optim_factory.create_optimizer(actor, lr)
return ModuleOpt(actor, opt)
class CriticModuleOptFactory:
def __init__(
self,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
use_action: bool,
):
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.use_action = use_action
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
critic = self.critic_factory.create_module(envs, device, self.use_action)
opt = self.optim_factory.create_optimizer(critic, lr)
return ModuleOpt(critic, opt)

View File

@ -0,0 +1,44 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TypeAlias
import numpy as np
import torch
from tianshou.highlevel.env import Environments
from tianshou.utils.net.common import Net
TDevice: TypeAlias = str | int | torch.device
def init_linear_orthogonal(module: torch.nn.Module):
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
:param module: the module whose submodules are to be processed
"""
for m in module.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
@dataclass
class Module:
module: torch.nn.Module
output_dim: int
class ModuleFactory(ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> Module:
pass
class ModuleFactoryNet(ModuleFactory):
def __init__(self, hidden_sizes: int | Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice) -> Module:
module = Net(envs.get_observation_shape())
return Module(module, module.output_dim)

View File

@ -0,0 +1,57 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from torch import nn
from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.utils.net import continuous
from tianshou.utils.net.common import Net
class CriticFactory(ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
pass
class CriticFactoryDefault(CriticFactory):
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS:
factory = CriticFactoryContinuousNet(self.hidden_sizes)
return factory.create_module(envs, device, use_action)
elif env_type == EnvType.DISCRETE:
raise NotImplementedError
else:
raise ValueError(f"{env_type} not supported")
class CriticFactoryContinuous(CriticFactory, ABC):
pass
class CriticFactoryContinuousNet(CriticFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
envs.get_observation_shape(),
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
device=device,
)
critic = continuous.Critic(net_c, device=device).to(device)
init_linear_orthogonal(critic)
return critic

View File

@ -0,0 +1,58 @@
from dataclasses import dataclass
import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import ActorFactory
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.critic import CriticFactory
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net.common import ActorCritic
@dataclass
class ModuleOpt:
module: torch.nn.Module
optim: torch.optim.Optimizer
@dataclass
class ActorCriticModuleOpt:
actor_critic_module: ActorCritic
optim: torch.optim.Optimizer
@property
def actor(self):
return self.actor_critic_module.actor
@property
def critic(self):
return self.actor_critic_module.critic
class ActorModuleOptFactory:
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_factory = actor_factory
self.optim_factory = optim_factory
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
actor = self.actor_factory.create_module(envs, device)
opt = self.optim_factory.create_optimizer(actor, lr)
return ModuleOpt(actor, opt)
class CriticModuleOptFactory:
def __init__(
self,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
use_action: bool,
):
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.use_action = use_action
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
critic = self.critic_factory.create_module(envs, device, self.use_action)
opt = self.optim_factory.create_optimizer(critic, lr)
return ModuleOpt(critic, opt)

View File

@ -4,7 +4,7 @@ import numpy as np
import torch import torch
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.module import TDevice from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory

View File

@ -6,7 +6,8 @@ import torch
from tianshou.exploration import BaseNoise from tianshou.exploration import BaseNoise
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.module import ModuleOpt, TDevice from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.alpha import AutoAlphaFactory from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.dist_fn import ( from tianshou.highlevel.params.dist_fn import (
@ -66,6 +67,18 @@ class ParamTransformerDrop(ParamTransformer):
del kwargs[k] del kwargs[k]
class ParamTransformerChangeValue(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, params: dict[str, Any], data: ParamTransformerData):
params[self.key] = self.change_value(params[self.key], data)
@abstractmethod
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
pass
class ParamTransformerLRScheduler(ParamTransformer): class ParamTransformerLRScheduler(ParamTransformer):
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing """Transforms a key containing a learning rate scheduler factory (removed) into a key containing
a learning rate scheduler (added) for the data member `optim`. a learning rate scheduler (added) for the data member `optim`.
@ -182,6 +195,14 @@ class ParamTransformerDistributionFunction(ParamTransformer):
kwargs[self.key] = value.create_dist_fn(data.envs) kwargs[self.key] = value.create_dist_fn(data.envs)
class ParamTransformerActionScaling(ParamTransformerChangeValue):
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
if value == "default":
return data.envs.get_type().is_continuous()
else:
return value
class GetParamTransformersProtocol(Protocol): class GetParamTransformersProtocol(Protocol):
def _get_param_transformers(self) -> list[ParamTransformer]: def _get_param_transformers(self) -> list[ParamTransformer]:
pass pass
@ -218,9 +239,15 @@ class PGParams(Params):
discount_factor: float = 0.99 discount_factor: float = 0.99
reward_normalization: bool = False reward_normalization: bool = False
deterministic_eval: bool = False deterministic_eval: bool = False
action_scaling: bool = True action_scaling: bool | Literal["default"] = "default"
"""whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces"""
action_bound_method: Literal["clip", "tanh"] | None = "clip" action_bound_method: Literal["clip", "tanh"] | None = "clip"
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.append(ParamTransformerActionScaling("action_scaling"))
return transformers
@dataclass @dataclass
class A2CParams(PGParams, ParamsMixinLearningRateWithScheduler): class A2CParams(PGParams, ParamsMixinLearningRateWithScheduler):

View File

@ -0,0 +1,72 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Generic, TypeVar
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import ModuleFactory, TDevice
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.policy import BasePolicy, ICMPolicy
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
TPolicyIn = TypeVar("TPolicyIn", bound=BasePolicy)
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ABC):
@abstractmethod
def create_wrapped_policy(
self,
policy: TPolicyIn,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> TPolicyOut:
pass
class PolicyWrapperFactoryIntrinsicCuriosity(
Generic[TPolicyIn], PolicyWrapperFactory[TPolicyIn, ICMPolicy],
):
def __init__(
self,
feature_net_factory: ModuleFactory,
hidden_sizes: Sequence[int],
lr: float,
lr_scale: float,
reward_scale: float,
forward_loss_weight,
):
self.feature_net_factory = feature_net_factory
self.hidden_sizes = hidden_sizes
self.lr = lr
self.lr_scale = lr_scale
self.reward_scale = reward_scale
self.forward_loss_weight = forward_loss_weight
def create_wrapped_policy(
self,
policy: TPolicyIn,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> ICMPolicy:
feature_net = self.feature_net_factory.create_module(envs, device)
action_dim = envs.get_action_shape()
feature_dim = feature_net.output_dim
icm_net = IntrinsicCuriosityModule(
feature_net.module,
feature_dim,
action_dim,
hidden_sizes=self.hidden_sizes,
device=device,
)
icm_optim = optim_factory.create_optimizer(icm_net, lr=self.lr)
return ICMPolicy(
policy=policy,
model=icm_net,
optim=icm_optim,
action_space=envs.get_action_space(),
lr_scale=self.lr_scale,
reward_scale=self.reward_scale,
forward_loss_weight=self.forward_loss_weight,
).to(device)