Support discrete SAC in high-level API

* Changed machanism for reusing actor's preprocessing module in critics
  to avoid special handling in AgentFactory implementations, improving
  separation of concerns:
    - Added CriticFactoryReuseActor as the new critic factory
    - Added ActorFactoryTransientStorageDecorator to pass on the actor
      data
    - Added helper classes ActorFuture, ActorFutureProviderProtocol
* Add example atari_sac_hl
This commit is contained in:
Dominik Jain 2023-10-10 19:11:49 +02:00
parent 305b30a6c1
commit 799beb79b4
6 changed files with 417 additions and 77 deletions

View File

@ -0,0 +1,105 @@
#!/usr/bin/env python3
import os
from jsonargparse import CLI
from examples.atari.atari_network import (
ActorFactoryAtariDQN,
FeatureNetFactoryDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
DiscreteSACExperimentBuilder,
ExperimentConfig,
)
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
from tianshou.highlevel.params.policy_params import DiscreteSACParams
from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity,
)
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: bool = False,
buffer_size: int = 100000,
actor_lr: float = 1e-5,
critic_lr: float = 1e-5,
gamma: float = 0.99,
n_step: int = 3,
tau: float = 0.005,
alpha: float = 0.05,
auto_alpha: bool = False,
alpha_lr: float = 3e-4,
epoch: int = 100,
step_per_epoch: int = 100000,
step_per_collect: int = 10,
update_per_step: float = 0.1,
batch_size: int = 64,
hidden_size: int = 512,
training_num: int = 10,
test_num: int = 10,
frames_stack: int = 4,
save_buffer_name: str | None = None, # TODO add support in high-level API?
icm_lr_scale: float = 0.0,
icm_reward_scale: float = 0.01,
icm_forward_loss_weight: float = 0.2,
):
log_name = os.path.join(task, "sac", str(experiment_config.seed), datetime_tag())
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
update_per_step=update_per_step,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
repeat_per_collect=None,
replay_buffer_stack_num=frames_stack,
replay_buffer_ignore_obs_next=True,
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
builder = (
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_sac_params(
DiscreteSACParams(
actor_lr=actor_lr,
critic1_lr=critic_lr,
critic2_lr=critic_lr,
gamma=gamma,
tau=tau,
alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha,
estimation_step=n_step,
),
)
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs, features_only=True))
.with_common_critic_factory_use_actor()
.with_trainer_stop_callback(AtariStopCallback(task))
)
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity(
FeatureNetFactoryDQN(),
[hidden_size],
actor_lr,
icm_lr_scale,
icm_reward_scale,
icm_forward_loss_weight,
),
)
experiment = builder.build()
experiment.run(log_name)
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))

View File

@ -22,9 +22,11 @@ from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.policy_params import (
A2CParams,
DDPGParams,
DiscreteSACParams,
DQNParams,
NPGParams,
Params,
ParamsMixinActorAndDualCritics,
ParamTransformerData,
PGParams,
PPOParams,
@ -39,6 +41,7 @@ from tianshou.policy import (
A2CPolicy,
BasePolicy,
DDPGPolicy,
DiscreteSACPolicy,
DQNPolicy,
NPGPolicy,
PGPolicy,
@ -49,13 +52,13 @@ from tianshou.policy import (
TRPOPolicy,
)
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import ActorCritic, BaseActor
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.string import ToStringMixin
CHECKPOINT_DICT_KEY_MODEL = "model"
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
TParams = TypeVar("TParams", bound=Params)
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
@ -247,7 +250,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
)
class _ActorCriticMixin:
class _ActorCriticMixin: # TODO merge
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
def __init__(
@ -256,13 +259,11 @@ class _ActorCriticMixin:
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
critic_use_action: bool,
critic_use_actor_module: bool,
):
self.actor_factory = actor_factory
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.critic_use_action = critic_use_action
self.critic_use_actor_module = critic_use_actor_module
def create_actor_critic_module_opt(
self,
@ -271,28 +272,7 @@ class _ActorCriticMixin:
lr: float,
) -> ActorCriticModuleOpt:
actor = self.actor_factory.create_module(envs, device)
critic: torch.nn.Module
if self.critic_use_actor_module:
if self.critic_use_action:
raise ValueError(
"The options critic_use_actor_module and critic_use_action are mutually exclusive",
)
if not isinstance(actor, BaseActor):
raise ValueError(
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
)
if envs.get_type().is_discrete():
critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device)
elif envs.get_type().is_continuous():
critic = continuous.Critic(actor.get_preprocess_net(), device=device).to(device)
else:
raise ValueError
else:
critic = self.critic_factory.create_module(
envs,
device,
use_action=self.critic_use_action,
)
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
actor_critic = ActorCritic(actor, critic)
optim = self.optim_factory.create_optimizer(actor_critic, lr)
return ActorCriticModuleOpt(actor_critic, optim)
@ -349,7 +329,6 @@ class ActorCriticAgentFactory(
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
policy_class: type[TPolicy],
critic_use_actor_module: bool,
):
super().__init__(sampling_config, optim_factory=optimizer_factory)
_ActorCriticMixin.__init__(
@ -358,7 +337,6 @@ class ActorCriticAgentFactory(
critic_factory,
optimizer_factory,
critic_use_action=False,
critic_use_actor_module=critic_use_actor_module,
)
self.params = params
self.policy_class = policy_class
@ -395,7 +373,6 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
critic_use_actor_module: bool,
):
super().__init__(
params,
@ -404,7 +381,6 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
critic_factory,
optimizer_factory,
A2CPolicy,
critic_use_actor_module,
)
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
@ -419,7 +395,6 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
critic_use_actor_module: bool,
):
super().__init__(
params,
@ -428,7 +403,6 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
critic_factory,
optimizer_factory,
PPOPolicy,
critic_use_actor_module,
)
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
@ -443,7 +417,6 @@ class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]):
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
critic_use_actor_module: bool,
):
super().__init__(
params,
@ -452,7 +425,6 @@ class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]):
critic_factory,
optimizer_factory,
NPGPolicy,
critic_use_actor_module,
)
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
@ -467,7 +439,6 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
critic_use_actor_module: bool,
):
super().__init__(
params,
@ -476,7 +447,6 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
critic_factory,
optimizer_factory,
TRPOPolicy,
critic_use_actor_module,
)
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
@ -619,6 +589,81 @@ class REDQAgentFactory(OffpolicyAgentFactory):
)
class ActorDualCriticsAgentFactory(
OffpolicyAgentFactory, Generic[TActorDualCriticsParams, TPolicy], ABC,
):
def __init__(
self,
params: TActorDualCriticsParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic1_factory: CriticFactory,
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
self.params = params
self.actor_factory = actor_factory
self.critic1_factory = critic1_factory
self.critic2_factory = critic2_factory
self.optim_factory = optim_factory
@abstractmethod
def _get_policy_class(self) -> type[TPolicy]:
pass
@abstractmethod
def _get_discrete_last_size_use_action_shape(self) -> bool:
pass
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
actor = self.actor_factory.create_module_opt(
envs,
device,
self.optim_factory,
self.params.actor_lr,
)
use_action_shape = self._get_discrete_last_size_use_action_shape()
critic1 = self.critic1_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic1_lr,
discrete_last_size_use_action_shape=use_action_shape,
)
critic2 = self.critic2_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic2_lr,
discrete_last_size_use_action_shape=use_action_shape,
)
kwargs = self.params.create_kwargs(
ParamTransformerData(
envs=envs,
device=device,
optim_factory=self.optim_factory,
actor=actor,
critic1=critic1,
critic2=critic2,
),
)
policy_class = self._get_policy_class()
return policy_class(
actor=actor.module,
actor_optim=actor.optim,
critic=critic1.module,
critic_optim=critic1.optim,
critic2=critic2.module,
critic2_optim=critic2.optim,
action_space=envs.get_action_space(),
observation_space=envs.get_observation_space(),
**kwargs,
)
class SACAgentFactory(OffpolicyAgentFactory):
def __init__(
self,
@ -680,6 +725,14 @@ class SACAgentFactory(OffpolicyAgentFactory):
)
class DiscreteSACAgentFactory(ActorDualCriticsAgentFactory[DiscreteSACParams, DiscreteSACPolicy]):
def _get_discrete_last_size_use_action_shape(self) -> bool:
return True
def _get_policy_class(self) -> type[TPolicy]:
return DiscreteSACPolicy
class TD3AgentFactory(OffpolicyAgentFactory):
def __init__(
self,

View File

@ -13,6 +13,7 @@ from tianshou.highlevel.agent import (
A2CAgentFactory,
AgentFactory,
DDPGAgentFactory,
DiscreteSACAgentFactory,
DQNAgentFactory,
NPGAgentFactory,
PGAgentFactory,
@ -28,6 +29,9 @@ from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
from tianshou.highlevel.module.actor import (
ActorFactory,
ActorFactoryDefault,
ActorFactoryTransientStorageDecorator,
ActorFuture,
ActorFutureProviderProtocol,
ContinuousActorType,
)
from tianshou.highlevel.module.critic import (
@ -35,11 +39,13 @@ from tianshou.highlevel.module.critic import (
CriticEnsembleFactoryDefault,
CriticFactory,
CriticFactoryDefault,
CriticFactoryReuseActor,
)
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
from tianshou.highlevel.params.policy_params import (
A2CParams,
DDPGParams,
DiscreteSACParams,
DQNParams,
NPGParams,
PGParams,
@ -263,9 +269,10 @@ class ExperimentBuilder:
return experiment
class _BuilderMixinActorFactory:
class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
def __init__(self, continuous_actor_type: ContinuousActorType):
self._continuous_actor_type = continuous_actor_type
self._actor_future = ActorFuture()
self._actor_factory: ActorFactory | None = None
def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
@ -286,11 +293,16 @@ class _BuilderMixinActorFactory:
)
return self
def get_actor_future(self) -> ActorFuture:
return self._actor_future
def _get_actor_factory(self) -> ActorFactory:
actor_factory: ActorFactory
if self._actor_factory is None:
return ActorFactoryDefault(self._continuous_actor_type)
actor_factory = ActorFactoryDefault(self._continuous_actor_type)
else:
return self._actor_factory
actor_factory = self._actor_factory
return ActorFactoryTransientStorageDecorator(actor_factory, self._actor_future)
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
@ -325,7 +337,8 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor
class _BuilderMixinCriticsFactory:
def __init__(self, num_critics: int):
def __init__(self, num_critics: int, actor_future_provider: ActorFutureProviderProtocol):
self._actor_future_provider = actor_future_provider
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory) -> Self:
@ -336,6 +349,12 @@ class _BuilderMixinCriticsFactory:
self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes)
return self
def _with_critic_factory_use_actor(self, idx: int) -> Self:
self._critic_factories[idx] = CriticFactoryReuseActor(
self._actor_future_provider.get_actor_future(),
)
return self
def _get_critic_factory(self, idx: int) -> CriticFactory:
factory = self._critic_factories[idx]
if factory is None:
@ -345,8 +364,8 @@ class _BuilderMixinCriticsFactory:
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self) -> None:
super().__init__(1)
def __init__(self, actor_future_provider: ActorFutureProviderProtocol = None) -> None:
super().__init__(1, actor_future_provider)
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
self._with_critic_factory(0, critic_factory)
@ -361,19 +380,17 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory):
def __init__(self) -> None:
super().__init__()
self._critic_use_actor_module = False
def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None:
super().__init__(actor_future_provider)
def with_critic_factory_use_actor(self) -> Self:
"""Makes the critic use the same network as the actor."""
self._critic_use_actor_module = True
return self
return self._with_critic_factory_use_actor(0)
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self) -> None:
super().__init__(2)
def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None:
super().__init__(2, actor_future_provider)
def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self:
for i in range(len(self._critic_factories)):
@ -388,6 +405,12 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
self._with_critic_factory_default(i, hidden_sizes)
return self
def with_common_critic_factory_use_actor(self) -> Self:
"""Makes all critics use the same network as the actor."""
for i in range(len(self._critic_factories)):
self._with_critic_factory_use_actor(i)
return self
def with_critic1_factory(self, critic_factory: CriticFactory) -> Self:
self._with_critic_factory(0, critic_factory)
return self
@ -399,6 +422,10 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
self._with_critic_factory_default(0, hidden_sizes)
return self
def with_critic1_factory_use_actor(self) -> Self:
"""Makes the critic use the same network as the actor."""
return self._with_critic_factory_use_actor(0)
def with_critic2_factory(self, critic_factory: CriticFactory) -> Self:
self._with_critic_factory(1, critic_factory)
return self
@ -410,6 +437,10 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
self._with_critic_factory_default(0, hidden_sizes)
return self
def with_critic2_factory_use_actor(self) -> Self:
"""Makes the second critic use the same network as the actor."""
return self._with_critic_factory_use_actor(1)
class _BuilderMixinCriticEnsembleFactory:
def __init__(self) -> None:
@ -475,7 +506,7 @@ class A2CExperimentBuilder(
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
self._params: A2CParams = A2CParams()
self._env_config = None
@ -483,7 +514,6 @@ class A2CExperimentBuilder(
self._params = params
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
return A2CAgentFactory(
self._params,
@ -491,7 +521,6 @@ class A2CExperimentBuilder(
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_optim_factory(),
self._critic_use_actor_module,
)
@ -508,14 +537,13 @@ class PPOExperimentBuilder(
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
self._params: PPOParams = PPOParams()
def with_ppo_params(self, params: PPOParams) -> Self:
self._params = params
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
return PPOAgentFactory(
self._params,
@ -523,7 +551,6 @@ class PPOExperimentBuilder(
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_optim_factory(),
self._critic_use_actor_module,
)
@ -540,14 +567,13 @@ class NPGExperimentBuilder(
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
self._params: NPGParams = NPGParams()
def with_npg_params(self, params: NPGParams) -> Self:
self._params = params
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
return NPGAgentFactory(
self._params,
@ -555,7 +581,6 @@ class NPGExperimentBuilder(
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_optim_factory(),
self._critic_use_actor_module,
)
@ -572,14 +597,13 @@ class TRPOExperimentBuilder(
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
self._params: TRPOParams = TRPOParams()
def with_trpo_params(self, params: TRPOParams) -> Self:
self._params = params
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
return TRPOAgentFactory(
self._params,
@ -587,7 +611,6 @@ class TRPOExperimentBuilder(
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_optim_factory(),
self._critic_use_actor_module,
)
@ -609,7 +632,6 @@ class DQNExperimentBuilder(
self._params = params
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
return DQNAgentFactory(
self._params,
@ -632,14 +654,13 @@ class DDPGExperimentBuilder(
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
self._params: DDPGParams = DDPGParams()
def with_ddpg_params(self, params: DDPGParams) -> Self:
self._params = params
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
return DDPGAgentFactory(
self._params,
@ -670,7 +691,6 @@ class REDQExperimentBuilder(
self._params = params
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
return REDQAgentFactory(
self._params,
@ -694,7 +714,7 @@ class SACExperimentBuilder(
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self, self)
self._params: SACParams = SACParams()
def with_sac_params(self, params: SACParams) -> Self:
@ -712,6 +732,37 @@ class SACExperimentBuilder(
)
class DiscreteSACExperimentBuilder(
ExperimentBuilder,
_BuilderMixinActorFactory,
_BuilderMixinDualCriticFactory,
):
def __init__(
self,
env_factory: EnvFactory,
experiment_config: ExperimentConfig | None = None,
sampling_config: SamplingConfig | None = None,
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
_BuilderMixinDualCriticFactory.__init__(self, self)
self._params: DiscreteSACParams = DiscreteSACParams()
def with_sac_params(self, params: DiscreteSACParams) -> Self:
self._params = params
return self
def _create_agent_factory(self) -> AgentFactory:
return DiscreteSACAgentFactory(
self._params,
self._sampling_config,
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_critic_factory(1),
self._get_optim_factory(),
)
class TD3ExperimentBuilder(
ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousDeterministic,
@ -725,7 +776,7 @@ class TD3ExperimentBuilder(
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self, self)
self._params: TD3Params = TD3Params()
def with_td3_params(self, params: TD3Params) -> Self:

View File

@ -1,6 +1,8 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Protocol
import torch
from torch import nn
@ -20,6 +22,18 @@ class ContinuousActorType(Enum):
UNSUPPORTED = "unsupported"
@dataclass
class ActorFuture:
"""Container, which, in the future, will hold an actor instance."""
actor: BaseActor | nn.Module | None = None
class ActorFutureProviderProtocol(Protocol):
def get_actor_future(self) -> ActorFuture:
pass
class ActorFactory(ToStringMixin, ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
@ -175,3 +189,26 @@ class ActorFactoryDiscreteNet(ActorFactory):
hidden_sizes=(),
device=device,
).to(device)
class ActorFactoryTransientStorageDecorator(ActorFactory):
def __init__(self, actor_factory: ActorFactory, actor_future: ActorFuture):
self.actor_factory = actor_factory
self._actor_future = actor_future
def __getstate__(self):
d = dict(self.__dict__)
del d["_actor_future"]
return d
def __setstate__(self, state):
self.__dict__ = state
self._actor_future = ActorFuture()
def _tostring_excludes(self):
return [*super()._tostring_excludes(), "_actor_future"]
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
module = self.actor_factory.create_module(envs, device)
self._actor_future.actor = module
return module

View File

@ -4,18 +4,30 @@ from collections.abc import Sequence
from torch import nn
from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.actor import ActorFuture
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import EnsembleLinear, Net
from tianshou.utils.net.common import BaseActor, EnsembleLinear, Net
from tianshou.utils.string import ToStringMixin
class CriticFactory(ToStringMixin, ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
pass
def create_module(
self,
envs: Environments,
device: TDevice,
use_action: bool,
discrete_last_size_use_action_shape: bool = False,
) -> nn.Module:
""":param envs: the environments
:param device: the torch device
:param use_action: whether to (additionally) expect the action as input
:param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape
:return: the module
"""
def create_module_opt(
self,
@ -24,8 +36,14 @@ class CriticFactory(ToStringMixin, ABC):
use_action: bool,
optim_factory: OptimizerFactory,
lr: float,
discrete_last_size_use_action_shape: bool = False,
) -> ModuleOpt:
module = self.create_module(envs, device, use_action)
module = self.create_module(
envs,
device,
use_action,
discrete_last_size_use_action_shape=discrete_last_size_use_action_shape,
)
opt = optim_factory.create_optimizer(module, lr)
return ModuleOpt(module, opt)
@ -38,7 +56,13 @@ class CriticFactoryDefault(CriticFactory):
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
def create_module(
self,
envs: Environments,
device: TDevice,
use_action: bool,
discrete_last_size_use_action_shape=False,
) -> nn.Module:
factory: CriticFactory
env_type = envs.get_type()
match env_type:
@ -48,14 +72,25 @@ class CriticFactoryDefault(CriticFactory):
factory = CriticFactoryDiscreteNet(self.hidden_sizes)
case _:
raise ValueError(f"{env_type} not supported")
return factory.create_module(envs, device, use_action)
return factory.create_module(
envs,
device,
use_action,
discrete_last_size_use_action_shape=discrete_last_size_use_action_shape,
)
class CriticFactoryContinuousNet(CriticFactory):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
def create_module(
self,
envs: Environments,
device: TDevice,
use_action: bool,
discrete_last_size_use_action_shape=False,
) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
envs.get_observation_shape(),
@ -74,7 +109,13 @@ class CriticFactoryDiscreteNet(CriticFactory):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
def create_module(
self,
envs: Environments,
device: TDevice,
use_action: bool,
discrete_last_size_use_action_shape=False,
) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
envs.get_observation_shape(),
@ -84,11 +125,50 @@ class CriticFactoryDiscreteNet(CriticFactory):
activation=nn.Tanh,
device=device,
)
critic = discrete.Critic(net_c, device=device).to(device)
last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1
critic = discrete.Critic(net_c, device=device, last_size=last_size).to(device)
init_linear_orthogonal(critic)
return critic
class CriticFactoryReuseActor(CriticFactory):
"""A critic factory which reuses the actor's preprocessing component.
This class is for internal use in experiment builders only.
"""
def __init__(self, actor_future: ActorFuture):
""":param actor_future: the object, which will hold the actor instance later when the critic is to be created"""
self.actor_future = actor_future
def _tostring_excludes(self) -> list[str]:
return ["actor_future"]
def create_module(
self,
envs: Environments,
device: TDevice,
use_action: bool,
discrete_last_size_use_action_shape=False,
) -> nn.Module:
actor = self.actor_future.actor
if not isinstance(actor, BaseActor):
raise ValueError(
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
)
if envs.get_type().is_discrete():
last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1
return discrete.Critic(
actor.get_preprocess_net(),
device=device,
last_size=last_size,
).to(device)
elif envs.get_type().is_continuous():
return continuous.Critic(actor.get_preprocess_net(), device=device).to(device)
else:
raise ValueError
class CriticEnsembleFactory:
@abstractmethod
def create_module(

View File

@ -359,6 +359,20 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
return transformers
@dataclass
class DiscreteSACParams(Params, ParamsMixinActorAndDualCritics):
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
estimation_step: int = 1
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
transformers.append(ParamTransformerAutoAlpha("alpha"))
return transformers
@dataclass
class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
discount_factor: float = 0.99