Add support for discrete PPO
* Refactored module `module` (split into submodules) * Basic support for discrete environments * Implement Atari env. factory * Implement DQN-based actor factory * Implement notion of reusing agent preprocessing network for critic * Add example atari_ppo_hl
This commit is contained in:
parent
e0e7349b0a
commit
6b6d9ea609
@ -5,7 +5,11 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.utils.net.discrete import NoisyLinear
|
from tianshou.highlevel.env import Environments
|
||||||
|
from tianshou.highlevel.module.actor import ActorFactory
|
||||||
|
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice
|
||||||
|
from tianshou.utils.net.common import BaseActor
|
||||||
|
from tianshou.utils.net.discrete import Actor, NoisyLinear
|
||||||
|
|
||||||
|
|
||||||
def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module:
|
def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module:
|
||||||
@ -220,3 +224,29 @@ class QRDQN(DQN):
|
|||||||
obs, state = super().forward(obs)
|
obs, state = super().forward(obs)
|
||||||
obs = obs.view(-1, self.action_num, self.num_quantiles)
|
obs = obs.view(-1, self.action_num, self.num_quantiles)
|
||||||
return obs, state
|
return obs, state
|
||||||
|
|
||||||
|
|
||||||
|
class ActorFactoryAtariDQN(ActorFactory):
|
||||||
|
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool):
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.scale_obs = scale_obs
|
||||||
|
|
||||||
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||||
|
net_cls = scale_obs(DQN) if self.scale_obs else DQN
|
||||||
|
net = net_cls(
|
||||||
|
*envs.get_observation_shape(),
|
||||||
|
envs.get_action_shape(),
|
||||||
|
device=device,
|
||||||
|
features_only=True,
|
||||||
|
output_dim=self.hidden_size,
|
||||||
|
layer_init=layer_init,
|
||||||
|
)
|
||||||
|
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureNetFactoryDQN(ModuleFactory):
|
||||||
|
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
||||||
|
dqn = DQN(
|
||||||
|
*envs.get_observation_shape(), envs.get_action_shape(), device, features_only=True,
|
||||||
|
)
|
||||||
|
return Module(dqn.net, dqn.output_dim)
|
||||||
|
|||||||
117
examples/atari/atari_ppo_hl.py
Normal file
117
examples/atari/atari_ppo_hl.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from jsonargparse import CLI
|
||||||
|
|
||||||
|
from examples.atari.atari_network import (
|
||||||
|
ActorFactoryAtariDQN,
|
||||||
|
FeatureNetFactoryDQN,
|
||||||
|
)
|
||||||
|
from examples.atari.atari_wrapper import AtariEnvFactory
|
||||||
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
|
from tianshou.highlevel.experiment import (
|
||||||
|
PPOExperimentBuilder,
|
||||||
|
RLExperimentConfig,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
|
from tianshou.highlevel.params.policy_params import PPOParams
|
||||||
|
from tianshou.highlevel.params.policy_wrapper import (
|
||||||
|
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
experiment_config: RLExperimentConfig,
|
||||||
|
task: str = "PongNoFrameskip-v4",
|
||||||
|
scale_obs: bool = True,
|
||||||
|
buffer_size: int = 100000,
|
||||||
|
lr: float = 2.5e-4,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
epoch: int = 100,
|
||||||
|
step_per_epoch: int = 100000,
|
||||||
|
step_per_collect: int = 1000,
|
||||||
|
repeat_per_collect: int = 4,
|
||||||
|
batch_size: int = 256,
|
||||||
|
hidden_sizes: int | Sequence[int] = 512,
|
||||||
|
training_num: int = 10,
|
||||||
|
test_num: int = 10,
|
||||||
|
rew_norm: bool = False,
|
||||||
|
vf_coef: float = 0.25,
|
||||||
|
ent_coef: float = 0.01,
|
||||||
|
gae_lambda: float = 0.95,
|
||||||
|
lr_decay: bool = True,
|
||||||
|
max_grad_norm: float = 0.5,
|
||||||
|
eps_clip: float = 0.1,
|
||||||
|
dual_clip: float | None = None,
|
||||||
|
value_clip: bool = True,
|
||||||
|
norm_adv: bool = True,
|
||||||
|
recompute_adv: bool = False,
|
||||||
|
frames_stack: int = 4,
|
||||||
|
save_buffer_name: str | None = None, # TODO add support in high-level API?
|
||||||
|
icm_lr_scale: float = 0.0,
|
||||||
|
icm_reward_scale: float = 0.01,
|
||||||
|
icm_forward_loss_weight: float = 0.2,
|
||||||
|
):
|
||||||
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
|
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||||
|
|
||||||
|
sampling_config = RLSamplingConfig(
|
||||||
|
num_epochs=epoch,
|
||||||
|
step_per_epoch=step_per_epoch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_train_envs=training_num,
|
||||||
|
num_test_envs=test_num,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
step_per_collect=step_per_collect,
|
||||||
|
repeat_per_collect=repeat_per_collect,
|
||||||
|
replay_buffer_stack_num=frames_stack,
|
||||||
|
replay_buffer_ignore_obs_next=True,
|
||||||
|
replay_buffer_save_only_last_obs=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
|
||||||
|
|
||||||
|
builder = (
|
||||||
|
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||||
|
.with_ppo_params(
|
||||||
|
PPOParams(
|
||||||
|
discount_factor=gamma,
|
||||||
|
gae_lambda=gae_lambda,
|
||||||
|
reward_normalization=rew_norm,
|
||||||
|
ent_coef=ent_coef,
|
||||||
|
vf_coef=vf_coef,
|
||||||
|
max_grad_norm=max_grad_norm,
|
||||||
|
value_clip=value_clip,
|
||||||
|
advantage_normalization=norm_adv,
|
||||||
|
eps_clip=eps_clip,
|
||||||
|
dual_clip=dual_clip,
|
||||||
|
recompute_advantage=recompute_adv,
|
||||||
|
lr=lr,
|
||||||
|
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
|
||||||
|
if lr_decay
|
||||||
|
else None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs))
|
||||||
|
.with_critic_factory_use_actor()
|
||||||
|
)
|
||||||
|
if icm_lr_scale > 0:
|
||||||
|
builder.with_policy_wrapper_factory(
|
||||||
|
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||||
|
FeatureNetFactoryDQN(),
|
||||||
|
[hidden_sizes],
|
||||||
|
lr,
|
||||||
|
icm_lr_scale,
|
||||||
|
icm_reward_scale,
|
||||||
|
icm_forward_loss_weight,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
experiment = builder.build()
|
||||||
|
experiment.run(log_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
CLI(main)
|
||||||
@ -9,6 +9,8 @@ import gymnasium as gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tianshou.env import ShmemVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
|
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import envpool
|
import envpool
|
||||||
@ -369,3 +371,22 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
|
|||||||
train_envs.seed(seed)
|
train_envs.seed(seed)
|
||||||
test_envs.seed(seed)
|
test_envs.seed(seed)
|
||||||
return env, train_envs, test_envs
|
return env, train_envs, test_envs
|
||||||
|
|
||||||
|
|
||||||
|
class AtariEnvFactory(EnvFactory):
|
||||||
|
def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig, frame_stack: int):
|
||||||
|
self.task = task
|
||||||
|
self.sampling_config = sampling_config
|
||||||
|
self.seed = seed
|
||||||
|
self.frame_stack = frame_stack
|
||||||
|
|
||||||
|
def create_envs(self, config=None) -> DiscreteEnvironments:
|
||||||
|
env, train_envs, test_envs = make_atari_env(
|
||||||
|
task=self.task,
|
||||||
|
seed=self.seed,
|
||||||
|
training_num=self.sampling_config.num_train_envs,
|
||||||
|
test_num=self.sampling_config.num_test_envs,
|
||||||
|
scale=0,
|
||||||
|
frame_stack=self.frame_stack,
|
||||||
|
)
|
||||||
|
return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
@ -9,14 +8,16 @@ from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
|||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.logger import Logger
|
from tianshou.highlevel.logger import Logger
|
||||||
from tianshou.highlevel.module import (
|
from tianshou.highlevel.module.actor import (
|
||||||
ActorCriticModuleOpt,
|
|
||||||
ActorFactory,
|
ActorFactory,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.module.core import TDevice
|
||||||
|
from tianshou.highlevel.module.critic import CriticFactory
|
||||||
|
from tianshou.highlevel.module.module_opt import (
|
||||||
|
ActorCriticModuleOpt,
|
||||||
ActorModuleOptFactory,
|
ActorModuleOptFactory,
|
||||||
CriticFactory,
|
|
||||||
CriticModuleOptFactory,
|
CriticModuleOptFactory,
|
||||||
ModuleOpt,
|
ModuleOpt,
|
||||||
TDevice,
|
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.highlevel.params.policy_params import (
|
from tianshou.highlevel.params.policy_params import (
|
||||||
@ -27,8 +28,10 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
SACParams,
|
SACParams,
|
||||||
TD3Params,
|
TD3Params,
|
||||||
)
|
)
|
||||||
|
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||||
from tianshou.policy import A2CPolicy, BasePolicy, PPOPolicy, SACPolicy, TD3Policy
|
from tianshou.policy import A2CPolicy, BasePolicy, PPOPolicy, SACPolicy, TD3Policy
|
||||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||||
|
from tianshou.utils.net import continuous, discrete
|
||||||
from tianshou.utils.net.common import ActorCritic
|
from tianshou.utils.net.common import ActorCritic
|
||||||
|
|
||||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||||
@ -38,34 +41,62 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
|||||||
|
|
||||||
|
|
||||||
class AgentFactory(ABC):
|
class AgentFactory(ABC):
|
||||||
def __init__(self, sampling_config: RLSamplingConfig):
|
def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory):
|
||||||
self.sampling_config = sampling_config
|
self.sampling_config = sampling_config
|
||||||
|
self.optim_factory = optim_factory
|
||||||
|
self.policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||||
|
|
||||||
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
|
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
|
||||||
buffer_size = self.sampling_config.buffer_size
|
buffer_size = self.sampling_config.buffer_size
|
||||||
train_envs = envs.train_envs
|
train_envs = envs.train_envs
|
||||||
if len(train_envs) > 1:
|
if len(train_envs) > 1:
|
||||||
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
|
buffer = VectorReplayBuffer(
|
||||||
|
buffer_size,
|
||||||
|
len(train_envs),
|
||||||
|
stack_num=self.sampling_config.replay_buffer_stack_num,
|
||||||
|
save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs,
|
||||||
|
ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
buffer = ReplayBuffer(buffer_size)
|
buffer = ReplayBuffer(
|
||||||
|
buffer_size,
|
||||||
|
stack_num=self.sampling_config.replay_buffer_stack_num,
|
||||||
|
save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs,
|
||||||
|
ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next,
|
||||||
|
)
|
||||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||||
test_collector = Collector(policy, envs.test_envs)
|
test_collector = Collector(policy, envs.test_envs)
|
||||||
if self.sampling_config.start_timesteps > 0:
|
if self.sampling_config.start_timesteps > 0:
|
||||||
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=True)
|
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=True)
|
||||||
return train_collector, test_collector
|
return train_collector, test_collector
|
||||||
|
|
||||||
|
def set_policy_wrapper_factory(
|
||||||
|
self, policy_wrapper_factory: PolicyWrapperFactory | None,
|
||||||
|
) -> None:
|
||||||
|
self.policy_wrapper_factory = policy_wrapper_factory
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
|
policy = self._create_policy(envs, device)
|
||||||
|
if self.policy_wrapper_factory is not None:
|
||||||
|
policy = self.policy_wrapper_factory.create_wrapped_policy(
|
||||||
|
policy, envs, self.optim_factory, device,
|
||||||
|
)
|
||||||
|
return policy
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
|
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
|
||||||
def save_best_fn(pol: torch.nn.Module) -> None:
|
def save_best_fn(pol: torch.nn.Module) -> None:
|
||||||
state = {
|
pass
|
||||||
CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
|
# TODO: Fix saving in general (code works only for mujoco)
|
||||||
CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
|
# state = {
|
||||||
}
|
# CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
|
||||||
torch.save(state, os.path.join(log_path, "policy.pth"))
|
# CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
|
||||||
|
# }
|
||||||
|
# torch.save(state, os.path.join(log_path, "policy.pth"))
|
||||||
|
|
||||||
return save_best_fn
|
return save_best_fn
|
||||||
|
|
||||||
@ -160,11 +191,13 @@ class _ActorCriticMixin:
|
|||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
critic_use_action: bool,
|
critic_use_action: bool,
|
||||||
|
critic_use_actor_module: bool,
|
||||||
):
|
):
|
||||||
self.actor_factory = actor_factory
|
self.actor_factory = actor_factory
|
||||||
self.critic_factory = critic_factory
|
self.critic_factory = critic_factory
|
||||||
self.optim_factory = optim_factory
|
self.optim_factory = optim_factory
|
||||||
self.critic_use_action = critic_use_action
|
self.critic_use_action = critic_use_action
|
||||||
|
self.critic_use_actor_module = critic_use_actor_module
|
||||||
|
|
||||||
def create_actor_critic_module_opt(
|
def create_actor_critic_module_opt(
|
||||||
self,
|
self,
|
||||||
@ -173,7 +206,23 @@ class _ActorCriticMixin:
|
|||||||
lr: float,
|
lr: float,
|
||||||
) -> ActorCriticModuleOpt:
|
) -> ActorCriticModuleOpt:
|
||||||
actor = self.actor_factory.create_module(envs, device)
|
actor = self.actor_factory.create_module(envs, device)
|
||||||
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
|
if self.critic_use_actor_module:
|
||||||
|
if self.critic_use_action:
|
||||||
|
raise ValueError(
|
||||||
|
"The options critic_use_actor_module and critic_use_action are mutually exclusive",
|
||||||
|
)
|
||||||
|
if envs.get_type().is_discrete():
|
||||||
|
critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device)
|
||||||
|
elif envs.get_type().is_continuous():
|
||||||
|
critic = continuous.Critic(actor.get_preprocess_net(), device=device).to(device)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
else:
|
||||||
|
critic = self.critic_factory.create_module(
|
||||||
|
envs,
|
||||||
|
device,
|
||||||
|
use_action=self.critic_use_action,
|
||||||
|
)
|
||||||
actor_critic = ActorCritic(actor, critic)
|
actor_critic = ActorCritic(actor, critic)
|
||||||
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
||||||
return ActorCriticModuleOpt(actor_critic, optim)
|
return ActorCriticModuleOpt(actor_critic, optim)
|
||||||
@ -237,14 +286,16 @@ class ActorCriticAgentFactory(
|
|||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optimizer_factory: OptimizerFactory,
|
optimizer_factory: OptimizerFactory,
|
||||||
policy_class: type[TPolicy],
|
policy_class: type[TPolicy],
|
||||||
|
critic_use_actor_module: bool,
|
||||||
):
|
):
|
||||||
super().__init__(sampling_config)
|
super().__init__(sampling_config, optim_factory=optimizer_factory)
|
||||||
_ActorCriticMixin.__init__(
|
_ActorCriticMixin.__init__(
|
||||||
self,
|
self,
|
||||||
actor_factory,
|
actor_factory,
|
||||||
critic_factory,
|
critic_factory,
|
||||||
optimizer_factory,
|
optimizer_factory,
|
||||||
critic_use_action=False,
|
critic_use_action=False,
|
||||||
|
critic_use_actor_module=critic_use_actor_module,
|
||||||
)
|
)
|
||||||
self.params = params
|
self.params = params
|
||||||
self.policy_class = policy_class
|
self.policy_class = policy_class
|
||||||
@ -269,7 +320,7 @@ class ActorCriticAgentFactory(
|
|||||||
kwargs["action_space"] = envs.get_action_space()
|
kwargs["action_space"] = envs.get_action_space()
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||||
return self.policy_class(**self._create_kwargs(envs, device))
|
return self.policy_class(**self._create_kwargs(envs, device))
|
||||||
|
|
||||||
|
|
||||||
@ -281,6 +332,7 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
|
|||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optimizer_factory: OptimizerFactory,
|
optimizer_factory: OptimizerFactory,
|
||||||
|
critic_use_actor_module: bool,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
params,
|
params,
|
||||||
@ -289,6 +341,7 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
|
|||||||
critic_factory,
|
critic_factory,
|
||||||
optimizer_factory,
|
optimizer_factory,
|
||||||
A2CPolicy,
|
A2CPolicy,
|
||||||
|
critic_use_actor_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||||
@ -303,6 +356,7 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
|||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optimizer_factory: OptimizerFactory,
|
optimizer_factory: OptimizerFactory,
|
||||||
|
critic_use_actor_module: bool,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
params,
|
params,
|
||||||
@ -311,6 +365,7 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
|||||||
critic_factory,
|
critic_factory,
|
||||||
optimizer_factory,
|
optimizer_factory,
|
||||||
PPOPolicy,
|
PPOPolicy,
|
||||||
|
critic_use_actor_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||||
@ -327,7 +382,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
|||||||
critic2_factory: CriticFactory,
|
critic2_factory: CriticFactory,
|
||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
):
|
):
|
||||||
super().__init__(sampling_config)
|
super().__init__(sampling_config, optim_factory)
|
||||||
_ActorAndDualCriticsMixin.__init__(
|
_ActorAndDualCriticsMixin.__init__(
|
||||||
self,
|
self,
|
||||||
actor_factory,
|
actor_factory,
|
||||||
@ -339,7 +394,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
|||||||
self.params = params
|
self.params = params
|
||||||
self.optim_factory = optim_factory
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
||||||
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
||||||
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
||||||
@ -376,7 +431,7 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
|||||||
critic2_factory: CriticFactory,
|
critic2_factory: CriticFactory,
|
||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
):
|
):
|
||||||
super().__init__(sampling_config)
|
super().__init__(sampling_config, optim_factory)
|
||||||
_ActorAndDualCriticsMixin.__init__(
|
_ActorAndDualCriticsMixin.__init__(
|
||||||
self,
|
self,
|
||||||
actor_factory,
|
actor_factory,
|
||||||
@ -388,7 +443,7 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
|||||||
self.params = params
|
self.params = params
|
||||||
self.optim_factory = optim_factory
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
||||||
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
||||||
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
||||||
|
|||||||
@ -17,3 +17,7 @@ class RLSamplingConfig:
|
|||||||
update_per_step: int = 1
|
update_per_step: int = 1
|
||||||
start_timesteps: int = 0
|
start_timesteps: int = 0
|
||||||
start_timesteps_random: bool = False
|
start_timesteps_random: bool = False
|
||||||
|
# TODO can we set the parameters below more intelligently? Perhaps based on env. representation?
|
||||||
|
replay_buffer_ignore_obs_next: bool = False
|
||||||
|
replay_buffer_save_only_last_obs: bool = False
|
||||||
|
replay_buffer_stack_num: int = 1
|
||||||
|
|||||||
@ -97,6 +97,22 @@ class ContinuousEnvironments(Environments):
|
|||||||
return EnvType.CONTINUOUS
|
return EnvType.CONTINUOUS
|
||||||
|
|
||||||
|
|
||||||
|
class DiscreteEnvironments(Environments):
|
||||||
|
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||||
|
super().__init__(env, train_envs, test_envs)
|
||||||
|
self.observation_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
self.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
|
||||||
|
def get_action_shape(self) -> TShape:
|
||||||
|
return self.action_shape
|
||||||
|
|
||||||
|
def get_observation_shape(self) -> TShape:
|
||||||
|
return self.observation_shape
|
||||||
|
|
||||||
|
def get_type(self) -> EnvType:
|
||||||
|
return EnvType.DISCRETE
|
||||||
|
|
||||||
|
|
||||||
class EnvFactory(ABC):
|
class EnvFactory(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
||||||
|
|||||||
@ -18,13 +18,12 @@ from tianshou.highlevel.agent import (
|
|||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
from tianshou.highlevel.env import EnvFactory, Environments
|
from tianshou.highlevel.env import EnvFactory, Environments
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
||||||
from tianshou.highlevel.module import (
|
from tianshou.highlevel.module.actor import (
|
||||||
ActorFactory,
|
ActorFactory,
|
||||||
ActorFactoryDefault,
|
ActorFactoryDefault,
|
||||||
ContinuousActorType,
|
ContinuousActorType,
|
||||||
CriticFactory,
|
|
||||||
CriticFactoryDefault,
|
|
||||||
)
|
)
|
||||||
|
from tianshou.highlevel.module.critic import CriticFactory, CriticFactoryDefault
|
||||||
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
||||||
from tianshou.highlevel.params.policy_params import (
|
from tianshou.highlevel.params.policy_params import (
|
||||||
A2CParams,
|
A2CParams,
|
||||||
@ -32,6 +31,7 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
SACParams,
|
SACParams,
|
||||||
TD3Params,
|
TD3Params,
|
||||||
)
|
)
|
||||||
|
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.trainer import BaseTrainer
|
from tianshou.trainer import BaseTrainer
|
||||||
@ -154,6 +154,7 @@ class RLExperimentBuilder:
|
|||||||
self._logger_factory: LoggerFactory | None = None
|
self._logger_factory: LoggerFactory | None = None
|
||||||
self._optim_factory: OptimizerFactory | None = None
|
self._optim_factory: OptimizerFactory | None = None
|
||||||
self._env_config: PersistableConfigProtocol | None = None
|
self._env_config: PersistableConfigProtocol | None = None
|
||||||
|
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||||
|
|
||||||
def with_env_config(self, config: PersistableConfigProtocol) -> Self:
|
def with_env_config(self, config: PersistableConfigProtocol) -> Self:
|
||||||
self._env_config = config
|
self._env_config = config
|
||||||
@ -163,6 +164,10 @@ class RLExperimentBuilder:
|
|||||||
self._logger_factory = logger_factory
|
self._logger_factory = logger_factory
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def with_policy_wrapper_factory(self, policy_wrapper_factory: PolicyWrapperFactory) -> Self:
|
||||||
|
self._policy_wrapper_factory = policy_wrapper_factory
|
||||||
|
return self
|
||||||
|
|
||||||
def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder:
|
def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder:
|
||||||
self._optim_factory = optim_factory
|
self._optim_factory = optim_factory
|
||||||
return self
|
return self
|
||||||
@ -194,10 +199,13 @@ class RLExperimentBuilder:
|
|||||||
return self._optim_factory
|
return self._optim_factory
|
||||||
|
|
||||||
def build(self) -> RLExperiment:
|
def build(self) -> RLExperiment:
|
||||||
|
agent_factory = self._create_agent_factory()
|
||||||
|
if self._policy_wrapper_factory:
|
||||||
|
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
||||||
return RLExperiment(
|
return RLExperiment(
|
||||||
self._config,
|
self._config,
|
||||||
self._env_factory,
|
self._env_factory,
|
||||||
self._create_agent_factory(),
|
agent_factory,
|
||||||
self._logger_factory,
|
self._logger_factory,
|
||||||
env_config=self._env_config,
|
env_config=self._env_config,
|
||||||
)
|
)
|
||||||
@ -287,6 +295,7 @@ class _BuilderMixinCriticsFactory:
|
|||||||
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(1)
|
super().__init__(1)
|
||||||
|
self._critic_use_actor_module = False
|
||||||
|
|
||||||
def with_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
def with_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
||||||
self: TBuilder | "_BuilderMixinSingleCriticFactory"
|
self: TBuilder | "_BuilderMixinSingleCriticFactory"
|
||||||
@ -301,6 +310,11 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
self._with_critic_factory_default(0, hidden_sizes)
|
self._with_critic_factory_default(0, hidden_sizes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def with_critic_factory_use_actor(self) -> Self:
|
||||||
|
"""Makes the critic use the same network as the actor."""
|
||||||
|
self._critic_use_actor_module = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -378,6 +392,7 @@ class A2CExperimentBuilder(
|
|||||||
self._get_actor_factory(),
|
self._get_actor_factory(),
|
||||||
self._get_critic_factory(0),
|
self._get_critic_factory(0),
|
||||||
self._get_optim_factory(),
|
self._get_optim_factory(),
|
||||||
|
self._critic_use_actor_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -411,6 +426,7 @@ class PPOExperimentBuilder(
|
|||||||
self._get_actor_factory(),
|
self._get_actor_factory(),
|
||||||
self._get_critic_factory(0),
|
self._get_critic_factory(0),
|
||||||
self._get_optim_factory(),
|
self._get_optim_factory(),
|
||||||
|
self._critic_use_actor_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
0
tianshou/highlevel/module/__init__.py
Normal file
0
tianshou/highlevel/module/__init__.py
Normal file
@ -1,29 +1,13 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TypeAlias
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.highlevel.env import Environments, EnvType
|
from tianshou.highlevel.env import Environments, EnvType
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
||||||
from tianshou.utils.net import continuous
|
from tianshou.utils.net import continuous, discrete
|
||||||
from tianshou.utils.net.common import ActorCritic, Net
|
from tianshou.utils.net.common import BaseActor, Net
|
||||||
|
|
||||||
TDevice: TypeAlias = str | int | torch.device
|
|
||||||
|
|
||||||
|
|
||||||
def init_linear_orthogonal(module: torch.nn.Module):
|
|
||||||
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
|
|
||||||
|
|
||||||
:param module: the module whose submodules are to be processed
|
|
||||||
"""
|
|
||||||
for m in module.modules():
|
|
||||||
if isinstance(m, torch.nn.Linear):
|
|
||||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
|
||||||
torch.nn.init.zeros_(m.bias)
|
|
||||||
|
|
||||||
|
|
||||||
class ContinuousActorType:
|
class ContinuousActorType:
|
||||||
@ -33,7 +17,7 @@ class ContinuousActorType:
|
|||||||
|
|
||||||
class ActorFactory(ABC):
|
class ActorFactory(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -70,18 +54,18 @@ class ActorFactoryDefault(ActorFactory):
|
|||||||
self.continuous_conditioned_sigma = continuous_conditioned_sigma
|
self.continuous_conditioned_sigma = continuous_conditioned_sigma
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||||
env_type = envs.get_type()
|
env_type = envs.get_type()
|
||||||
if env_type == EnvType.CONTINUOUS:
|
if env_type == EnvType.CONTINUOUS:
|
||||||
match self.continuous_actor_type:
|
match self.continuous_actor_type:
|
||||||
case ContinuousActorType.GAUSSIAN:
|
case ContinuousActorType.GAUSSIAN:
|
||||||
factory = ActorFactoryContinuousGaussian(
|
factory = ActorFactoryContinuousGaussianNet(
|
||||||
self.hidden_sizes,
|
self.hidden_sizes,
|
||||||
unbounded=self.continuous_unbounded,
|
unbounded=self.continuous_unbounded,
|
||||||
conditioned_sigma=self.continuous_conditioned_sigma,
|
conditioned_sigma=self.continuous_conditioned_sigma,
|
||||||
)
|
)
|
||||||
case ContinuousActorType.DETERMINISTIC:
|
case ContinuousActorType.DETERMINISTIC:
|
||||||
factory = ActorFactoryContinuousDeterministic(self.hidden_sizes)
|
factory = ActorFactoryContinuousDeterministicNet(self.hidden_sizes)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(self.continuous_actor_type)
|
raise ValueError(self.continuous_actor_type)
|
||||||
return factory.create_module(envs, device)
|
return factory.create_module(envs, device)
|
||||||
@ -95,11 +79,11 @@ class ActorFactoryContinuous(ActorFactory, ABC):
|
|||||||
"""Serves as a type bound for actor factories that are suitable for continuous action spaces."""
|
"""Serves as a type bound for actor factories that are suitable for continuous action spaces."""
|
||||||
|
|
||||||
|
|
||||||
class ActorFactoryContinuousDeterministic(ActorFactoryContinuous):
|
class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous):
|
||||||
def __init__(self, hidden_sizes: Sequence[int]):
|
def __init__(self, hidden_sizes: Sequence[int]):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||||
net_a = Net(
|
net_a = Net(
|
||||||
envs.get_observation_shape(),
|
envs.get_observation_shape(),
|
||||||
hidden_sizes=self.hidden_sizes,
|
hidden_sizes=self.hidden_sizes,
|
||||||
@ -113,13 +97,13 @@ class ActorFactoryContinuousDeterministic(ActorFactoryContinuous):
|
|||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
|
|
||||||
class ActorFactoryContinuousGaussian(ActorFactoryContinuous):
|
class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
|
||||||
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
|
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
self.unbounded = unbounded
|
self.unbounded = unbounded
|
||||||
self.conditioned_sigma = conditioned_sigma
|
self.conditioned_sigma = conditioned_sigma
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||||
net_a = Net(
|
net_a = Net(
|
||||||
envs.get_observation_shape(),
|
envs.get_observation_shape(),
|
||||||
hidden_sizes=self.hidden_sizes,
|
hidden_sizes=self.hidden_sizes,
|
||||||
@ -142,97 +126,19 @@ class ActorFactoryContinuousGaussian(ActorFactoryContinuous):
|
|||||||
return actor
|
return actor
|
||||||
|
|
||||||
|
|
||||||
class CriticFactory(ABC):
|
class ActorFactoryDiscreteNet(ActorFactory):
|
||||||
@abstractmethod
|
|
||||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CriticFactoryDefault(CriticFactory):
|
|
||||||
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
|
|
||||||
|
|
||||||
DEFAULT_HIDDEN_SIZES = (64, 64)
|
|
||||||
|
|
||||||
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
|
|
||||||
self.hidden_sizes = hidden_sizes
|
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
|
||||||
env_type = envs.get_type()
|
|
||||||
if env_type == EnvType.CONTINUOUS:
|
|
||||||
factory = CriticFactoryContinuousNet(self.hidden_sizes)
|
|
||||||
return factory.create_module(envs, device, use_action)
|
|
||||||
elif env_type == EnvType.DISCRETE:
|
|
||||||
raise NotImplementedError
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{env_type} not supported")
|
|
||||||
|
|
||||||
|
|
||||||
class CriticFactoryContinuous(CriticFactory, ABC):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CriticFactoryContinuousNet(CriticFactoryContinuous):
|
|
||||||
def __init__(self, hidden_sizes: Sequence[int]):
|
def __init__(self, hidden_sizes: Sequence[int]):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||||
action_shape = envs.get_action_shape() if use_action else 0
|
net_a = Net(
|
||||||
net_c = Net(
|
|
||||||
envs.get_observation_shape(),
|
envs.get_observation_shape(),
|
||||||
action_shape=action_shape,
|
|
||||||
hidden_sizes=self.hidden_sizes,
|
hidden_sizes=self.hidden_sizes,
|
||||||
concat=use_action,
|
|
||||||
activation=nn.Tanh,
|
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
critic = continuous.Critic(net_c, device=device).to(device)
|
return discrete.Actor(
|
||||||
init_linear_orthogonal(critic)
|
net_a,
|
||||||
return critic
|
envs.get_action_shape(),
|
||||||
|
hidden_sizes=(),
|
||||||
|
device=device,
|
||||||
@dataclass
|
).to(device)
|
||||||
class ModuleOpt:
|
|
||||||
module: torch.nn.Module
|
|
||||||
optim: torch.optim.Optimizer
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ActorCriticModuleOpt:
|
|
||||||
actor_critic_module: ActorCritic
|
|
||||||
optim: torch.optim.Optimizer
|
|
||||||
|
|
||||||
@property
|
|
||||||
def actor(self):
|
|
||||||
return self.actor_critic_module.actor
|
|
||||||
|
|
||||||
@property
|
|
||||||
def critic(self):
|
|
||||||
return self.actor_critic_module.critic
|
|
||||||
|
|
||||||
|
|
||||||
class ActorModuleOptFactory:
|
|
||||||
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
|
||||||
self.actor_factory = actor_factory
|
|
||||||
self.optim_factory = optim_factory
|
|
||||||
|
|
||||||
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
||||||
actor = self.actor_factory.create_module(envs, device)
|
|
||||||
opt = self.optim_factory.create_optimizer(actor, lr)
|
|
||||||
return ModuleOpt(actor, opt)
|
|
||||||
|
|
||||||
|
|
||||||
class CriticModuleOptFactory:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
critic_factory: CriticFactory,
|
|
||||||
optim_factory: OptimizerFactory,
|
|
||||||
use_action: bool,
|
|
||||||
):
|
|
||||||
self.critic_factory = critic_factory
|
|
||||||
self.optim_factory = optim_factory
|
|
||||||
self.use_action = use_action
|
|
||||||
|
|
||||||
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
||||||
critic = self.critic_factory.create_module(envs, device, self.use_action)
|
|
||||||
opt = self.optim_factory.create_optimizer(critic, lr)
|
|
||||||
return ModuleOpt(critic, opt)
|
|
||||||
44
tianshou/highlevel/module/core.py
Normal file
44
tianshou/highlevel/module/core.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TypeAlias
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tianshou.highlevel.env import Environments
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
|
|
||||||
|
TDevice: TypeAlias = str | int | torch.device
|
||||||
|
|
||||||
|
|
||||||
|
def init_linear_orthogonal(module: torch.nn.Module):
|
||||||
|
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
|
||||||
|
|
||||||
|
:param module: the module whose submodules are to be processed
|
||||||
|
"""
|
||||||
|
for m in module.modules():
|
||||||
|
if isinstance(m, torch.nn.Linear):
|
||||||
|
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||||
|
torch.nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Module:
|
||||||
|
module: torch.nn.Module
|
||||||
|
output_dim: int
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleFactoryNet(ModuleFactory):
|
||||||
|
def __init__(self, hidden_sizes: int | Sequence[int]):
|
||||||
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
|
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
||||||
|
module = Net(envs.get_observation_shape())
|
||||||
|
return Module(module, module.output_dim)
|
||||||
57
tianshou/highlevel/module/critic.py
Normal file
57
tianshou/highlevel/module/critic.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from tianshou.highlevel.env import Environments, EnvType
|
||||||
|
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
||||||
|
from tianshou.utils.net import continuous
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
|
|
||||||
|
|
||||||
|
class CriticFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CriticFactoryDefault(CriticFactory):
|
||||||
|
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
|
||||||
|
|
||||||
|
DEFAULT_HIDDEN_SIZES = (64, 64)
|
||||||
|
|
||||||
|
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
|
||||||
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||||
|
env_type = envs.get_type()
|
||||||
|
if env_type == EnvType.CONTINUOUS:
|
||||||
|
factory = CriticFactoryContinuousNet(self.hidden_sizes)
|
||||||
|
return factory.create_module(envs, device, use_action)
|
||||||
|
elif env_type == EnvType.DISCRETE:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{env_type} not supported")
|
||||||
|
|
||||||
|
|
||||||
|
class CriticFactoryContinuous(CriticFactory, ABC):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CriticFactoryContinuousNet(CriticFactoryContinuous):
|
||||||
|
def __init__(self, hidden_sizes: Sequence[int]):
|
||||||
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||||
|
action_shape = envs.get_action_shape() if use_action else 0
|
||||||
|
net_c = Net(
|
||||||
|
envs.get_observation_shape(),
|
||||||
|
action_shape=action_shape,
|
||||||
|
hidden_sizes=self.hidden_sizes,
|
||||||
|
concat=use_action,
|
||||||
|
activation=nn.Tanh,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
critic = continuous.Critic(net_c, device=device).to(device)
|
||||||
|
init_linear_orthogonal(critic)
|
||||||
|
return critic
|
||||||
58
tianshou/highlevel/module/module_opt.py
Normal file
58
tianshou/highlevel/module/module_opt.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tianshou.highlevel.env import Environments
|
||||||
|
from tianshou.highlevel.module.actor import ActorFactory
|
||||||
|
from tianshou.highlevel.module.core import TDevice
|
||||||
|
from tianshou.highlevel.module.critic import CriticFactory
|
||||||
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
|
from tianshou.utils.net.common import ActorCritic
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModuleOpt:
|
||||||
|
module: torch.nn.Module
|
||||||
|
optim: torch.optim.Optimizer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ActorCriticModuleOpt:
|
||||||
|
actor_critic_module: ActorCritic
|
||||||
|
optim: torch.optim.Optimizer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def actor(self):
|
||||||
|
return self.actor_critic_module.actor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def critic(self):
|
||||||
|
return self.actor_critic_module.critic
|
||||||
|
|
||||||
|
|
||||||
|
class ActorModuleOptFactory:
|
||||||
|
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
||||||
|
self.actor_factory = actor_factory
|
||||||
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
|
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||||
|
actor = self.actor_factory.create_module(envs, device)
|
||||||
|
opt = self.optim_factory.create_optimizer(actor, lr)
|
||||||
|
return ModuleOpt(actor, opt)
|
||||||
|
|
||||||
|
|
||||||
|
class CriticModuleOptFactory:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
critic_factory: CriticFactory,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
use_action: bool,
|
||||||
|
):
|
||||||
|
self.critic_factory = critic_factory
|
||||||
|
self.optim_factory = optim_factory
|
||||||
|
self.use_action = use_action
|
||||||
|
|
||||||
|
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||||
|
critic = self.critic_factory.create_module(envs, device, self.use_action)
|
||||||
|
opt = self.optim_factory.create_optimizer(critic, lr)
|
||||||
|
return ModuleOpt(critic, opt)
|
||||||
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.module import TDevice
|
from tianshou.highlevel.module.core import TDevice
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,8 @@ import torch
|
|||||||
|
|
||||||
from tianshou.exploration import BaseNoise
|
from tianshou.exploration import BaseNoise
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.module import ModuleOpt, TDevice
|
from tianshou.highlevel.module.core import TDevice
|
||||||
|
from tianshou.highlevel.module.module_opt import ModuleOpt
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
||||||
from tianshou.highlevel.params.dist_fn import (
|
from tianshou.highlevel.params.dist_fn import (
|
||||||
@ -66,6 +67,18 @@ class ParamTransformerDrop(ParamTransformer):
|
|||||||
del kwargs[k]
|
del kwargs[k]
|
||||||
|
|
||||||
|
|
||||||
|
class ParamTransformerChangeValue(ParamTransformer):
|
||||||
|
def __init__(self, key: str):
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
def transform(self, params: dict[str, Any], data: ParamTransformerData):
|
||||||
|
params[self.key] = self.change_value(params[self.key], data)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ParamTransformerLRScheduler(ParamTransformer):
|
class ParamTransformerLRScheduler(ParamTransformer):
|
||||||
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
|
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
|
||||||
a learning rate scheduler (added) for the data member `optim`.
|
a learning rate scheduler (added) for the data member `optim`.
|
||||||
@ -182,6 +195,14 @@ class ParamTransformerDistributionFunction(ParamTransformer):
|
|||||||
kwargs[self.key] = value.create_dist_fn(data.envs)
|
kwargs[self.key] = value.create_dist_fn(data.envs)
|
||||||
|
|
||||||
|
|
||||||
|
class ParamTransformerActionScaling(ParamTransformerChangeValue):
|
||||||
|
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
||||||
|
if value == "default":
|
||||||
|
return data.envs.get_type().is_continuous()
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class GetParamTransformersProtocol(Protocol):
|
class GetParamTransformersProtocol(Protocol):
|
||||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
pass
|
pass
|
||||||
@ -218,9 +239,15 @@ class PGParams(Params):
|
|||||||
discount_factor: float = 0.99
|
discount_factor: float = 0.99
|
||||||
reward_normalization: bool = False
|
reward_normalization: bool = False
|
||||||
deterministic_eval: bool = False
|
deterministic_eval: bool = False
|
||||||
action_scaling: bool = True
|
action_scaling: bool | Literal["default"] = "default"
|
||||||
|
"""whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces"""
|
||||||
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
||||||
|
|
||||||
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
|
transformers = super()._get_param_transformers()
|
||||||
|
transformers.append(ParamTransformerActionScaling("action_scaling"))
|
||||||
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class A2CParams(PGParams, ParamsMixinLearningRateWithScheduler):
|
class A2CParams(PGParams, ParamsMixinLearningRateWithScheduler):
|
||||||
|
|||||||
72
tianshou/highlevel/params/policy_wrapper.py
Normal file
72
tianshou/highlevel/params/policy_wrapper.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
from tianshou.highlevel.env import Environments
|
||||||
|
from tianshou.highlevel.module.core import ModuleFactory, TDevice
|
||||||
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
|
from tianshou.policy import BasePolicy, ICMPolicy
|
||||||
|
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||||
|
|
||||||
|
TPolicyIn = TypeVar("TPolicyIn", bound=BasePolicy)
|
||||||
|
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create_wrapped_policy(
|
||||||
|
self,
|
||||||
|
policy: TPolicyIn,
|
||||||
|
envs: Environments,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
device: TDevice,
|
||||||
|
) -> TPolicyOut:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyWrapperFactoryIntrinsicCuriosity(
|
||||||
|
Generic[TPolicyIn], PolicyWrapperFactory[TPolicyIn, ICMPolicy],
|
||||||
|
):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
feature_net_factory: ModuleFactory,
|
||||||
|
hidden_sizes: Sequence[int],
|
||||||
|
lr: float,
|
||||||
|
lr_scale: float,
|
||||||
|
reward_scale: float,
|
||||||
|
forward_loss_weight,
|
||||||
|
):
|
||||||
|
self.feature_net_factory = feature_net_factory
|
||||||
|
self.hidden_sizes = hidden_sizes
|
||||||
|
self.lr = lr
|
||||||
|
self.lr_scale = lr_scale
|
||||||
|
self.reward_scale = reward_scale
|
||||||
|
self.forward_loss_weight = forward_loss_weight
|
||||||
|
|
||||||
|
def create_wrapped_policy(
|
||||||
|
self,
|
||||||
|
policy: TPolicyIn,
|
||||||
|
envs: Environments,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
device: TDevice,
|
||||||
|
) -> ICMPolicy:
|
||||||
|
feature_net = self.feature_net_factory.create_module(envs, device)
|
||||||
|
action_dim = envs.get_action_shape()
|
||||||
|
feature_dim = feature_net.output_dim
|
||||||
|
icm_net = IntrinsicCuriosityModule(
|
||||||
|
feature_net.module,
|
||||||
|
feature_dim,
|
||||||
|
action_dim,
|
||||||
|
hidden_sizes=self.hidden_sizes,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
icm_optim = optim_factory.create_optimizer(icm_net, lr=self.lr)
|
||||||
|
return ICMPolicy(
|
||||||
|
policy=policy,
|
||||||
|
model=icm_net,
|
||||||
|
optim=icm_optim,
|
||||||
|
action_space=envs.get_action_space(),
|
||||||
|
lr_scale=self.lr_scale,
|
||||||
|
reward_scale=self.reward_scale,
|
||||||
|
forward_loss_weight=self.forward_loss_weight,
|
||||||
|
).to(device)
|
||||||
Loading…
x
Reference in New Issue
Block a user