Add support for discrete PPO
* Refactored module `module` (split into submodules) * Basic support for discrete environments * Implement Atari env. factory * Implement DQN-based actor factory * Implement notion of reusing agent preprocessing network for critic * Add example atari_ppo_hl
This commit is contained in:
parent
e0e7349b0a
commit
6b6d9ea609
@ -5,7 +5,11 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tianshou.utils.net.discrete import NoisyLinear
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module.actor import ActorFactory
|
||||
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice
|
||||
from tianshou.utils.net.common import BaseActor
|
||||
from tianshou.utils.net.discrete import Actor, NoisyLinear
|
||||
|
||||
|
||||
def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module:
|
||||
@ -220,3 +224,29 @@ class QRDQN(DQN):
|
||||
obs, state = super().forward(obs)
|
||||
obs = obs.view(-1, self.action_num, self.num_quantiles)
|
||||
return obs, state
|
||||
|
||||
|
||||
class ActorFactoryAtariDQN(ActorFactory):
|
||||
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool):
|
||||
self.hidden_size = hidden_size
|
||||
self.scale_obs = scale_obs
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||
net_cls = scale_obs(DQN) if self.scale_obs else DQN
|
||||
net = net_cls(
|
||||
*envs.get_observation_shape(),
|
||||
envs.get_action_shape(),
|
||||
device=device,
|
||||
features_only=True,
|
||||
output_dim=self.hidden_size,
|
||||
layer_init=layer_init,
|
||||
)
|
||||
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
|
||||
|
||||
|
||||
class FeatureNetFactoryDQN(ModuleFactory):
|
||||
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
||||
dqn = DQN(
|
||||
*envs.get_observation_shape(), envs.get_action_shape(), device, features_only=True,
|
||||
)
|
||||
return Module(dqn.net, dqn.output_dim)
|
||||
|
||||
117
examples/atari/atari_ppo_hl.py
Normal file
117
examples/atari/atari_ppo_hl.py
Normal file
@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import datetime
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.atari.atari_network import (
|
||||
ActorFactoryAtariDQN,
|
||||
FeatureNetFactoryDQN,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
PPOExperimentBuilder,
|
||||
RLExperimentConfig,
|
||||
)
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||
from tianshou.highlevel.params.policy_params import PPOParams
|
||||
from tianshou.highlevel.params.policy_wrapper import (
|
||||
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: RLExperimentConfig,
|
||||
task: str = "PongNoFrameskip-v4",
|
||||
scale_obs: bool = True,
|
||||
buffer_size: int = 100000,
|
||||
lr: float = 2.5e-4,
|
||||
gamma: float = 0.99,
|
||||
epoch: int = 100,
|
||||
step_per_epoch: int = 100000,
|
||||
step_per_collect: int = 1000,
|
||||
repeat_per_collect: int = 4,
|
||||
batch_size: int = 256,
|
||||
hidden_sizes: int | Sequence[int] = 512,
|
||||
training_num: int = 10,
|
||||
test_num: int = 10,
|
||||
rew_norm: bool = False,
|
||||
vf_coef: float = 0.25,
|
||||
ent_coef: float = 0.01,
|
||||
gae_lambda: float = 0.95,
|
||||
lr_decay: bool = True,
|
||||
max_grad_norm: float = 0.5,
|
||||
eps_clip: float = 0.1,
|
||||
dual_clip: float | None = None,
|
||||
value_clip: bool = True,
|
||||
norm_adv: bool = True,
|
||||
recompute_adv: bool = False,
|
||||
frames_stack: int = 4,
|
||||
save_buffer_name: str | None = None, # TODO add support in high-level API?
|
||||
icm_lr_scale: float = 0.0,
|
||||
icm_reward_scale: float = 0.01,
|
||||
icm_forward_loss_weight: float = 0.2,
|
||||
):
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||
|
||||
sampling_config = RLSamplingConfig(
|
||||
num_epochs=epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
batch_size=batch_size,
|
||||
num_train_envs=training_num,
|
||||
num_test_envs=test_num,
|
||||
buffer_size=buffer_size,
|
||||
step_per_collect=step_per_collect,
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
replay_buffer_stack_num=frames_stack,
|
||||
replay_buffer_ignore_obs_next=True,
|
||||
replay_buffer_save_only_last_obs=True,
|
||||
)
|
||||
|
||||
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
|
||||
|
||||
builder = (
|
||||
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||
.with_ppo_params(
|
||||
PPOParams(
|
||||
discount_factor=gamma,
|
||||
gae_lambda=gae_lambda,
|
||||
reward_normalization=rew_norm,
|
||||
ent_coef=ent_coef,
|
||||
vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm,
|
||||
value_clip=value_clip,
|
||||
advantage_normalization=norm_adv,
|
||||
eps_clip=eps_clip,
|
||||
dual_clip=dual_clip,
|
||||
recompute_advantage=recompute_adv,
|
||||
lr=lr,
|
||||
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
|
||||
if lr_decay
|
||||
else None,
|
||||
),
|
||||
)
|
||||
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs))
|
||||
.with_critic_factory_use_actor()
|
||||
)
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
FeatureNetFactoryDQN(),
|
||||
[hidden_sizes],
|
||||
lr,
|
||||
icm_lr_scale,
|
||||
icm_reward_scale,
|
||||
icm_forward_loss_weight,
|
||||
),
|
||||
)
|
||||
experiment = builder.build()
|
||||
experiment.run(log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI(main)
|
||||
@ -9,6 +9,8 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
|
||||
|
||||
try:
|
||||
import envpool
|
||||
@ -369,3 +371,22 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
|
||||
train_envs.seed(seed)
|
||||
test_envs.seed(seed)
|
||||
return env, train_envs, test_envs
|
||||
|
||||
|
||||
class AtariEnvFactory(EnvFactory):
|
||||
def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig, frame_stack: int):
|
||||
self.task = task
|
||||
self.sampling_config = sampling_config
|
||||
self.seed = seed
|
||||
self.frame_stack = frame_stack
|
||||
|
||||
def create_envs(self, config=None) -> DiscreteEnvironments:
|
||||
env, train_envs, test_envs = make_atari_env(
|
||||
task=self.task,
|
||||
seed=self.seed,
|
||||
training_num=self.sampling_config.num_train_envs,
|
||||
test_num=self.sampling_config.num_test_envs,
|
||||
scale=0,
|
||||
frame_stack=self.frame_stack,
|
||||
)
|
||||
return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Generic, TypeVar
|
||||
@ -9,14 +8,16 @@ from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.logger import Logger
|
||||
from tianshou.highlevel.module import (
|
||||
ActorCriticModuleOpt,
|
||||
from tianshou.highlevel.module.actor import (
|
||||
ActorFactory,
|
||||
)
|
||||
from tianshou.highlevel.module.core import TDevice
|
||||
from tianshou.highlevel.module.critic import CriticFactory
|
||||
from tianshou.highlevel.module.module_opt import (
|
||||
ActorCriticModuleOpt,
|
||||
ActorModuleOptFactory,
|
||||
CriticFactory,
|
||||
CriticModuleOptFactory,
|
||||
ModuleOpt,
|
||||
TDevice,
|
||||
)
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.highlevel.params.policy_params import (
|
||||
@ -27,8 +28,10 @@ from tianshou.highlevel.params.policy_params import (
|
||||
SACParams,
|
||||
TD3Params,
|
||||
)
|
||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||
from tianshou.policy import A2CPolicy, BasePolicy, PPOPolicy, SACPolicy, TD3Policy
|
||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||
from tianshou.utils.net import continuous, discrete
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
|
||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||
@ -38,34 +41,62 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
|
||||
|
||||
class AgentFactory(ABC):
|
||||
def __init__(self, sampling_config: RLSamplingConfig):
|
||||
def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory):
|
||||
self.sampling_config = sampling_config
|
||||
self.optim_factory = optim_factory
|
||||
self.policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||
|
||||
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
|
||||
buffer_size = self.sampling_config.buffer_size
|
||||
train_envs = envs.train_envs
|
||||
if len(train_envs) > 1:
|
||||
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
|
||||
buffer = VectorReplayBuffer(
|
||||
buffer_size,
|
||||
len(train_envs),
|
||||
stack_num=self.sampling_config.replay_buffer_stack_num,
|
||||
save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs,
|
||||
ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next,
|
||||
)
|
||||
else:
|
||||
buffer = ReplayBuffer(buffer_size)
|
||||
buffer = ReplayBuffer(
|
||||
buffer_size,
|
||||
stack_num=self.sampling_config.replay_buffer_stack_num,
|
||||
save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs,
|
||||
ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next,
|
||||
)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, envs.test_envs)
|
||||
if self.sampling_config.start_timesteps > 0:
|
||||
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=True)
|
||||
return train_collector, test_collector
|
||||
|
||||
def set_policy_wrapper_factory(
|
||||
self, policy_wrapper_factory: PolicyWrapperFactory | None,
|
||||
) -> None:
|
||||
self.policy_wrapper_factory = policy_wrapper_factory
|
||||
|
||||
@abstractmethod
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
pass
|
||||
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
policy = self._create_policy(envs, device)
|
||||
if self.policy_wrapper_factory is not None:
|
||||
policy = self.policy_wrapper_factory.create_wrapped_policy(
|
||||
policy, envs, self.optim_factory, device,
|
||||
)
|
||||
return policy
|
||||
|
||||
@staticmethod
|
||||
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
|
||||
def save_best_fn(pol: torch.nn.Module) -> None:
|
||||
state = {
|
||||
CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
|
||||
CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
|
||||
}
|
||||
torch.save(state, os.path.join(log_path, "policy.pth"))
|
||||
pass
|
||||
# TODO: Fix saving in general (code works only for mujoco)
|
||||
# state = {
|
||||
# CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
|
||||
# CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
|
||||
# }
|
||||
# torch.save(state, os.path.join(log_path, "policy.pth"))
|
||||
|
||||
return save_best_fn
|
||||
|
||||
@ -160,11 +191,13 @@ class _ActorCriticMixin:
|
||||
critic_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
critic_use_action: bool,
|
||||
critic_use_actor_module: bool,
|
||||
):
|
||||
self.actor_factory = actor_factory
|
||||
self.critic_factory = critic_factory
|
||||
self.optim_factory = optim_factory
|
||||
self.critic_use_action = critic_use_action
|
||||
self.critic_use_actor_module = critic_use_actor_module
|
||||
|
||||
def create_actor_critic_module_opt(
|
||||
self,
|
||||
@ -173,7 +206,23 @@ class _ActorCriticMixin:
|
||||
lr: float,
|
||||
) -> ActorCriticModuleOpt:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
|
||||
if self.critic_use_actor_module:
|
||||
if self.critic_use_action:
|
||||
raise ValueError(
|
||||
"The options critic_use_actor_module and critic_use_action are mutually exclusive",
|
||||
)
|
||||
if envs.get_type().is_discrete():
|
||||
critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device)
|
||||
elif envs.get_type().is_continuous():
|
||||
critic = continuous.Critic(actor.get_preprocess_net(), device=device).to(device)
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
critic = self.critic_factory.create_module(
|
||||
envs,
|
||||
device,
|
||||
use_action=self.critic_use_action,
|
||||
)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
||||
return ActorCriticModuleOpt(actor_critic, optim)
|
||||
@ -237,14 +286,16 @@ class ActorCriticAgentFactory(
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
policy_class: type[TPolicy],
|
||||
critic_use_actor_module: bool,
|
||||
):
|
||||
super().__init__(sampling_config)
|
||||
super().__init__(sampling_config, optim_factory=optimizer_factory)
|
||||
_ActorCriticMixin.__init__(
|
||||
self,
|
||||
actor_factory,
|
||||
critic_factory,
|
||||
optimizer_factory,
|
||||
critic_use_action=False,
|
||||
critic_use_actor_module=critic_use_actor_module,
|
||||
)
|
||||
self.params = params
|
||||
self.policy_class = policy_class
|
||||
@ -269,7 +320,7 @@ class ActorCriticAgentFactory(
|
||||
kwargs["action_space"] = envs.get_action_space()
|
||||
return kwargs
|
||||
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||
return self.policy_class(**self._create_kwargs(envs, device))
|
||||
|
||||
|
||||
@ -281,6 +332,7 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
critic_use_actor_module: bool,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
@ -289,6 +341,7 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
|
||||
critic_factory,
|
||||
optimizer_factory,
|
||||
A2CPolicy,
|
||||
critic_use_actor_module,
|
||||
)
|
||||
|
||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||
@ -303,6 +356,7 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
critic_use_actor_module: bool,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
@ -311,6 +365,7 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
||||
critic_factory,
|
||||
optimizer_factory,
|
||||
PPOPolicy,
|
||||
critic_use_actor_module,
|
||||
)
|
||||
|
||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||
@ -327,7 +382,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
critic2_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(sampling_config)
|
||||
super().__init__(sampling_config, optim_factory)
|
||||
_ActorAndDualCriticsMixin.__init__(
|
||||
self,
|
||||
actor_factory,
|
||||
@ -339,7 +394,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
self.params = params
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
||||
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
||||
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
||||
@ -376,7 +431,7 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
critic2_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(sampling_config)
|
||||
super().__init__(sampling_config, optim_factory)
|
||||
_ActorAndDualCriticsMixin.__init__(
|
||||
self,
|
||||
actor_factory,
|
||||
@ -388,7 +443,7 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
self.params = params
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
||||
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
||||
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
||||
|
||||
@ -17,3 +17,7 @@ class RLSamplingConfig:
|
||||
update_per_step: int = 1
|
||||
start_timesteps: int = 0
|
||||
start_timesteps_random: bool = False
|
||||
# TODO can we set the parameters below more intelligently? Perhaps based on env. representation?
|
||||
replay_buffer_ignore_obs_next: bool = False
|
||||
replay_buffer_save_only_last_obs: bool = False
|
||||
replay_buffer_stack_num: int = 1
|
||||
|
||||
@ -97,6 +97,22 @@ class ContinuousEnvironments(Environments):
|
||||
return EnvType.CONTINUOUS
|
||||
|
||||
|
||||
class DiscreteEnvironments(Environments):
|
||||
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
super().__init__(env, train_envs, test_envs)
|
||||
self.observation_shape = env.observation_space.shape or env.observation_space.n
|
||||
self.action_shape = env.action_space.shape or env.action_space.n
|
||||
|
||||
def get_action_shape(self) -> TShape:
|
||||
return self.action_shape
|
||||
|
||||
def get_observation_shape(self) -> TShape:
|
||||
return self.observation_shape
|
||||
|
||||
def get_type(self) -> EnvType:
|
||||
return EnvType.DISCRETE
|
||||
|
||||
|
||||
class EnvFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
||||
|
||||
@ -18,13 +18,12 @@ from tianshou.highlevel.agent import (
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.env import EnvFactory, Environments
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
||||
from tianshou.highlevel.module import (
|
||||
from tianshou.highlevel.module.actor import (
|
||||
ActorFactory,
|
||||
ActorFactoryDefault,
|
||||
ContinuousActorType,
|
||||
CriticFactory,
|
||||
CriticFactoryDefault,
|
||||
)
|
||||
from tianshou.highlevel.module.critic import CriticFactory, CriticFactoryDefault
|
||||
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
||||
from tianshou.highlevel.params.policy_params import (
|
||||
A2CParams,
|
||||
@ -32,6 +31,7 @@ from tianshou.highlevel.params.policy_params import (
|
||||
SACParams,
|
||||
TD3Params,
|
||||
)
|
||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.trainer import BaseTrainer
|
||||
@ -154,6 +154,7 @@ class RLExperimentBuilder:
|
||||
self._logger_factory: LoggerFactory | None = None
|
||||
self._optim_factory: OptimizerFactory | None = None
|
||||
self._env_config: PersistableConfigProtocol | None = None
|
||||
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||
|
||||
def with_env_config(self, config: PersistableConfigProtocol) -> Self:
|
||||
self._env_config = config
|
||||
@ -163,6 +164,10 @@ class RLExperimentBuilder:
|
||||
self._logger_factory = logger_factory
|
||||
return self
|
||||
|
||||
def with_policy_wrapper_factory(self, policy_wrapper_factory: PolicyWrapperFactory) -> Self:
|
||||
self._policy_wrapper_factory = policy_wrapper_factory
|
||||
return self
|
||||
|
||||
def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder:
|
||||
self._optim_factory = optim_factory
|
||||
return self
|
||||
@ -194,10 +199,13 @@ class RLExperimentBuilder:
|
||||
return self._optim_factory
|
||||
|
||||
def build(self) -> RLExperiment:
|
||||
agent_factory = self._create_agent_factory()
|
||||
if self._policy_wrapper_factory:
|
||||
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
||||
return RLExperiment(
|
||||
self._config,
|
||||
self._env_factory,
|
||||
self._create_agent_factory(),
|
||||
agent_factory,
|
||||
self._logger_factory,
|
||||
env_config=self._env_config,
|
||||
)
|
||||
@ -287,6 +295,7 @@ class _BuilderMixinCriticsFactory:
|
||||
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||
def __init__(self):
|
||||
super().__init__(1)
|
||||
self._critic_use_actor_module = False
|
||||
|
||||
def with_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinSingleCriticFactory"
|
||||
@ -301,6 +310,11 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
return self
|
||||
|
||||
def with_critic_factory_use_actor(self) -> Self:
|
||||
"""Makes the critic use the same network as the actor."""
|
||||
self._critic_use_actor_module = True
|
||||
return self
|
||||
|
||||
|
||||
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
def __init__(self):
|
||||
@ -378,6 +392,7 @@ class A2CExperimentBuilder(
|
||||
self._get_actor_factory(),
|
||||
self._get_critic_factory(0),
|
||||
self._get_optim_factory(),
|
||||
self._critic_use_actor_module,
|
||||
)
|
||||
|
||||
|
||||
@ -411,6 +426,7 @@ class PPOExperimentBuilder(
|
||||
self._get_actor_factory(),
|
||||
self._get_critic_factory(0),
|
||||
self._get_optim_factory(),
|
||||
self._critic_use_actor_module,
|
||||
)
|
||||
|
||||
|
||||
|
||||
0
tianshou/highlevel/module/__init__.py
Normal file
0
tianshou/highlevel/module/__init__.py
Normal file
@ -1,29 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tianshou.highlevel.env import Environments, EnvType
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.utils.net import continuous
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
|
||||
TDevice: TypeAlias = str | int | torch.device
|
||||
|
||||
|
||||
def init_linear_orthogonal(module: torch.nn.Module):
|
||||
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
|
||||
|
||||
:param module: the module whose submodules are to be processed
|
||||
"""
|
||||
for m in module.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
||||
from tianshou.utils.net import continuous, discrete
|
||||
from tianshou.utils.net.common import BaseActor, Net
|
||||
|
||||
|
||||
class ContinuousActorType:
|
||||
@ -33,7 +17,7 @@ class ContinuousActorType:
|
||||
|
||||
class ActorFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@ -70,18 +54,18 @@ class ActorFactoryDefault(ActorFactory):
|
||||
self.continuous_conditioned_sigma = continuous_conditioned_sigma
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||
env_type = envs.get_type()
|
||||
if env_type == EnvType.CONTINUOUS:
|
||||
match self.continuous_actor_type:
|
||||
case ContinuousActorType.GAUSSIAN:
|
||||
factory = ActorFactoryContinuousGaussian(
|
||||
factory = ActorFactoryContinuousGaussianNet(
|
||||
self.hidden_sizes,
|
||||
unbounded=self.continuous_unbounded,
|
||||
conditioned_sigma=self.continuous_conditioned_sigma,
|
||||
)
|
||||
case ContinuousActorType.DETERMINISTIC:
|
||||
factory = ActorFactoryContinuousDeterministic(self.hidden_sizes)
|
||||
factory = ActorFactoryContinuousDeterministicNet(self.hidden_sizes)
|
||||
case _:
|
||||
raise ValueError(self.continuous_actor_type)
|
||||
return factory.create_module(envs, device)
|
||||
@ -95,11 +79,11 @@ class ActorFactoryContinuous(ActorFactory, ABC):
|
||||
"""Serves as a type bound for actor factories that are suitable for continuous action spaces."""
|
||||
|
||||
|
||||
class ActorFactoryContinuousDeterministic(ActorFactoryContinuous):
|
||||
class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous):
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||
net_a = Net(
|
||||
envs.get_observation_shape(),
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
@ -113,13 +97,13 @@ class ActorFactoryContinuousDeterministic(ActorFactoryContinuous):
|
||||
).to(device)
|
||||
|
||||
|
||||
class ActorFactoryContinuousGaussian(ActorFactoryContinuous):
|
||||
class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
|
||||
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.unbounded = unbounded
|
||||
self.conditioned_sigma = conditioned_sigma
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||
net_a = Net(
|
||||
envs.get_observation_shape(),
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
@ -142,97 +126,19 @@ class ActorFactoryContinuousGaussian(ActorFactoryContinuous):
|
||||
return actor
|
||||
|
||||
|
||||
class CriticFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
pass
|
||||
|
||||
|
||||
class CriticFactoryDefault(CriticFactory):
|
||||
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
|
||||
|
||||
DEFAULT_HIDDEN_SIZES = (64, 64)
|
||||
|
||||
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
env_type = envs.get_type()
|
||||
if env_type == EnvType.CONTINUOUS:
|
||||
factory = CriticFactoryContinuousNet(self.hidden_sizes)
|
||||
return factory.create_module(envs, device, use_action)
|
||||
elif env_type == EnvType.DISCRETE:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(f"{env_type} not supported")
|
||||
|
||||
|
||||
class CriticFactoryContinuous(CriticFactory, ABC):
|
||||
pass
|
||||
|
||||
|
||||
class CriticFactoryContinuousNet(CriticFactoryContinuous):
|
||||
class ActorFactoryDiscreteNet(ActorFactory):
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
action_shape = envs.get_action_shape() if use_action else 0
|
||||
net_c = Net(
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||
net_a = Net(
|
||||
envs.get_observation_shape(),
|
||||
action_shape=action_shape,
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
concat=use_action,
|
||||
activation=nn.Tanh,
|
||||
device=device,
|
||||
)
|
||||
critic = continuous.Critic(net_c, device=device).to(device)
|
||||
init_linear_orthogonal(critic)
|
||||
return critic
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleOpt:
|
||||
module: torch.nn.Module
|
||||
optim: torch.optim.Optimizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorCriticModuleOpt:
|
||||
actor_critic_module: ActorCritic
|
||||
optim: torch.optim.Optimizer
|
||||
|
||||
@property
|
||||
def actor(self):
|
||||
return self.actor_critic_module.actor
|
||||
|
||||
@property
|
||||
def critic(self):
|
||||
return self.actor_critic_module.critic
|
||||
|
||||
|
||||
class ActorModuleOptFactory:
|
||||
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
||||
self.actor_factory = actor_factory
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
opt = self.optim_factory.create_optimizer(actor, lr)
|
||||
return ModuleOpt(actor, opt)
|
||||
|
||||
|
||||
class CriticModuleOptFactory:
|
||||
def __init__(
|
||||
self,
|
||||
critic_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
use_action: bool,
|
||||
):
|
||||
self.critic_factory = critic_factory
|
||||
self.optim_factory = optim_factory
|
||||
self.use_action = use_action
|
||||
|
||||
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||
critic = self.critic_factory.create_module(envs, device, self.use_action)
|
||||
opt = self.optim_factory.create_optimizer(critic, lr)
|
||||
return ModuleOpt(critic, opt)
|
||||
return discrete.Actor(
|
||||
net_a,
|
||||
envs.get_action_shape(),
|
||||
hidden_sizes=(),
|
||||
device=device,
|
||||
).to(device)
|
||||
44
tianshou/highlevel/module/core.py
Normal file
44
tianshou/highlevel/module/core.py
Normal file
@ -0,0 +1,44 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
TDevice: TypeAlias = str | int | torch.device
|
||||
|
||||
|
||||
def init_linear_orthogonal(module: torch.nn.Module):
|
||||
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
|
||||
|
||||
:param module: the module whose submodules are to be processed
|
||||
"""
|
||||
for m in module.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Module:
|
||||
module: torch.nn.Module
|
||||
output_dim: int
|
||||
|
||||
|
||||
class ModuleFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
||||
pass
|
||||
|
||||
|
||||
class ModuleFactoryNet(ModuleFactory):
|
||||
def __init__(self, hidden_sizes: int | Sequence[int]):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
||||
module = Net(envs.get_observation_shape())
|
||||
return Module(module, module.output_dim)
|
||||
57
tianshou/highlevel/module/critic.py
Normal file
57
tianshou/highlevel/module/critic.py
Normal file
@ -0,0 +1,57 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
|
||||
from torch import nn
|
||||
|
||||
from tianshou.highlevel.env import Environments, EnvType
|
||||
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
||||
from tianshou.utils.net import continuous
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
|
||||
class CriticFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
pass
|
||||
|
||||
|
||||
class CriticFactoryDefault(CriticFactory):
|
||||
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
|
||||
|
||||
DEFAULT_HIDDEN_SIZES = (64, 64)
|
||||
|
||||
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
env_type = envs.get_type()
|
||||
if env_type == EnvType.CONTINUOUS:
|
||||
factory = CriticFactoryContinuousNet(self.hidden_sizes)
|
||||
return factory.create_module(envs, device, use_action)
|
||||
elif env_type == EnvType.DISCRETE:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(f"{env_type} not supported")
|
||||
|
||||
|
||||
class CriticFactoryContinuous(CriticFactory, ABC):
|
||||
pass
|
||||
|
||||
|
||||
class CriticFactoryContinuousNet(CriticFactoryContinuous):
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
action_shape = envs.get_action_shape() if use_action else 0
|
||||
net_c = Net(
|
||||
envs.get_observation_shape(),
|
||||
action_shape=action_shape,
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
concat=use_action,
|
||||
activation=nn.Tanh,
|
||||
device=device,
|
||||
)
|
||||
critic = continuous.Critic(net_c, device=device).to(device)
|
||||
init_linear_orthogonal(critic)
|
||||
return critic
|
||||
58
tianshou/highlevel/module/module_opt.py
Normal file
58
tianshou/highlevel/module/module_opt.py
Normal file
@ -0,0 +1,58 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module.actor import ActorFactory
|
||||
from tianshou.highlevel.module.core import TDevice
|
||||
from tianshou.highlevel.module.critic import CriticFactory
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleOpt:
|
||||
module: torch.nn.Module
|
||||
optim: torch.optim.Optimizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorCriticModuleOpt:
|
||||
actor_critic_module: ActorCritic
|
||||
optim: torch.optim.Optimizer
|
||||
|
||||
@property
|
||||
def actor(self):
|
||||
return self.actor_critic_module.actor
|
||||
|
||||
@property
|
||||
def critic(self):
|
||||
return self.actor_critic_module.critic
|
||||
|
||||
|
||||
class ActorModuleOptFactory:
|
||||
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
||||
self.actor_factory = actor_factory
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
opt = self.optim_factory.create_optimizer(actor, lr)
|
||||
return ModuleOpt(actor, opt)
|
||||
|
||||
|
||||
class CriticModuleOptFactory:
|
||||
def __init__(
|
||||
self,
|
||||
critic_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
use_action: bool,
|
||||
):
|
||||
self.critic_factory = critic_factory
|
||||
self.optim_factory = optim_factory
|
||||
self.use_action = use_action
|
||||
|
||||
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||
critic = self.critic_factory.create_module(envs, device, self.use_action)
|
||||
opt = self.optim_factory.create_optimizer(critic, lr)
|
||||
return ModuleOpt(critic, opt)
|
||||
@ -4,7 +4,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module import TDevice
|
||||
from tianshou.highlevel.module.core import TDevice
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
|
||||
|
||||
|
||||
@ -6,7 +6,8 @@ import torch
|
||||
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module import ModuleOpt, TDevice
|
||||
from tianshou.highlevel.module.core import TDevice
|
||||
from tianshou.highlevel.module.module_opt import ModuleOpt
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
||||
from tianshou.highlevel.params.dist_fn import (
|
||||
@ -66,6 +67,18 @@ class ParamTransformerDrop(ParamTransformer):
|
||||
del kwargs[k]
|
||||
|
||||
|
||||
class ParamTransformerChangeValue(ParamTransformer):
|
||||
def __init__(self, key: str):
|
||||
self.key = key
|
||||
|
||||
def transform(self, params: dict[str, Any], data: ParamTransformerData):
|
||||
params[self.key] = self.change_value(params[self.key], data)
|
||||
|
||||
@abstractmethod
|
||||
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
class ParamTransformerLRScheduler(ParamTransformer):
|
||||
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
|
||||
a learning rate scheduler (added) for the data member `optim`.
|
||||
@ -182,6 +195,14 @@ class ParamTransformerDistributionFunction(ParamTransformer):
|
||||
kwargs[self.key] = value.create_dist_fn(data.envs)
|
||||
|
||||
|
||||
class ParamTransformerActionScaling(ParamTransformerChangeValue):
|
||||
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
||||
if value == "default":
|
||||
return data.envs.get_type().is_continuous()
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
class GetParamTransformersProtocol(Protocol):
|
||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||
pass
|
||||
@ -218,9 +239,15 @@ class PGParams(Params):
|
||||
discount_factor: float = 0.99
|
||||
reward_normalization: bool = False
|
||||
deterministic_eval: bool = False
|
||||
action_scaling: bool = True
|
||||
action_scaling: bool | Literal["default"] = "default"
|
||||
"""whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces"""
|
||||
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
||||
|
||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||
transformers = super()._get_param_transformers()
|
||||
transformers.append(ParamTransformerActionScaling("action_scaling"))
|
||||
return transformers
|
||||
|
||||
|
||||
@dataclass
|
||||
class A2CParams(PGParams, ParamsMixinLearningRateWithScheduler):
|
||||
|
||||
72
tianshou/highlevel/params/policy_wrapper.py
Normal file
72
tianshou/highlevel/params/policy_wrapper.py
Normal file
@ -0,0 +1,72 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module.core import ModuleFactory, TDevice
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.policy import BasePolicy, ICMPolicy
|
||||
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||
|
||||
TPolicyIn = TypeVar("TPolicyIn", bound=BasePolicy)
|
||||
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
|
||||
|
||||
|
||||
class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ABC):
|
||||
@abstractmethod
|
||||
def create_wrapped_policy(
|
||||
self,
|
||||
policy: TPolicyIn,
|
||||
envs: Environments,
|
||||
optim_factory: OptimizerFactory,
|
||||
device: TDevice,
|
||||
) -> TPolicyOut:
|
||||
pass
|
||||
|
||||
|
||||
class PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
Generic[TPolicyIn], PolicyWrapperFactory[TPolicyIn, ICMPolicy],
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
feature_net_factory: ModuleFactory,
|
||||
hidden_sizes: Sequence[int],
|
||||
lr: float,
|
||||
lr_scale: float,
|
||||
reward_scale: float,
|
||||
forward_loss_weight,
|
||||
):
|
||||
self.feature_net_factory = feature_net_factory
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.lr = lr
|
||||
self.lr_scale = lr_scale
|
||||
self.reward_scale = reward_scale
|
||||
self.forward_loss_weight = forward_loss_weight
|
||||
|
||||
def create_wrapped_policy(
|
||||
self,
|
||||
policy: TPolicyIn,
|
||||
envs: Environments,
|
||||
optim_factory: OptimizerFactory,
|
||||
device: TDevice,
|
||||
) -> ICMPolicy:
|
||||
feature_net = self.feature_net_factory.create_module(envs, device)
|
||||
action_dim = envs.get_action_shape()
|
||||
feature_dim = feature_net.output_dim
|
||||
icm_net = IntrinsicCuriosityModule(
|
||||
feature_net.module,
|
||||
feature_dim,
|
||||
action_dim,
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
device=device,
|
||||
)
|
||||
icm_optim = optim_factory.create_optimizer(icm_net, lr=self.lr)
|
||||
return ICMPolicy(
|
||||
policy=policy,
|
||||
model=icm_net,
|
||||
optim=icm_optim,
|
||||
action_space=envs.get_action_space(),
|
||||
lr_scale=self.lr_scale,
|
||||
reward_scale=self.reward_scale,
|
||||
forward_loss_weight=self.forward_loss_weight,
|
||||
).to(device)
|
||||
Loading…
x
Reference in New Issue
Block a user