Add support for discrete PPO

* Refactored module `module` (split into submodules)
* Basic support for discrete environments
* Implement Atari env. factory
* Implement DQN-based actor factory
* Implement notion of reusing agent preprocessing network for critic
* Add example atari_ppo_hl
This commit is contained in:
Dominik Jain 2023-09-28 20:07:52 +02:00
parent e0e7349b0a
commit 6b6d9ea609
15 changed files with 566 additions and 143 deletions

View File

@ -5,7 +5,11 @@ import numpy as np
import torch
from torch import nn
from tianshou.utils.net.discrete import NoisyLinear
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import ActorFactory
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice
from tianshou.utils.net.common import BaseActor
from tianshou.utils.net.discrete import Actor, NoisyLinear
def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module:
@ -220,3 +224,29 @@ class QRDQN(DQN):
obs, state = super().forward(obs)
obs = obs.view(-1, self.action_num, self.num_quantiles)
return obs, state
class ActorFactoryAtariDQN(ActorFactory):
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool):
self.hidden_size = hidden_size
self.scale_obs = scale_obs
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
net_cls = scale_obs(DQN) if self.scale_obs else DQN
net = net_cls(
*envs.get_observation_shape(),
envs.get_action_shape(),
device=device,
features_only=True,
output_dim=self.hidden_size,
layer_init=layer_init,
)
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
class FeatureNetFactoryDQN(ModuleFactory):
def create_module(self, envs: Environments, device: TDevice) -> Module:
dqn = DQN(
*envs.get_observation_shape(), envs.get_action_shape(), device, features_only=True,
)
return Module(dqn.net, dqn.output_dim)

View File

@ -0,0 +1,117 @@
#!/usr/bin/env python3
import datetime
import os
from collections.abc import Sequence
from jsonargparse import CLI
from examples.atari.atari_network import (
ActorFactoryAtariDQN,
FeatureNetFactoryDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
PPOExperimentBuilder,
RLExperimentConfig,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity,
)
def main(
experiment_config: RLExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: bool = True,
buffer_size: int = 100000,
lr: float = 2.5e-4,
gamma: float = 0.99,
epoch: int = 100,
step_per_epoch: int = 100000,
step_per_collect: int = 1000,
repeat_per_collect: int = 4,
batch_size: int = 256,
hidden_sizes: int | Sequence[int] = 512,
training_num: int = 10,
test_num: int = 10,
rew_norm: bool = False,
vf_coef: float = 0.25,
ent_coef: float = 0.01,
gae_lambda: float = 0.95,
lr_decay: bool = True,
max_grad_norm: float = 0.5,
eps_clip: float = 0.1,
dual_clip: float | None = None,
value_clip: bool = True,
norm_adv: bool = True,
recompute_adv: bool = False,
frames_stack: int = 4,
save_buffer_name: str | None = None, # TODO add support in high-level API?
icm_lr_scale: float = 0.0,
icm_reward_scale: float = 0.01,
icm_forward_loss_weight: float = 0.2,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
repeat_per_collect=repeat_per_collect,
replay_buffer_stack_num=frames_stack,
replay_buffer_ignore_obs_next=True,
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
builder = (
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
.with_ppo_params(
PPOParams(
discount_factor=gamma,
gae_lambda=gae_lambda,
reward_normalization=rew_norm,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
value_clip=value_clip,
advantage_normalization=norm_adv,
eps_clip=eps_clip,
dual_clip=dual_clip,
recompute_advantage=recompute_adv,
lr=lr,
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
),
)
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs))
.with_critic_factory_use_actor()
)
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity(
FeatureNetFactoryDQN(),
[hidden_sizes],
lr,
icm_lr_scale,
icm_reward_scale,
icm_forward_loss_weight,
),
)
experiment = builder.build()
experiment.run(log_name)
if __name__ == "__main__":
CLI(main)

View File

@ -9,6 +9,8 @@ import gymnasium as gym
import numpy as np
from tianshou.env import ShmemVectorEnv
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
try:
import envpool
@ -369,3 +371,22 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
train_envs.seed(seed)
test_envs.seed(seed)
return env, train_envs, test_envs
class AtariEnvFactory(EnvFactory):
def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig, frame_stack: int):
self.task = task
self.sampling_config = sampling_config
self.seed = seed
self.frame_stack = frame_stack
def create_envs(self, config=None) -> DiscreteEnvironments:
env, train_envs, test_envs = make_atari_env(
task=self.task,
seed=self.seed,
training_num=self.sampling_config.num_train_envs,
test_num=self.sampling_config.num_test_envs,
scale=0,
frame_stack=self.frame_stack,
)
return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)

View File

@ -1,4 +1,3 @@
import os
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Generic, TypeVar
@ -9,14 +8,16 @@ from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module import (
ActorCriticModuleOpt,
from tianshou.highlevel.module.actor import (
ActorFactory,
)
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.critic import CriticFactory
from tianshou.highlevel.module.module_opt import (
ActorCriticModuleOpt,
ActorModuleOptFactory,
CriticFactory,
CriticModuleOptFactory,
ModuleOpt,
TDevice,
)
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.policy_params import (
@ -27,8 +28,10 @@ from tianshou.highlevel.params.policy_params import (
SACParams,
TD3Params,
)
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.policy import A2CPolicy, BasePolicy, PPOPolicy, SACPolicy, TD3Policy
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import ActorCritic
CHECKPOINT_DICT_KEY_MODEL = "model"
@ -38,34 +41,62 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy)
class AgentFactory(ABC):
def __init__(self, sampling_config: RLSamplingConfig):
def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory):
self.sampling_config = sampling_config
self.optim_factory = optim_factory
self.policy_wrapper_factory: PolicyWrapperFactory | None = None
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
buffer_size = self.sampling_config.buffer_size
train_envs = envs.train_envs
if len(train_envs) > 1:
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
buffer = VectorReplayBuffer(
buffer_size,
len(train_envs),
stack_num=self.sampling_config.replay_buffer_stack_num,
save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs,
ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next,
)
else:
buffer = ReplayBuffer(buffer_size)
buffer = ReplayBuffer(
buffer_size,
stack_num=self.sampling_config.replay_buffer_stack_num,
save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs,
ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next,
)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, envs.test_envs)
if self.sampling_config.start_timesteps > 0:
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=True)
return train_collector, test_collector
def set_policy_wrapper_factory(
self, policy_wrapper_factory: PolicyWrapperFactory | None,
) -> None:
self.policy_wrapper_factory = policy_wrapper_factory
@abstractmethod
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
pass
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
policy = self._create_policy(envs, device)
if self.policy_wrapper_factory is not None:
policy = self.policy_wrapper_factory.create_wrapped_policy(
policy, envs, self.optim_factory, device,
)
return policy
@staticmethod
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
def save_best_fn(pol: torch.nn.Module) -> None:
state = {
CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
}
torch.save(state, os.path.join(log_path, "policy.pth"))
pass
# TODO: Fix saving in general (code works only for mujoco)
# state = {
# CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
# CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
# }
# torch.save(state, os.path.join(log_path, "policy.pth"))
return save_best_fn
@ -160,11 +191,13 @@ class _ActorCriticMixin:
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
critic_use_action: bool,
critic_use_actor_module: bool,
):
self.actor_factory = actor_factory
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.critic_use_action = critic_use_action
self.critic_use_actor_module = critic_use_actor_module
def create_actor_critic_module_opt(
self,
@ -173,7 +206,23 @@ class _ActorCriticMixin:
lr: float,
) -> ActorCriticModuleOpt:
actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
if self.critic_use_actor_module:
if self.critic_use_action:
raise ValueError(
"The options critic_use_actor_module and critic_use_action are mutually exclusive",
)
if envs.get_type().is_discrete():
critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device)
elif envs.get_type().is_continuous():
critic = continuous.Critic(actor.get_preprocess_net(), device=device).to(device)
else:
raise ValueError
else:
critic = self.critic_factory.create_module(
envs,
device,
use_action=self.critic_use_action,
)
actor_critic = ActorCritic(actor, critic)
optim = self.optim_factory.create_optimizer(actor_critic, lr)
return ActorCriticModuleOpt(actor_critic, optim)
@ -237,14 +286,16 @@ class ActorCriticAgentFactory(
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
policy_class: type[TPolicy],
critic_use_actor_module: bool,
):
super().__init__(sampling_config)
super().__init__(sampling_config, optim_factory=optimizer_factory)
_ActorCriticMixin.__init__(
self,
actor_factory,
critic_factory,
optimizer_factory,
critic_use_action=False,
critic_use_actor_module=critic_use_actor_module,
)
self.params = params
self.policy_class = policy_class
@ -269,7 +320,7 @@ class ActorCriticAgentFactory(
kwargs["action_space"] = envs.get_action_space()
return kwargs
def create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
return self.policy_class(**self._create_kwargs(envs, device))
@ -281,6 +332,7 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
critic_use_actor_module: bool,
):
super().__init__(
params,
@ -289,6 +341,7 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
critic_factory,
optimizer_factory,
A2CPolicy,
critic_use_actor_module,
)
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
@ -303,6 +356,7 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
critic_use_actor_module: bool,
):
super().__init__(
params,
@ -311,6 +365,7 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
critic_factory,
optimizer_factory,
PPOPolicy,
critic_use_actor_module,
)
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
@ -327,7 +382,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config)
super().__init__(sampling_config, optim_factory)
_ActorAndDualCriticsMixin.__init__(
self,
actor_factory,
@ -339,7 +394,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
self.params = params
self.optim_factory = optim_factory
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
@ -376,7 +431,7 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config)
super().__init__(sampling_config, optim_factory)
_ActorAndDualCriticsMixin.__init__(
self,
actor_factory,
@ -388,7 +443,7 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
self.params = params
self.optim_factory = optim_factory
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)

View File

@ -17,3 +17,7 @@ class RLSamplingConfig:
update_per_step: int = 1
start_timesteps: int = 0
start_timesteps_random: bool = False
# TODO can we set the parameters below more intelligently? Perhaps based on env. representation?
replay_buffer_ignore_obs_next: bool = False
replay_buffer_save_only_last_obs: bool = False
replay_buffer_stack_num: int = 1

View File

@ -97,6 +97,22 @@ class ContinuousEnvironments(Environments):
return EnvType.CONTINUOUS
class DiscreteEnvironments(Environments):
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs)
self.observation_shape = env.observation_space.shape or env.observation_space.n
self.action_shape = env.action_space.shape or env.action_space.n
def get_action_shape(self) -> TShape:
return self.action_shape
def get_observation_shape(self) -> TShape:
return self.observation_shape
def get_type(self) -> EnvType:
return EnvType.DISCRETE
class EnvFactory(ABC):
@abstractmethod
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:

View File

@ -18,13 +18,12 @@ from tianshou.highlevel.agent import (
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import EnvFactory, Environments
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
from tianshou.highlevel.module import (
from tianshou.highlevel.module.actor import (
ActorFactory,
ActorFactoryDefault,
ContinuousActorType,
CriticFactory,
CriticFactoryDefault,
)
from tianshou.highlevel.module.critic import CriticFactory, CriticFactoryDefault
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
from tianshou.highlevel.params.policy_params import (
A2CParams,
@ -32,6 +31,7 @@ from tianshou.highlevel.params.policy_params import (
SACParams,
TD3Params,
)
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer
@ -154,6 +154,7 @@ class RLExperimentBuilder:
self._logger_factory: LoggerFactory | None = None
self._optim_factory: OptimizerFactory | None = None
self._env_config: PersistableConfigProtocol | None = None
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
def with_env_config(self, config: PersistableConfigProtocol) -> Self:
self._env_config = config
@ -163,6 +164,10 @@ class RLExperimentBuilder:
self._logger_factory = logger_factory
return self
def with_policy_wrapper_factory(self, policy_wrapper_factory: PolicyWrapperFactory) -> Self:
self._policy_wrapper_factory = policy_wrapper_factory
return self
def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder:
self._optim_factory = optim_factory
return self
@ -194,10 +199,13 @@ class RLExperimentBuilder:
return self._optim_factory
def build(self) -> RLExperiment:
agent_factory = self._create_agent_factory()
if self._policy_wrapper_factory:
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
return RLExperiment(
self._config,
self._env_factory,
self._create_agent_factory(),
agent_factory,
self._logger_factory,
env_config=self._env_config,
)
@ -287,6 +295,7 @@ class _BuilderMixinCriticsFactory:
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self):
super().__init__(1)
self._critic_use_actor_module = False
def with_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
self: TBuilder | "_BuilderMixinSingleCriticFactory"
@ -301,6 +310,11 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
self._with_critic_factory_default(0, hidden_sizes)
return self
def with_critic_factory_use_actor(self) -> Self:
"""Makes the critic use the same network as the actor."""
self._critic_use_actor_module = True
return self
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self):
@ -378,6 +392,7 @@ class A2CExperimentBuilder(
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_optim_factory(),
self._critic_use_actor_module,
)
@ -411,6 +426,7 @@ class PPOExperimentBuilder(
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_optim_factory(),
self._critic_use_actor_module,
)

View File

View File

@ -1,29 +1,13 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TypeAlias
import numpy as np
import torch
from torch import nn
from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous
from tianshou.utils.net.common import ActorCritic, Net
TDevice: TypeAlias = str | int | torch.device
def init_linear_orthogonal(module: torch.nn.Module):
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
:param module: the module whose submodules are to be processed
"""
for m in module.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import BaseActor, Net
class ContinuousActorType:
@ -33,7 +17,7 @@ class ContinuousActorType:
class ActorFactory(ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
pass
@staticmethod
@ -70,18 +54,18 @@ class ActorFactoryDefault(ActorFactory):
self.continuous_conditioned_sigma = continuous_conditioned_sigma
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS:
match self.continuous_actor_type:
case ContinuousActorType.GAUSSIAN:
factory = ActorFactoryContinuousGaussian(
factory = ActorFactoryContinuousGaussianNet(
self.hidden_sizes,
unbounded=self.continuous_unbounded,
conditioned_sigma=self.continuous_conditioned_sigma,
)
case ContinuousActorType.DETERMINISTIC:
factory = ActorFactoryContinuousDeterministic(self.hidden_sizes)
factory = ActorFactoryContinuousDeterministicNet(self.hidden_sizes)
case _:
raise ValueError(self.continuous_actor_type)
return factory.create_module(envs, device)
@ -95,11 +79,11 @@ class ActorFactoryContinuous(ActorFactory, ABC):
"""Serves as a type bound for actor factories that are suitable for continuous action spaces."""
class ActorFactoryContinuousDeterministic(ActorFactoryContinuous):
class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
net_a = Net(
envs.get_observation_shape(),
hidden_sizes=self.hidden_sizes,
@ -113,13 +97,13 @@ class ActorFactoryContinuousDeterministic(ActorFactoryContinuous):
).to(device)
class ActorFactoryContinuousGaussian(ActorFactoryContinuous):
class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
self.hidden_sizes = hidden_sizes
self.unbounded = unbounded
self.conditioned_sigma = conditioned_sigma
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
net_a = Net(
envs.get_observation_shape(),
hidden_sizes=self.hidden_sizes,
@ -142,97 +126,19 @@ class ActorFactoryContinuousGaussian(ActorFactoryContinuous):
return actor
class CriticFactory(ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
pass
class CriticFactoryDefault(CriticFactory):
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS:
factory = CriticFactoryContinuousNet(self.hidden_sizes)
return factory.create_module(envs, device, use_action)
elif env_type == EnvType.DISCRETE:
raise NotImplementedError
else:
raise ValueError(f"{env_type} not supported")
class CriticFactoryContinuous(CriticFactory, ABC):
pass
class CriticFactoryContinuousNet(CriticFactoryContinuous):
class ActorFactoryDiscreteNet(ActorFactory):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
net_a = Net(
envs.get_observation_shape(),
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
device=device,
)
critic = continuous.Critic(net_c, device=device).to(device)
init_linear_orthogonal(critic)
return critic
@dataclass
class ModuleOpt:
module: torch.nn.Module
optim: torch.optim.Optimizer
@dataclass
class ActorCriticModuleOpt:
actor_critic_module: ActorCritic
optim: torch.optim.Optimizer
@property
def actor(self):
return self.actor_critic_module.actor
@property
def critic(self):
return self.actor_critic_module.critic
class ActorModuleOptFactory:
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_factory = actor_factory
self.optim_factory = optim_factory
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
actor = self.actor_factory.create_module(envs, device)
opt = self.optim_factory.create_optimizer(actor, lr)
return ModuleOpt(actor, opt)
class CriticModuleOptFactory:
def __init__(
self,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
use_action: bool,
):
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.use_action = use_action
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
critic = self.critic_factory.create_module(envs, device, self.use_action)
opt = self.optim_factory.create_optimizer(critic, lr)
return ModuleOpt(critic, opt)
return discrete.Actor(
net_a,
envs.get_action_shape(),
hidden_sizes=(),
device=device,
).to(device)

View File

@ -0,0 +1,44 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TypeAlias
import numpy as np
import torch
from tianshou.highlevel.env import Environments
from tianshou.utils.net.common import Net
TDevice: TypeAlias = str | int | torch.device
def init_linear_orthogonal(module: torch.nn.Module):
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
:param module: the module whose submodules are to be processed
"""
for m in module.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
@dataclass
class Module:
module: torch.nn.Module
output_dim: int
class ModuleFactory(ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> Module:
pass
class ModuleFactoryNet(ModuleFactory):
def __init__(self, hidden_sizes: int | Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice) -> Module:
module = Net(envs.get_observation_shape())
return Module(module, module.output_dim)

View File

@ -0,0 +1,57 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from torch import nn
from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.utils.net import continuous
from tianshou.utils.net.common import Net
class CriticFactory(ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
pass
class CriticFactoryDefault(CriticFactory):
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS:
factory = CriticFactoryContinuousNet(self.hidden_sizes)
return factory.create_module(envs, device, use_action)
elif env_type == EnvType.DISCRETE:
raise NotImplementedError
else:
raise ValueError(f"{env_type} not supported")
class CriticFactoryContinuous(CriticFactory, ABC):
pass
class CriticFactoryContinuousNet(CriticFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
envs.get_observation_shape(),
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
device=device,
)
critic = continuous.Critic(net_c, device=device).to(device)
init_linear_orthogonal(critic)
return critic

View File

@ -0,0 +1,58 @@
from dataclasses import dataclass
import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import ActorFactory
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.critic import CriticFactory
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net.common import ActorCritic
@dataclass
class ModuleOpt:
module: torch.nn.Module
optim: torch.optim.Optimizer
@dataclass
class ActorCriticModuleOpt:
actor_critic_module: ActorCritic
optim: torch.optim.Optimizer
@property
def actor(self):
return self.actor_critic_module.actor
@property
def critic(self):
return self.actor_critic_module.critic
class ActorModuleOptFactory:
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_factory = actor_factory
self.optim_factory = optim_factory
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
actor = self.actor_factory.create_module(envs, device)
opt = self.optim_factory.create_optimizer(actor, lr)
return ModuleOpt(actor, opt)
class CriticModuleOptFactory:
def __init__(
self,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
use_action: bool,
):
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.use_action = use_action
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
critic = self.critic_factory.create_module(envs, device, self.use_action)
opt = self.optim_factory.create_optimizer(critic, lr)
return ModuleOpt(critic, opt)

View File

@ -4,7 +4,7 @@ import numpy as np
import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module import TDevice
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.optim import OptimizerFactory

View File

@ -6,7 +6,8 @@ import torch
from tianshou.exploration import BaseNoise
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module import ModuleOpt, TDevice
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.dist_fn import (
@ -66,6 +67,18 @@ class ParamTransformerDrop(ParamTransformer):
del kwargs[k]
class ParamTransformerChangeValue(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, params: dict[str, Any], data: ParamTransformerData):
params[self.key] = self.change_value(params[self.key], data)
@abstractmethod
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
pass
class ParamTransformerLRScheduler(ParamTransformer):
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
a learning rate scheduler (added) for the data member `optim`.
@ -182,6 +195,14 @@ class ParamTransformerDistributionFunction(ParamTransformer):
kwargs[self.key] = value.create_dist_fn(data.envs)
class ParamTransformerActionScaling(ParamTransformerChangeValue):
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
if value == "default":
return data.envs.get_type().is_continuous()
else:
return value
class GetParamTransformersProtocol(Protocol):
def _get_param_transformers(self) -> list[ParamTransformer]:
pass
@ -218,9 +239,15 @@ class PGParams(Params):
discount_factor: float = 0.99
reward_normalization: bool = False
deterministic_eval: bool = False
action_scaling: bool = True
action_scaling: bool | Literal["default"] = "default"
"""whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces"""
action_bound_method: Literal["clip", "tanh"] | None = "clip"
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.append(ParamTransformerActionScaling("action_scaling"))
return transformers
@dataclass
class A2CParams(PGParams, ParamsMixinLearningRateWithScheduler):

View File

@ -0,0 +1,72 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Generic, TypeVar
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import ModuleFactory, TDevice
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.policy import BasePolicy, ICMPolicy
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
TPolicyIn = TypeVar("TPolicyIn", bound=BasePolicy)
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ABC):
@abstractmethod
def create_wrapped_policy(
self,
policy: TPolicyIn,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> TPolicyOut:
pass
class PolicyWrapperFactoryIntrinsicCuriosity(
Generic[TPolicyIn], PolicyWrapperFactory[TPolicyIn, ICMPolicy],
):
def __init__(
self,
feature_net_factory: ModuleFactory,
hidden_sizes: Sequence[int],
lr: float,
lr_scale: float,
reward_scale: float,
forward_loss_weight,
):
self.feature_net_factory = feature_net_factory
self.hidden_sizes = hidden_sizes
self.lr = lr
self.lr_scale = lr_scale
self.reward_scale = reward_scale
self.forward_loss_weight = forward_loss_weight
def create_wrapped_policy(
self,
policy: TPolicyIn,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> ICMPolicy:
feature_net = self.feature_net_factory.create_module(envs, device)
action_dim = envs.get_action_shape()
feature_dim = feature_net.output_dim
icm_net = IntrinsicCuriosityModule(
feature_net.module,
feature_dim,
action_dim,
hidden_sizes=self.hidden_sizes,
device=device,
)
icm_optim = optim_factory.create_optimizer(icm_net, lr=self.lr)
return ICMPolicy(
policy=policy,
model=icm_net,
optim=icm_optim,
action_space=envs.get_action_space(),
lr_scale=self.lr_scale,
reward_scale=self.reward_scale,
forward_loss_weight=self.forward_loss_weight,
).to(device)