Support discrete SAC in high-level API
* Changed machanism for reusing actor's preprocessing module in critics
to avoid special handling in AgentFactory implementations, improving
separation of concerns:
- Added CriticFactoryReuseActor as the new critic factory
- Added ActorFactoryTransientStorageDecorator to pass on the actor
data
- Added helper classes ActorFuture, ActorFutureProviderProtocol
* Add example atari_sac_hl
This commit is contained in:
parent
305b30a6c1
commit
799beb79b4
105
examples/atari/atari_sac_hl.py
Normal file
105
examples/atari/atari_sac_hl.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from jsonargparse import CLI
|
||||||
|
|
||||||
|
from examples.atari.atari_network import (
|
||||||
|
ActorFactoryAtariDQN,
|
||||||
|
FeatureNetFactoryDQN,
|
||||||
|
)
|
||||||
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||||
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
|
from tianshou.highlevel.experiment import (
|
||||||
|
DiscreteSACExperimentBuilder,
|
||||||
|
ExperimentConfig,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
|
||||||
|
from tianshou.highlevel.params.policy_params import DiscreteSACParams
|
||||||
|
from tianshou.highlevel.params.policy_wrapper import (
|
||||||
|
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||||
|
)
|
||||||
|
from tianshou.utils import logging
|
||||||
|
from tianshou.utils.logging import datetime_tag
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
experiment_config: ExperimentConfig,
|
||||||
|
task: str = "PongNoFrameskip-v4",
|
||||||
|
scale_obs: bool = False,
|
||||||
|
buffer_size: int = 100000,
|
||||||
|
actor_lr: float = 1e-5,
|
||||||
|
critic_lr: float = 1e-5,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
n_step: int = 3,
|
||||||
|
tau: float = 0.005,
|
||||||
|
alpha: float = 0.05,
|
||||||
|
auto_alpha: bool = False,
|
||||||
|
alpha_lr: float = 3e-4,
|
||||||
|
epoch: int = 100,
|
||||||
|
step_per_epoch: int = 100000,
|
||||||
|
step_per_collect: int = 10,
|
||||||
|
update_per_step: float = 0.1,
|
||||||
|
batch_size: int = 64,
|
||||||
|
hidden_size: int = 512,
|
||||||
|
training_num: int = 10,
|
||||||
|
test_num: int = 10,
|
||||||
|
frames_stack: int = 4,
|
||||||
|
save_buffer_name: str | None = None, # TODO add support in high-level API?
|
||||||
|
icm_lr_scale: float = 0.0,
|
||||||
|
icm_reward_scale: float = 0.01,
|
||||||
|
icm_forward_loss_weight: float = 0.2,
|
||||||
|
):
|
||||||
|
log_name = os.path.join(task, "sac", str(experiment_config.seed), datetime_tag())
|
||||||
|
|
||||||
|
sampling_config = SamplingConfig(
|
||||||
|
num_epochs=epoch,
|
||||||
|
step_per_epoch=step_per_epoch,
|
||||||
|
update_per_step=update_per_step,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_train_envs=training_num,
|
||||||
|
num_test_envs=test_num,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
step_per_collect=step_per_collect,
|
||||||
|
repeat_per_collect=None,
|
||||||
|
replay_buffer_stack_num=frames_stack,
|
||||||
|
replay_buffer_ignore_obs_next=True,
|
||||||
|
replay_buffer_save_only_last_obs=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
|
||||||
|
|
||||||
|
builder = (
|
||||||
|
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
.with_sac_params(
|
||||||
|
DiscreteSACParams(
|
||||||
|
actor_lr=actor_lr,
|
||||||
|
critic1_lr=critic_lr,
|
||||||
|
critic2_lr=critic_lr,
|
||||||
|
gamma=gamma,
|
||||||
|
tau=tau,
|
||||||
|
alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha,
|
||||||
|
estimation_step=n_step,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs, features_only=True))
|
||||||
|
.with_common_critic_factory_use_actor()
|
||||||
|
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||||
|
)
|
||||||
|
if icm_lr_scale > 0:
|
||||||
|
builder.with_policy_wrapper_factory(
|
||||||
|
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||||
|
FeatureNetFactoryDQN(),
|
||||||
|
[hidden_size],
|
||||||
|
actor_lr,
|
||||||
|
icm_lr_scale,
|
||||||
|
icm_reward_scale,
|
||||||
|
icm_forward_loss_weight,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
experiment = builder.build()
|
||||||
|
experiment.run(log_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.run_main(lambda: CLI(main))
|
||||||
@ -22,9 +22,11 @@ from tianshou.highlevel.optim import OptimizerFactory
|
|||||||
from tianshou.highlevel.params.policy_params import (
|
from tianshou.highlevel.params.policy_params import (
|
||||||
A2CParams,
|
A2CParams,
|
||||||
DDPGParams,
|
DDPGParams,
|
||||||
|
DiscreteSACParams,
|
||||||
DQNParams,
|
DQNParams,
|
||||||
NPGParams,
|
NPGParams,
|
||||||
Params,
|
Params,
|
||||||
|
ParamsMixinActorAndDualCritics,
|
||||||
ParamTransformerData,
|
ParamTransformerData,
|
||||||
PGParams,
|
PGParams,
|
||||||
PPOParams,
|
PPOParams,
|
||||||
@ -39,6 +41,7 @@ from tianshou.policy import (
|
|||||||
A2CPolicy,
|
A2CPolicy,
|
||||||
BasePolicy,
|
BasePolicy,
|
||||||
DDPGPolicy,
|
DDPGPolicy,
|
||||||
|
DiscreteSACPolicy,
|
||||||
DQNPolicy,
|
DQNPolicy,
|
||||||
NPGPolicy,
|
NPGPolicy,
|
||||||
PGPolicy,
|
PGPolicy,
|
||||||
@ -49,13 +52,13 @@ from tianshou.policy import (
|
|||||||
TRPOPolicy,
|
TRPOPolicy,
|
||||||
)
|
)
|
||||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||||
from tianshou.utils.net import continuous, discrete
|
from tianshou.utils.net.common import ActorCritic
|
||||||
from tianshou.utils.net.common import ActorCritic, BaseActor
|
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||||
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
||||||
TParams = TypeVar("TParams", bound=Params)
|
TParams = TypeVar("TParams", bound=Params)
|
||||||
|
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
|
||||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||||
|
|
||||||
|
|
||||||
@ -247,7 +250,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class _ActorCriticMixin:
|
class _ActorCriticMixin: # TODO merge
|
||||||
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
|
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -256,13 +259,11 @@ class _ActorCriticMixin:
|
|||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
critic_use_action: bool,
|
critic_use_action: bool,
|
||||||
critic_use_actor_module: bool,
|
|
||||||
):
|
):
|
||||||
self.actor_factory = actor_factory
|
self.actor_factory = actor_factory
|
||||||
self.critic_factory = critic_factory
|
self.critic_factory = critic_factory
|
||||||
self.optim_factory = optim_factory
|
self.optim_factory = optim_factory
|
||||||
self.critic_use_action = critic_use_action
|
self.critic_use_action = critic_use_action
|
||||||
self.critic_use_actor_module = critic_use_actor_module
|
|
||||||
|
|
||||||
def create_actor_critic_module_opt(
|
def create_actor_critic_module_opt(
|
||||||
self,
|
self,
|
||||||
@ -271,28 +272,7 @@ class _ActorCriticMixin:
|
|||||||
lr: float,
|
lr: float,
|
||||||
) -> ActorCriticModuleOpt:
|
) -> ActorCriticModuleOpt:
|
||||||
actor = self.actor_factory.create_module(envs, device)
|
actor = self.actor_factory.create_module(envs, device)
|
||||||
critic: torch.nn.Module
|
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
|
||||||
if self.critic_use_actor_module:
|
|
||||||
if self.critic_use_action:
|
|
||||||
raise ValueError(
|
|
||||||
"The options critic_use_actor_module and critic_use_action are mutually exclusive",
|
|
||||||
)
|
|
||||||
if not isinstance(actor, BaseActor):
|
|
||||||
raise ValueError(
|
|
||||||
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
|
|
||||||
)
|
|
||||||
if envs.get_type().is_discrete():
|
|
||||||
critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device)
|
|
||||||
elif envs.get_type().is_continuous():
|
|
||||||
critic = continuous.Critic(actor.get_preprocess_net(), device=device).to(device)
|
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
else:
|
|
||||||
critic = self.critic_factory.create_module(
|
|
||||||
envs,
|
|
||||||
device,
|
|
||||||
use_action=self.critic_use_action,
|
|
||||||
)
|
|
||||||
actor_critic = ActorCritic(actor, critic)
|
actor_critic = ActorCritic(actor, critic)
|
||||||
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
||||||
return ActorCriticModuleOpt(actor_critic, optim)
|
return ActorCriticModuleOpt(actor_critic, optim)
|
||||||
@ -349,7 +329,6 @@ class ActorCriticAgentFactory(
|
|||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optimizer_factory: OptimizerFactory,
|
optimizer_factory: OptimizerFactory,
|
||||||
policy_class: type[TPolicy],
|
policy_class: type[TPolicy],
|
||||||
critic_use_actor_module: bool,
|
|
||||||
):
|
):
|
||||||
super().__init__(sampling_config, optim_factory=optimizer_factory)
|
super().__init__(sampling_config, optim_factory=optimizer_factory)
|
||||||
_ActorCriticMixin.__init__(
|
_ActorCriticMixin.__init__(
|
||||||
@ -358,7 +337,6 @@ class ActorCriticAgentFactory(
|
|||||||
critic_factory,
|
critic_factory,
|
||||||
optimizer_factory,
|
optimizer_factory,
|
||||||
critic_use_action=False,
|
critic_use_action=False,
|
||||||
critic_use_actor_module=critic_use_actor_module,
|
|
||||||
)
|
)
|
||||||
self.params = params
|
self.params = params
|
||||||
self.policy_class = policy_class
|
self.policy_class = policy_class
|
||||||
@ -395,7 +373,6 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
|
|||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optimizer_factory: OptimizerFactory,
|
optimizer_factory: OptimizerFactory,
|
||||||
critic_use_actor_module: bool,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
params,
|
params,
|
||||||
@ -404,7 +381,6 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
|
|||||||
critic_factory,
|
critic_factory,
|
||||||
optimizer_factory,
|
optimizer_factory,
|
||||||
A2CPolicy,
|
A2CPolicy,
|
||||||
critic_use_actor_module,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||||
@ -419,7 +395,6 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
|||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optimizer_factory: OptimizerFactory,
|
optimizer_factory: OptimizerFactory,
|
||||||
critic_use_actor_module: bool,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
params,
|
params,
|
||||||
@ -428,7 +403,6 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
|||||||
critic_factory,
|
critic_factory,
|
||||||
optimizer_factory,
|
optimizer_factory,
|
||||||
PPOPolicy,
|
PPOPolicy,
|
||||||
critic_use_actor_module,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||||
@ -443,7 +417,6 @@ class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]):
|
|||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optimizer_factory: OptimizerFactory,
|
optimizer_factory: OptimizerFactory,
|
||||||
critic_use_actor_module: bool,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
params,
|
params,
|
||||||
@ -452,7 +425,6 @@ class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]):
|
|||||||
critic_factory,
|
critic_factory,
|
||||||
optimizer_factory,
|
optimizer_factory,
|
||||||
NPGPolicy,
|
NPGPolicy,
|
||||||
critic_use_actor_module,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||||
@ -467,7 +439,6 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
|
|||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optimizer_factory: OptimizerFactory,
|
optimizer_factory: OptimizerFactory,
|
||||||
critic_use_actor_module: bool,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
params,
|
params,
|
||||||
@ -476,7 +447,6 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
|
|||||||
critic_factory,
|
critic_factory,
|
||||||
optimizer_factory,
|
optimizer_factory,
|
||||||
TRPOPolicy,
|
TRPOPolicy,
|
||||||
critic_use_actor_module,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||||
@ -619,6 +589,81 @@ class REDQAgentFactory(OffpolicyAgentFactory):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ActorDualCriticsAgentFactory(
|
||||||
|
OffpolicyAgentFactory, Generic[TActorDualCriticsParams, TPolicy], ABC,
|
||||||
|
):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: TActorDualCriticsParams,
|
||||||
|
sampling_config: SamplingConfig,
|
||||||
|
actor_factory: ActorFactory,
|
||||||
|
critic1_factory: CriticFactory,
|
||||||
|
critic2_factory: CriticFactory,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
):
|
||||||
|
super().__init__(sampling_config, optim_factory)
|
||||||
|
self.params = params
|
||||||
|
self.actor_factory = actor_factory
|
||||||
|
self.critic1_factory = critic1_factory
|
||||||
|
self.critic2_factory = critic2_factory
|
||||||
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_policy_class(self) -> type[TPolicy]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_discrete_last_size_use_action_shape(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||||
|
actor = self.actor_factory.create_module_opt(
|
||||||
|
envs,
|
||||||
|
device,
|
||||||
|
self.optim_factory,
|
||||||
|
self.params.actor_lr,
|
||||||
|
)
|
||||||
|
use_action_shape = self._get_discrete_last_size_use_action_shape()
|
||||||
|
critic1 = self.critic1_factory.create_module_opt(
|
||||||
|
envs,
|
||||||
|
device,
|
||||||
|
True,
|
||||||
|
self.optim_factory,
|
||||||
|
self.params.critic1_lr,
|
||||||
|
discrete_last_size_use_action_shape=use_action_shape,
|
||||||
|
)
|
||||||
|
critic2 = self.critic2_factory.create_module_opt(
|
||||||
|
envs,
|
||||||
|
device,
|
||||||
|
True,
|
||||||
|
self.optim_factory,
|
||||||
|
self.params.critic2_lr,
|
||||||
|
discrete_last_size_use_action_shape=use_action_shape,
|
||||||
|
)
|
||||||
|
kwargs = self.params.create_kwargs(
|
||||||
|
ParamTransformerData(
|
||||||
|
envs=envs,
|
||||||
|
device=device,
|
||||||
|
optim_factory=self.optim_factory,
|
||||||
|
actor=actor,
|
||||||
|
critic1=critic1,
|
||||||
|
critic2=critic2,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
policy_class = self._get_policy_class()
|
||||||
|
return policy_class(
|
||||||
|
actor=actor.module,
|
||||||
|
actor_optim=actor.optim,
|
||||||
|
critic=critic1.module,
|
||||||
|
critic_optim=critic1.optim,
|
||||||
|
critic2=critic2.module,
|
||||||
|
critic2_optim=critic2.optim,
|
||||||
|
action_space=envs.get_action_space(),
|
||||||
|
observation_space=envs.get_observation_space(),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SACAgentFactory(OffpolicyAgentFactory):
|
class SACAgentFactory(OffpolicyAgentFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -680,6 +725,14 @@ class SACAgentFactory(OffpolicyAgentFactory):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscreteSACAgentFactory(ActorDualCriticsAgentFactory[DiscreteSACParams, DiscreteSACPolicy]):
|
||||||
|
def _get_discrete_last_size_use_action_shape(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_policy_class(self) -> type[TPolicy]:
|
||||||
|
return DiscreteSACPolicy
|
||||||
|
|
||||||
|
|
||||||
class TD3AgentFactory(OffpolicyAgentFactory):
|
class TD3AgentFactory(OffpolicyAgentFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from tianshou.highlevel.agent import (
|
|||||||
A2CAgentFactory,
|
A2CAgentFactory,
|
||||||
AgentFactory,
|
AgentFactory,
|
||||||
DDPGAgentFactory,
|
DDPGAgentFactory,
|
||||||
|
DiscreteSACAgentFactory,
|
||||||
DQNAgentFactory,
|
DQNAgentFactory,
|
||||||
NPGAgentFactory,
|
NPGAgentFactory,
|
||||||
PGAgentFactory,
|
PGAgentFactory,
|
||||||
@ -28,6 +29,9 @@ from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
|||||||
from tianshou.highlevel.module.actor import (
|
from tianshou.highlevel.module.actor import (
|
||||||
ActorFactory,
|
ActorFactory,
|
||||||
ActorFactoryDefault,
|
ActorFactoryDefault,
|
||||||
|
ActorFactoryTransientStorageDecorator,
|
||||||
|
ActorFuture,
|
||||||
|
ActorFutureProviderProtocol,
|
||||||
ContinuousActorType,
|
ContinuousActorType,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.module.critic import (
|
from tianshou.highlevel.module.critic import (
|
||||||
@ -35,11 +39,13 @@ from tianshou.highlevel.module.critic import (
|
|||||||
CriticEnsembleFactoryDefault,
|
CriticEnsembleFactoryDefault,
|
||||||
CriticFactory,
|
CriticFactory,
|
||||||
CriticFactoryDefault,
|
CriticFactoryDefault,
|
||||||
|
CriticFactoryReuseActor,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
||||||
from tianshou.highlevel.params.policy_params import (
|
from tianshou.highlevel.params.policy_params import (
|
||||||
A2CParams,
|
A2CParams,
|
||||||
DDPGParams,
|
DDPGParams,
|
||||||
|
DiscreteSACParams,
|
||||||
DQNParams,
|
DQNParams,
|
||||||
NPGParams,
|
NPGParams,
|
||||||
PGParams,
|
PGParams,
|
||||||
@ -263,9 +269,10 @@ class ExperimentBuilder:
|
|||||||
return experiment
|
return experiment
|
||||||
|
|
||||||
|
|
||||||
class _BuilderMixinActorFactory:
|
class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
||||||
def __init__(self, continuous_actor_type: ContinuousActorType):
|
def __init__(self, continuous_actor_type: ContinuousActorType):
|
||||||
self._continuous_actor_type = continuous_actor_type
|
self._continuous_actor_type = continuous_actor_type
|
||||||
|
self._actor_future = ActorFuture()
|
||||||
self._actor_factory: ActorFactory | None = None
|
self._actor_factory: ActorFactory | None = None
|
||||||
|
|
||||||
def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
|
def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
|
||||||
@ -286,11 +293,16 @@ class _BuilderMixinActorFactory:
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def get_actor_future(self) -> ActorFuture:
|
||||||
|
return self._actor_future
|
||||||
|
|
||||||
def _get_actor_factory(self) -> ActorFactory:
|
def _get_actor_factory(self) -> ActorFactory:
|
||||||
|
actor_factory: ActorFactory
|
||||||
if self._actor_factory is None:
|
if self._actor_factory is None:
|
||||||
return ActorFactoryDefault(self._continuous_actor_type)
|
actor_factory = ActorFactoryDefault(self._continuous_actor_type)
|
||||||
else:
|
else:
|
||||||
return self._actor_factory
|
actor_factory = self._actor_factory
|
||||||
|
return ActorFactoryTransientStorageDecorator(actor_factory, self._actor_future)
|
||||||
|
|
||||||
|
|
||||||
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
||||||
@ -325,7 +337,8 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor
|
|||||||
|
|
||||||
|
|
||||||
class _BuilderMixinCriticsFactory:
|
class _BuilderMixinCriticsFactory:
|
||||||
def __init__(self, num_critics: int):
|
def __init__(self, num_critics: int, actor_future_provider: ActorFutureProviderProtocol):
|
||||||
|
self._actor_future_provider = actor_future_provider
|
||||||
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
|
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
|
||||||
|
|
||||||
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory) -> Self:
|
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory) -> Self:
|
||||||
@ -336,6 +349,12 @@ class _BuilderMixinCriticsFactory:
|
|||||||
self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes)
|
self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _with_critic_factory_use_actor(self, idx: int) -> Self:
|
||||||
|
self._critic_factories[idx] = CriticFactoryReuseActor(
|
||||||
|
self._actor_future_provider.get_actor_future(),
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
def _get_critic_factory(self, idx: int) -> CriticFactory:
|
def _get_critic_factory(self, idx: int) -> CriticFactory:
|
||||||
factory = self._critic_factories[idx]
|
factory = self._critic_factories[idx]
|
||||||
if factory is None:
|
if factory is None:
|
||||||
@ -345,8 +364,8 @@ class _BuilderMixinCriticsFactory:
|
|||||||
|
|
||||||
|
|
||||||
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||||
def __init__(self) -> None:
|
def __init__(self, actor_future_provider: ActorFutureProviderProtocol = None) -> None:
|
||||||
super().__init__(1)
|
super().__init__(1, actor_future_provider)
|
||||||
|
|
||||||
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
||||||
self._with_critic_factory(0, critic_factory)
|
self._with_critic_factory(0, critic_factory)
|
||||||
@ -361,19 +380,17 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
|
|
||||||
|
|
||||||
class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory):
|
class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory):
|
||||||
def __init__(self) -> None:
|
def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None:
|
||||||
super().__init__()
|
super().__init__(actor_future_provider)
|
||||||
self._critic_use_actor_module = False
|
|
||||||
|
|
||||||
def with_critic_factory_use_actor(self) -> Self:
|
def with_critic_factory_use_actor(self) -> Self:
|
||||||
"""Makes the critic use the same network as the actor."""
|
"""Makes the critic use the same network as the actor."""
|
||||||
self._critic_use_actor_module = True
|
return self._with_critic_factory_use_actor(0)
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||||
def __init__(self) -> None:
|
def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None:
|
||||||
super().__init__(2)
|
super().__init__(2, actor_future_provider)
|
||||||
|
|
||||||
def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
||||||
for i in range(len(self._critic_factories)):
|
for i in range(len(self._critic_factories)):
|
||||||
@ -388,6 +405,12 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
self._with_critic_factory_default(i, hidden_sizes)
|
self._with_critic_factory_default(i, hidden_sizes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def with_common_critic_factory_use_actor(self) -> Self:
|
||||||
|
"""Makes all critics use the same network as the actor."""
|
||||||
|
for i in range(len(self._critic_factories)):
|
||||||
|
self._with_critic_factory_use_actor(i)
|
||||||
|
return self
|
||||||
|
|
||||||
def with_critic1_factory(self, critic_factory: CriticFactory) -> Self:
|
def with_critic1_factory(self, critic_factory: CriticFactory) -> Self:
|
||||||
self._with_critic_factory(0, critic_factory)
|
self._with_critic_factory(0, critic_factory)
|
||||||
return self
|
return self
|
||||||
@ -399,6 +422,10 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
self._with_critic_factory_default(0, hidden_sizes)
|
self._with_critic_factory_default(0, hidden_sizes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def with_critic1_factory_use_actor(self) -> Self:
|
||||||
|
"""Makes the critic use the same network as the actor."""
|
||||||
|
return self._with_critic_factory_use_actor(0)
|
||||||
|
|
||||||
def with_critic2_factory(self, critic_factory: CriticFactory) -> Self:
|
def with_critic2_factory(self, critic_factory: CriticFactory) -> Self:
|
||||||
self._with_critic_factory(1, critic_factory)
|
self._with_critic_factory(1, critic_factory)
|
||||||
return self
|
return self
|
||||||
@ -410,6 +437,10 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
self._with_critic_factory_default(0, hidden_sizes)
|
self._with_critic_factory_default(0, hidden_sizes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def with_critic2_factory_use_actor(self) -> Self:
|
||||||
|
"""Makes the second critic use the same network as the actor."""
|
||||||
|
return self._with_critic_factory_use_actor(1)
|
||||||
|
|
||||||
|
|
||||||
class _BuilderMixinCriticEnsembleFactory:
|
class _BuilderMixinCriticEnsembleFactory:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@ -475,7 +506,7 @@ class A2CExperimentBuilder(
|
|||||||
):
|
):
|
||||||
super().__init__(env_factory, experiment_config, sampling_config)
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
|
||||||
self._params: A2CParams = A2CParams()
|
self._params: A2CParams = A2CParams()
|
||||||
self._env_config = None
|
self._env_config = None
|
||||||
|
|
||||||
@ -483,7 +514,6 @@ class A2CExperimentBuilder(
|
|||||||
self._params = params
|
self._params = params
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _create_agent_factory(self) -> AgentFactory:
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
return A2CAgentFactory(
|
return A2CAgentFactory(
|
||||||
self._params,
|
self._params,
|
||||||
@ -491,7 +521,6 @@ class A2CExperimentBuilder(
|
|||||||
self._get_actor_factory(),
|
self._get_actor_factory(),
|
||||||
self._get_critic_factory(0),
|
self._get_critic_factory(0),
|
||||||
self._get_optim_factory(),
|
self._get_optim_factory(),
|
||||||
self._critic_use_actor_module,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -508,14 +537,13 @@ class PPOExperimentBuilder(
|
|||||||
):
|
):
|
||||||
super().__init__(env_factory, experiment_config, sampling_config)
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
|
||||||
self._params: PPOParams = PPOParams()
|
self._params: PPOParams = PPOParams()
|
||||||
|
|
||||||
def with_ppo_params(self, params: PPOParams) -> Self:
|
def with_ppo_params(self, params: PPOParams) -> Self:
|
||||||
self._params = params
|
self._params = params
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _create_agent_factory(self) -> AgentFactory:
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
return PPOAgentFactory(
|
return PPOAgentFactory(
|
||||||
self._params,
|
self._params,
|
||||||
@ -523,7 +551,6 @@ class PPOExperimentBuilder(
|
|||||||
self._get_actor_factory(),
|
self._get_actor_factory(),
|
||||||
self._get_critic_factory(0),
|
self._get_critic_factory(0),
|
||||||
self._get_optim_factory(),
|
self._get_optim_factory(),
|
||||||
self._critic_use_actor_module,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -540,14 +567,13 @@ class NPGExperimentBuilder(
|
|||||||
):
|
):
|
||||||
super().__init__(env_factory, experiment_config, sampling_config)
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
|
||||||
self._params: NPGParams = NPGParams()
|
self._params: NPGParams = NPGParams()
|
||||||
|
|
||||||
def with_npg_params(self, params: NPGParams) -> Self:
|
def with_npg_params(self, params: NPGParams) -> Self:
|
||||||
self._params = params
|
self._params = params
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _create_agent_factory(self) -> AgentFactory:
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
return NPGAgentFactory(
|
return NPGAgentFactory(
|
||||||
self._params,
|
self._params,
|
||||||
@ -555,7 +581,6 @@ class NPGExperimentBuilder(
|
|||||||
self._get_actor_factory(),
|
self._get_actor_factory(),
|
||||||
self._get_critic_factory(0),
|
self._get_critic_factory(0),
|
||||||
self._get_optim_factory(),
|
self._get_optim_factory(),
|
||||||
self._critic_use_actor_module,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -572,14 +597,13 @@ class TRPOExperimentBuilder(
|
|||||||
):
|
):
|
||||||
super().__init__(env_factory, experiment_config, sampling_config)
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
|
||||||
self._params: TRPOParams = TRPOParams()
|
self._params: TRPOParams = TRPOParams()
|
||||||
|
|
||||||
def with_trpo_params(self, params: TRPOParams) -> Self:
|
def with_trpo_params(self, params: TRPOParams) -> Self:
|
||||||
self._params = params
|
self._params = params
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _create_agent_factory(self) -> AgentFactory:
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
return TRPOAgentFactory(
|
return TRPOAgentFactory(
|
||||||
self._params,
|
self._params,
|
||||||
@ -587,7 +611,6 @@ class TRPOExperimentBuilder(
|
|||||||
self._get_actor_factory(),
|
self._get_actor_factory(),
|
||||||
self._get_critic_factory(0),
|
self._get_critic_factory(0),
|
||||||
self._get_optim_factory(),
|
self._get_optim_factory(),
|
||||||
self._critic_use_actor_module,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -609,7 +632,6 @@ class DQNExperimentBuilder(
|
|||||||
self._params = params
|
self._params = params
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _create_agent_factory(self) -> AgentFactory:
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
return DQNAgentFactory(
|
return DQNAgentFactory(
|
||||||
self._params,
|
self._params,
|
||||||
@ -632,14 +654,13 @@ class DDPGExperimentBuilder(
|
|||||||
):
|
):
|
||||||
super().__init__(env_factory, experiment_config, sampling_config)
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
|
||||||
self._params: DDPGParams = DDPGParams()
|
self._params: DDPGParams = DDPGParams()
|
||||||
|
|
||||||
def with_ddpg_params(self, params: DDPGParams) -> Self:
|
def with_ddpg_params(self, params: DDPGParams) -> Self:
|
||||||
self._params = params
|
self._params = params
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _create_agent_factory(self) -> AgentFactory:
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
return DDPGAgentFactory(
|
return DDPGAgentFactory(
|
||||||
self._params,
|
self._params,
|
||||||
@ -670,7 +691,6 @@ class REDQExperimentBuilder(
|
|||||||
self._params = params
|
self._params = params
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _create_agent_factory(self) -> AgentFactory:
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
return REDQAgentFactory(
|
return REDQAgentFactory(
|
||||||
self._params,
|
self._params,
|
||||||
@ -694,7 +714,7 @@ class SACExperimentBuilder(
|
|||||||
):
|
):
|
||||||
super().__init__(env_factory, experiment_config, sampling_config)
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
_BuilderMixinDualCriticFactory.__init__(self)
|
_BuilderMixinDualCriticFactory.__init__(self, self)
|
||||||
self._params: SACParams = SACParams()
|
self._params: SACParams = SACParams()
|
||||||
|
|
||||||
def with_sac_params(self, params: SACParams) -> Self:
|
def with_sac_params(self, params: SACParams) -> Self:
|
||||||
@ -712,6 +732,37 @@ class SACExperimentBuilder(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscreteSACExperimentBuilder(
|
||||||
|
ExperimentBuilder,
|
||||||
|
_BuilderMixinActorFactory,
|
||||||
|
_BuilderMixinDualCriticFactory,
|
||||||
|
):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env_factory: EnvFactory,
|
||||||
|
experiment_config: ExperimentConfig | None = None,
|
||||||
|
sampling_config: SamplingConfig | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
|
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
|
||||||
|
_BuilderMixinDualCriticFactory.__init__(self, self)
|
||||||
|
self._params: DiscreteSACParams = DiscreteSACParams()
|
||||||
|
|
||||||
|
def with_sac_params(self, params: DiscreteSACParams) -> Self:
|
||||||
|
self._params = params
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
|
return DiscreteSACAgentFactory(
|
||||||
|
self._params,
|
||||||
|
self._sampling_config,
|
||||||
|
self._get_actor_factory(),
|
||||||
|
self._get_critic_factory(0),
|
||||||
|
self._get_critic_factory(1),
|
||||||
|
self._get_optim_factory(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TD3ExperimentBuilder(
|
class TD3ExperimentBuilder(
|
||||||
ExperimentBuilder,
|
ExperimentBuilder,
|
||||||
_BuilderMixinActorFactory_ContinuousDeterministic,
|
_BuilderMixinActorFactory_ContinuousDeterministic,
|
||||||
@ -725,7 +776,7 @@ class TD3ExperimentBuilder(
|
|||||||
):
|
):
|
||||||
super().__init__(env_factory, experiment_config, sampling_config)
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||||
_BuilderMixinDualCriticFactory.__init__(self)
|
_BuilderMixinDualCriticFactory.__init__(self, self)
|
||||||
self._params: TD3Params = TD3Params()
|
self._params: TD3Params = TD3Params()
|
||||||
|
|
||||||
def with_td3_params(self, params: TD3Params) -> Self:
|
def with_td3_params(self, params: TD3Params) -> Self:
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -20,6 +22,18 @@ class ContinuousActorType(Enum):
|
|||||||
UNSUPPORTED = "unsupported"
|
UNSUPPORTED = "unsupported"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ActorFuture:
|
||||||
|
"""Container, which, in the future, will hold an actor instance."""
|
||||||
|
|
||||||
|
actor: BaseActor | nn.Module | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ActorFutureProviderProtocol(Protocol):
|
||||||
|
def get_actor_future(self) -> ActorFuture:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ActorFactory(ToStringMixin, ABC):
|
class ActorFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
|
||||||
@ -175,3 +189,26 @@ class ActorFactoryDiscreteNet(ActorFactory):
|
|||||||
hidden_sizes=(),
|
hidden_sizes=(),
|
||||||
device=device,
|
device=device,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
class ActorFactoryTransientStorageDecorator(ActorFactory):
|
||||||
|
def __init__(self, actor_factory: ActorFactory, actor_future: ActorFuture):
|
||||||
|
self.actor_factory = actor_factory
|
||||||
|
self._actor_future = actor_future
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
d = dict(self.__dict__)
|
||||||
|
del d["_actor_future"]
|
||||||
|
return d
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__ = state
|
||||||
|
self._actor_future = ActorFuture()
|
||||||
|
|
||||||
|
def _tostring_excludes(self):
|
||||||
|
return [*super()._tostring_excludes(), "_actor_future"]
|
||||||
|
|
||||||
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
|
||||||
|
module = self.actor_factory.create_module(envs, device)
|
||||||
|
self._actor_future.actor = module
|
||||||
|
return module
|
||||||
|
|||||||
@ -4,18 +4,30 @@ from collections.abc import Sequence
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.highlevel.env import Environments, EnvType
|
from tianshou.highlevel.env import Environments, EnvType
|
||||||
|
from tianshou.highlevel.module.actor import ActorFuture
|
||||||
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
||||||
from tianshou.highlevel.module.module_opt import ModuleOpt
|
from tianshou.highlevel.module.module_opt import ModuleOpt
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.utils.net import continuous, discrete
|
from tianshou.utils.net import continuous, discrete
|
||||||
from tianshou.utils.net.common import EnsembleLinear, Net
|
from tianshou.utils.net.common import BaseActor, EnsembleLinear, Net
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
class CriticFactory(ToStringMixin, ABC):
|
class CriticFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
def create_module(
|
||||||
pass
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
device: TDevice,
|
||||||
|
use_action: bool,
|
||||||
|
discrete_last_size_use_action_shape: bool = False,
|
||||||
|
) -> nn.Module:
|
||||||
|
""":param envs: the environments
|
||||||
|
:param device: the torch device
|
||||||
|
:param use_action: whether to (additionally) expect the action as input
|
||||||
|
:param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape
|
||||||
|
:return: the module
|
||||||
|
"""
|
||||||
|
|
||||||
def create_module_opt(
|
def create_module_opt(
|
||||||
self,
|
self,
|
||||||
@ -24,8 +36,14 @@ class CriticFactory(ToStringMixin, ABC):
|
|||||||
use_action: bool,
|
use_action: bool,
|
||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
lr: float,
|
lr: float,
|
||||||
|
discrete_last_size_use_action_shape: bool = False,
|
||||||
) -> ModuleOpt:
|
) -> ModuleOpt:
|
||||||
module = self.create_module(envs, device, use_action)
|
module = self.create_module(
|
||||||
|
envs,
|
||||||
|
device,
|
||||||
|
use_action,
|
||||||
|
discrete_last_size_use_action_shape=discrete_last_size_use_action_shape,
|
||||||
|
)
|
||||||
opt = optim_factory.create_optimizer(module, lr)
|
opt = optim_factory.create_optimizer(module, lr)
|
||||||
return ModuleOpt(module, opt)
|
return ModuleOpt(module, opt)
|
||||||
|
|
||||||
@ -38,7 +56,13 @@ class CriticFactoryDefault(CriticFactory):
|
|||||||
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
|
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
def create_module(
|
||||||
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
device: TDevice,
|
||||||
|
use_action: bool,
|
||||||
|
discrete_last_size_use_action_shape=False,
|
||||||
|
) -> nn.Module:
|
||||||
factory: CriticFactory
|
factory: CriticFactory
|
||||||
env_type = envs.get_type()
|
env_type = envs.get_type()
|
||||||
match env_type:
|
match env_type:
|
||||||
@ -48,14 +72,25 @@ class CriticFactoryDefault(CriticFactory):
|
|||||||
factory = CriticFactoryDiscreteNet(self.hidden_sizes)
|
factory = CriticFactoryDiscreteNet(self.hidden_sizes)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"{env_type} not supported")
|
raise ValueError(f"{env_type} not supported")
|
||||||
return factory.create_module(envs, device, use_action)
|
return factory.create_module(
|
||||||
|
envs,
|
||||||
|
device,
|
||||||
|
use_action,
|
||||||
|
discrete_last_size_use_action_shape=discrete_last_size_use_action_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CriticFactoryContinuousNet(CriticFactory):
|
class CriticFactoryContinuousNet(CriticFactory):
|
||||||
def __init__(self, hidden_sizes: Sequence[int]):
|
def __init__(self, hidden_sizes: Sequence[int]):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
def create_module(
|
||||||
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
device: TDevice,
|
||||||
|
use_action: bool,
|
||||||
|
discrete_last_size_use_action_shape=False,
|
||||||
|
) -> nn.Module:
|
||||||
action_shape = envs.get_action_shape() if use_action else 0
|
action_shape = envs.get_action_shape() if use_action else 0
|
||||||
net_c = Net(
|
net_c = Net(
|
||||||
envs.get_observation_shape(),
|
envs.get_observation_shape(),
|
||||||
@ -74,7 +109,13 @@ class CriticFactoryDiscreteNet(CriticFactory):
|
|||||||
def __init__(self, hidden_sizes: Sequence[int]):
|
def __init__(self, hidden_sizes: Sequence[int]):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
def create_module(
|
||||||
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
device: TDevice,
|
||||||
|
use_action: bool,
|
||||||
|
discrete_last_size_use_action_shape=False,
|
||||||
|
) -> nn.Module:
|
||||||
action_shape = envs.get_action_shape() if use_action else 0
|
action_shape = envs.get_action_shape() if use_action else 0
|
||||||
net_c = Net(
|
net_c = Net(
|
||||||
envs.get_observation_shape(),
|
envs.get_observation_shape(),
|
||||||
@ -84,11 +125,50 @@ class CriticFactoryDiscreteNet(CriticFactory):
|
|||||||
activation=nn.Tanh,
|
activation=nn.Tanh,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
critic = discrete.Critic(net_c, device=device).to(device)
|
last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1
|
||||||
|
critic = discrete.Critic(net_c, device=device, last_size=last_size).to(device)
|
||||||
init_linear_orthogonal(critic)
|
init_linear_orthogonal(critic)
|
||||||
return critic
|
return critic
|
||||||
|
|
||||||
|
|
||||||
|
class CriticFactoryReuseActor(CriticFactory):
|
||||||
|
"""A critic factory which reuses the actor's preprocessing component.
|
||||||
|
|
||||||
|
This class is for internal use in experiment builders only.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, actor_future: ActorFuture):
|
||||||
|
""":param actor_future: the object, which will hold the actor instance later when the critic is to be created"""
|
||||||
|
self.actor_future = actor_future
|
||||||
|
|
||||||
|
def _tostring_excludes(self) -> list[str]:
|
||||||
|
return ["actor_future"]
|
||||||
|
|
||||||
|
def create_module(
|
||||||
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
device: TDevice,
|
||||||
|
use_action: bool,
|
||||||
|
discrete_last_size_use_action_shape=False,
|
||||||
|
) -> nn.Module:
|
||||||
|
actor = self.actor_future.actor
|
||||||
|
if not isinstance(actor, BaseActor):
|
||||||
|
raise ValueError(
|
||||||
|
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
|
||||||
|
)
|
||||||
|
if envs.get_type().is_discrete():
|
||||||
|
last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1
|
||||||
|
return discrete.Critic(
|
||||||
|
actor.get_preprocess_net(),
|
||||||
|
device=device,
|
||||||
|
last_size=last_size,
|
||||||
|
).to(device)
|
||||||
|
elif envs.get_type().is_continuous():
|
||||||
|
return continuous.Critic(actor.get_preprocess_net(), device=device).to(device)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
class CriticEnsembleFactory:
|
class CriticEnsembleFactory:
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_module(
|
def create_module(
|
||||||
|
|||||||
@ -359,6 +359,20 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
|
|||||||
return transformers
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DiscreteSACParams(Params, ParamsMixinActorAndDualCritics):
|
||||||
|
tau: float = 0.005
|
||||||
|
gamma: float = 0.99
|
||||||
|
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
||||||
|
estimation_step: int = 1
|
||||||
|
|
||||||
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
|
transformers = super()._get_param_transformers()
|
||||||
|
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
||||||
|
transformers.append(ParamTransformerAutoAlpha("alpha"))
|
||||||
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
|
class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
|
||||||
discount_factor: float = 0.99
|
discount_factor: float = 0.99
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user