Support discrete SAC in high-level API

* Changed machanism for reusing actor's preprocessing module in critics
  to avoid special handling in AgentFactory implementations, improving
  separation of concerns:
    - Added CriticFactoryReuseActor as the new critic factory
    - Added ActorFactoryTransientStorageDecorator to pass on the actor
      data
    - Added helper classes ActorFuture, ActorFutureProviderProtocol
* Add example atari_sac_hl
This commit is contained in:
Dominik Jain 2023-10-10 19:11:49 +02:00
parent 305b30a6c1
commit 799beb79b4
6 changed files with 417 additions and 77 deletions

View File

@ -0,0 +1,105 @@
#!/usr/bin/env python3
import os
from jsonargparse import CLI
from examples.atari.atari_network import (
ActorFactoryAtariDQN,
FeatureNetFactoryDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
DiscreteSACExperimentBuilder,
ExperimentConfig,
)
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
from tianshou.highlevel.params.policy_params import DiscreteSACParams
from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity,
)
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: bool = False,
buffer_size: int = 100000,
actor_lr: float = 1e-5,
critic_lr: float = 1e-5,
gamma: float = 0.99,
n_step: int = 3,
tau: float = 0.005,
alpha: float = 0.05,
auto_alpha: bool = False,
alpha_lr: float = 3e-4,
epoch: int = 100,
step_per_epoch: int = 100000,
step_per_collect: int = 10,
update_per_step: float = 0.1,
batch_size: int = 64,
hidden_size: int = 512,
training_num: int = 10,
test_num: int = 10,
frames_stack: int = 4,
save_buffer_name: str | None = None, # TODO add support in high-level API?
icm_lr_scale: float = 0.0,
icm_reward_scale: float = 0.01,
icm_forward_loss_weight: float = 0.2,
):
log_name = os.path.join(task, "sac", str(experiment_config.seed), datetime_tag())
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
update_per_step=update_per_step,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
repeat_per_collect=None,
replay_buffer_stack_num=frames_stack,
replay_buffer_ignore_obs_next=True,
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
builder = (
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_sac_params(
DiscreteSACParams(
actor_lr=actor_lr,
critic1_lr=critic_lr,
critic2_lr=critic_lr,
gamma=gamma,
tau=tau,
alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha,
estimation_step=n_step,
),
)
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs, features_only=True))
.with_common_critic_factory_use_actor()
.with_trainer_stop_callback(AtariStopCallback(task))
)
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity(
FeatureNetFactoryDQN(),
[hidden_size],
actor_lr,
icm_lr_scale,
icm_reward_scale,
icm_forward_loss_weight,
),
)
experiment = builder.build()
experiment.run(log_name)
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))

View File

@ -22,9 +22,11 @@ from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.policy_params import ( from tianshou.highlevel.params.policy_params import (
A2CParams, A2CParams,
DDPGParams, DDPGParams,
DiscreteSACParams,
DQNParams, DQNParams,
NPGParams, NPGParams,
Params, Params,
ParamsMixinActorAndDualCritics,
ParamTransformerData, ParamTransformerData,
PGParams, PGParams,
PPOParams, PPOParams,
@ -39,6 +41,7 @@ from tianshou.policy import (
A2CPolicy, A2CPolicy,
BasePolicy, BasePolicy,
DDPGPolicy, DDPGPolicy,
DiscreteSACPolicy,
DQNPolicy, DQNPolicy,
NPGPolicy, NPGPolicy,
PGPolicy, PGPolicy,
@ -49,13 +52,13 @@ from tianshou.policy import (
TRPOPolicy, TRPOPolicy,
) )
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils.net import continuous, discrete from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.common import ActorCritic, BaseActor
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_MODEL = "model"
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
TParams = TypeVar("TParams", bound=Params) TParams = TypeVar("TParams", bound=Params)
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
TPolicy = TypeVar("TPolicy", bound=BasePolicy) TPolicy = TypeVar("TPolicy", bound=BasePolicy)
@ -247,7 +250,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
) )
class _ActorCriticMixin: class _ActorCriticMixin: # TODO merge
"""Mixin for agents that use an ActorCritic module with a single optimizer.""" """Mixin for agents that use an ActorCritic module with a single optimizer."""
def __init__( def __init__(
@ -256,13 +259,11 @@ class _ActorCriticMixin:
critic_factory: CriticFactory, critic_factory: CriticFactory,
optim_factory: OptimizerFactory, optim_factory: OptimizerFactory,
critic_use_action: bool, critic_use_action: bool,
critic_use_actor_module: bool,
): ):
self.actor_factory = actor_factory self.actor_factory = actor_factory
self.critic_factory = critic_factory self.critic_factory = critic_factory
self.optim_factory = optim_factory self.optim_factory = optim_factory
self.critic_use_action = critic_use_action self.critic_use_action = critic_use_action
self.critic_use_actor_module = critic_use_actor_module
def create_actor_critic_module_opt( def create_actor_critic_module_opt(
self, self,
@ -271,28 +272,7 @@ class _ActorCriticMixin:
lr: float, lr: float,
) -> ActorCriticModuleOpt: ) -> ActorCriticModuleOpt:
actor = self.actor_factory.create_module(envs, device) actor = self.actor_factory.create_module(envs, device)
critic: torch.nn.Module critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
if self.critic_use_actor_module:
if self.critic_use_action:
raise ValueError(
"The options critic_use_actor_module and critic_use_action are mutually exclusive",
)
if not isinstance(actor, BaseActor):
raise ValueError(
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
)
if envs.get_type().is_discrete():
critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device)
elif envs.get_type().is_continuous():
critic = continuous.Critic(actor.get_preprocess_net(), device=device).to(device)
else:
raise ValueError
else:
critic = self.critic_factory.create_module(
envs,
device,
use_action=self.critic_use_action,
)
actor_critic = ActorCritic(actor, critic) actor_critic = ActorCritic(actor, critic)
optim = self.optim_factory.create_optimizer(actor_critic, lr) optim = self.optim_factory.create_optimizer(actor_critic, lr)
return ActorCriticModuleOpt(actor_critic, optim) return ActorCriticModuleOpt(actor_critic, optim)
@ -349,7 +329,6 @@ class ActorCriticAgentFactory(
critic_factory: CriticFactory, critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory, optimizer_factory: OptimizerFactory,
policy_class: type[TPolicy], policy_class: type[TPolicy],
critic_use_actor_module: bool,
): ):
super().__init__(sampling_config, optim_factory=optimizer_factory) super().__init__(sampling_config, optim_factory=optimizer_factory)
_ActorCriticMixin.__init__( _ActorCriticMixin.__init__(
@ -358,7 +337,6 @@ class ActorCriticAgentFactory(
critic_factory, critic_factory,
optimizer_factory, optimizer_factory,
critic_use_action=False, critic_use_action=False,
critic_use_actor_module=critic_use_actor_module,
) )
self.params = params self.params = params
self.policy_class = policy_class self.policy_class = policy_class
@ -395,7 +373,6 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic_factory: CriticFactory, critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory, optimizer_factory: OptimizerFactory,
critic_use_actor_module: bool,
): ):
super().__init__( super().__init__(
params, params,
@ -404,7 +381,6 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
critic_factory, critic_factory,
optimizer_factory, optimizer_factory,
A2CPolicy, A2CPolicy,
critic_use_actor_module,
) )
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
@ -419,7 +395,6 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic_factory: CriticFactory, critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory, optimizer_factory: OptimizerFactory,
critic_use_actor_module: bool,
): ):
super().__init__( super().__init__(
params, params,
@ -428,7 +403,6 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
critic_factory, critic_factory,
optimizer_factory, optimizer_factory,
PPOPolicy, PPOPolicy,
critic_use_actor_module,
) )
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
@ -443,7 +417,6 @@ class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]):
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic_factory: CriticFactory, critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory, optimizer_factory: OptimizerFactory,
critic_use_actor_module: bool,
): ):
super().__init__( super().__init__(
params, params,
@ -452,7 +425,6 @@ class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]):
critic_factory, critic_factory,
optimizer_factory, optimizer_factory,
NPGPolicy, NPGPolicy,
critic_use_actor_module,
) )
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
@ -467,7 +439,6 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic_factory: CriticFactory, critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory, optimizer_factory: OptimizerFactory,
critic_use_actor_module: bool,
): ):
super().__init__( super().__init__(
params, params,
@ -476,7 +447,6 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
critic_factory, critic_factory,
optimizer_factory, optimizer_factory,
TRPOPolicy, TRPOPolicy,
critic_use_actor_module,
) )
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
@ -619,6 +589,81 @@ class REDQAgentFactory(OffpolicyAgentFactory):
) )
class ActorDualCriticsAgentFactory(
OffpolicyAgentFactory, Generic[TActorDualCriticsParams, TPolicy], ABC,
):
def __init__(
self,
params: TActorDualCriticsParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic1_factory: CriticFactory,
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
self.params = params
self.actor_factory = actor_factory
self.critic1_factory = critic1_factory
self.critic2_factory = critic2_factory
self.optim_factory = optim_factory
@abstractmethod
def _get_policy_class(self) -> type[TPolicy]:
pass
@abstractmethod
def _get_discrete_last_size_use_action_shape(self) -> bool:
pass
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
actor = self.actor_factory.create_module_opt(
envs,
device,
self.optim_factory,
self.params.actor_lr,
)
use_action_shape = self._get_discrete_last_size_use_action_shape()
critic1 = self.critic1_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic1_lr,
discrete_last_size_use_action_shape=use_action_shape,
)
critic2 = self.critic2_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic2_lr,
discrete_last_size_use_action_shape=use_action_shape,
)
kwargs = self.params.create_kwargs(
ParamTransformerData(
envs=envs,
device=device,
optim_factory=self.optim_factory,
actor=actor,
critic1=critic1,
critic2=critic2,
),
)
policy_class = self._get_policy_class()
return policy_class(
actor=actor.module,
actor_optim=actor.optim,
critic=critic1.module,
critic_optim=critic1.optim,
critic2=critic2.module,
critic2_optim=critic2.optim,
action_space=envs.get_action_space(),
observation_space=envs.get_observation_space(),
**kwargs,
)
class SACAgentFactory(OffpolicyAgentFactory): class SACAgentFactory(OffpolicyAgentFactory):
def __init__( def __init__(
self, self,
@ -680,6 +725,14 @@ class SACAgentFactory(OffpolicyAgentFactory):
) )
class DiscreteSACAgentFactory(ActorDualCriticsAgentFactory[DiscreteSACParams, DiscreteSACPolicy]):
def _get_discrete_last_size_use_action_shape(self) -> bool:
return True
def _get_policy_class(self) -> type[TPolicy]:
return DiscreteSACPolicy
class TD3AgentFactory(OffpolicyAgentFactory): class TD3AgentFactory(OffpolicyAgentFactory):
def __init__( def __init__(
self, self,

View File

@ -13,6 +13,7 @@ from tianshou.highlevel.agent import (
A2CAgentFactory, A2CAgentFactory,
AgentFactory, AgentFactory,
DDPGAgentFactory, DDPGAgentFactory,
DiscreteSACAgentFactory,
DQNAgentFactory, DQNAgentFactory,
NPGAgentFactory, NPGAgentFactory,
PGAgentFactory, PGAgentFactory,
@ -28,6 +29,9 @@ from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
from tianshou.highlevel.module.actor import ( from tianshou.highlevel.module.actor import (
ActorFactory, ActorFactory,
ActorFactoryDefault, ActorFactoryDefault,
ActorFactoryTransientStorageDecorator,
ActorFuture,
ActorFutureProviderProtocol,
ContinuousActorType, ContinuousActorType,
) )
from tianshou.highlevel.module.critic import ( from tianshou.highlevel.module.critic import (
@ -35,11 +39,13 @@ from tianshou.highlevel.module.critic import (
CriticEnsembleFactoryDefault, CriticEnsembleFactoryDefault,
CriticFactory, CriticFactory,
CriticFactoryDefault, CriticFactoryDefault,
CriticFactoryReuseActor,
) )
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
from tianshou.highlevel.params.policy_params import ( from tianshou.highlevel.params.policy_params import (
A2CParams, A2CParams,
DDPGParams, DDPGParams,
DiscreteSACParams,
DQNParams, DQNParams,
NPGParams, NPGParams,
PGParams, PGParams,
@ -263,9 +269,10 @@ class ExperimentBuilder:
return experiment return experiment
class _BuilderMixinActorFactory: class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
def __init__(self, continuous_actor_type: ContinuousActorType): def __init__(self, continuous_actor_type: ContinuousActorType):
self._continuous_actor_type = continuous_actor_type self._continuous_actor_type = continuous_actor_type
self._actor_future = ActorFuture()
self._actor_factory: ActorFactory | None = None self._actor_factory: ActorFactory | None = None
def with_actor_factory(self, actor_factory: ActorFactory) -> Self: def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
@ -286,11 +293,16 @@ class _BuilderMixinActorFactory:
) )
return self return self
def get_actor_future(self) -> ActorFuture:
return self._actor_future
def _get_actor_factory(self) -> ActorFactory: def _get_actor_factory(self) -> ActorFactory:
actor_factory: ActorFactory
if self._actor_factory is None: if self._actor_factory is None:
return ActorFactoryDefault(self._continuous_actor_type) actor_factory = ActorFactoryDefault(self._continuous_actor_type)
else: else:
return self._actor_factory actor_factory = self._actor_factory
return ActorFactoryTransientStorageDecorator(actor_factory, self._actor_future)
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory): class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
@ -325,7 +337,8 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor
class _BuilderMixinCriticsFactory: class _BuilderMixinCriticsFactory:
def __init__(self, num_critics: int): def __init__(self, num_critics: int, actor_future_provider: ActorFutureProviderProtocol):
self._actor_future_provider = actor_future_provider
self._critic_factories: list[CriticFactory | None] = [None] * num_critics self._critic_factories: list[CriticFactory | None] = [None] * num_critics
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory) -> Self: def _with_critic_factory(self, idx: int, critic_factory: CriticFactory) -> Self:
@ -336,6 +349,12 @@ class _BuilderMixinCriticsFactory:
self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes) self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes)
return self return self
def _with_critic_factory_use_actor(self, idx: int) -> Self:
self._critic_factories[idx] = CriticFactoryReuseActor(
self._actor_future_provider.get_actor_future(),
)
return self
def _get_critic_factory(self, idx: int) -> CriticFactory: def _get_critic_factory(self, idx: int) -> CriticFactory:
factory = self._critic_factories[idx] factory = self._critic_factories[idx]
if factory is None: if factory is None:
@ -345,8 +364,8 @@ class _BuilderMixinCriticsFactory:
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self) -> None: def __init__(self, actor_future_provider: ActorFutureProviderProtocol = None) -> None:
super().__init__(1) super().__init__(1, actor_future_provider)
def with_critic_factory(self, critic_factory: CriticFactory) -> Self: def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
self._with_critic_factory(0, critic_factory) self._with_critic_factory(0, critic_factory)
@ -361,19 +380,17 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory): class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory):
def __init__(self) -> None: def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None:
super().__init__() super().__init__(actor_future_provider)
self._critic_use_actor_module = False
def with_critic_factory_use_actor(self) -> Self: def with_critic_factory_use_actor(self) -> Self:
"""Makes the critic use the same network as the actor.""" """Makes the critic use the same network as the actor."""
self._critic_use_actor_module = True return self._with_critic_factory_use_actor(0)
return self
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self) -> None: def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None:
super().__init__(2) super().__init__(2, actor_future_provider)
def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self: def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self:
for i in range(len(self._critic_factories)): for i in range(len(self._critic_factories)):
@ -388,6 +405,12 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
self._with_critic_factory_default(i, hidden_sizes) self._with_critic_factory_default(i, hidden_sizes)
return self return self
def with_common_critic_factory_use_actor(self) -> Self:
"""Makes all critics use the same network as the actor."""
for i in range(len(self._critic_factories)):
self._with_critic_factory_use_actor(i)
return self
def with_critic1_factory(self, critic_factory: CriticFactory) -> Self: def with_critic1_factory(self, critic_factory: CriticFactory) -> Self:
self._with_critic_factory(0, critic_factory) self._with_critic_factory(0, critic_factory)
return self return self
@ -399,6 +422,10 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes)
return self return self
def with_critic1_factory_use_actor(self) -> Self:
"""Makes the critic use the same network as the actor."""
return self._with_critic_factory_use_actor(0)
def with_critic2_factory(self, critic_factory: CriticFactory) -> Self: def with_critic2_factory(self, critic_factory: CriticFactory) -> Self:
self._with_critic_factory(1, critic_factory) self._with_critic_factory(1, critic_factory)
return self return self
@ -410,6 +437,10 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes)
return self return self
def with_critic2_factory_use_actor(self) -> Self:
"""Makes the second critic use the same network as the actor."""
return self._with_critic_factory_use_actor(1)
class _BuilderMixinCriticEnsembleFactory: class _BuilderMixinCriticEnsembleFactory:
def __init__(self) -> None: def __init__(self) -> None:
@ -475,7 +506,7 @@ class A2CExperimentBuilder(
): ):
super().__init__(env_factory, experiment_config, sampling_config) super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
self._params: A2CParams = A2CParams() self._params: A2CParams = A2CParams()
self._env_config = None self._env_config = None
@ -483,7 +514,6 @@ class A2CExperimentBuilder(
self._params = params self._params = params
return self return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory: def _create_agent_factory(self) -> AgentFactory:
return A2CAgentFactory( return A2CAgentFactory(
self._params, self._params,
@ -491,7 +521,6 @@ class A2CExperimentBuilder(
self._get_actor_factory(), self._get_actor_factory(),
self._get_critic_factory(0), self._get_critic_factory(0),
self._get_optim_factory(), self._get_optim_factory(),
self._critic_use_actor_module,
) )
@ -508,14 +537,13 @@ class PPOExperimentBuilder(
): ):
super().__init__(env_factory, experiment_config, sampling_config) super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
self._params: PPOParams = PPOParams() self._params: PPOParams = PPOParams()
def with_ppo_params(self, params: PPOParams) -> Self: def with_ppo_params(self, params: PPOParams) -> Self:
self._params = params self._params = params
return self return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory: def _create_agent_factory(self) -> AgentFactory:
return PPOAgentFactory( return PPOAgentFactory(
self._params, self._params,
@ -523,7 +551,6 @@ class PPOExperimentBuilder(
self._get_actor_factory(), self._get_actor_factory(),
self._get_critic_factory(0), self._get_critic_factory(0),
self._get_optim_factory(), self._get_optim_factory(),
self._critic_use_actor_module,
) )
@ -540,14 +567,13 @@ class NPGExperimentBuilder(
): ):
super().__init__(env_factory, experiment_config, sampling_config) super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
self._params: NPGParams = NPGParams() self._params: NPGParams = NPGParams()
def with_npg_params(self, params: NPGParams) -> Self: def with_npg_params(self, params: NPGParams) -> Self:
self._params = params self._params = params
return self return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory: def _create_agent_factory(self) -> AgentFactory:
return NPGAgentFactory( return NPGAgentFactory(
self._params, self._params,
@ -555,7 +581,6 @@ class NPGExperimentBuilder(
self._get_actor_factory(), self._get_actor_factory(),
self._get_critic_factory(0), self._get_critic_factory(0),
self._get_optim_factory(), self._get_optim_factory(),
self._critic_use_actor_module,
) )
@ -572,14 +597,13 @@ class TRPOExperimentBuilder(
): ):
super().__init__(env_factory, experiment_config, sampling_config) super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
self._params: TRPOParams = TRPOParams() self._params: TRPOParams = TRPOParams()
def with_trpo_params(self, params: TRPOParams) -> Self: def with_trpo_params(self, params: TRPOParams) -> Self:
self._params = params self._params = params
return self return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory: def _create_agent_factory(self) -> AgentFactory:
return TRPOAgentFactory( return TRPOAgentFactory(
self._params, self._params,
@ -587,7 +611,6 @@ class TRPOExperimentBuilder(
self._get_actor_factory(), self._get_actor_factory(),
self._get_critic_factory(0), self._get_critic_factory(0),
self._get_optim_factory(), self._get_optim_factory(),
self._critic_use_actor_module,
) )
@ -609,7 +632,6 @@ class DQNExperimentBuilder(
self._params = params self._params = params
return self return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory: def _create_agent_factory(self) -> AgentFactory:
return DQNAgentFactory( return DQNAgentFactory(
self._params, self._params,
@ -632,14 +654,13 @@ class DDPGExperimentBuilder(
): ):
super().__init__(env_factory, experiment_config, sampling_config) super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
self._params: DDPGParams = DDPGParams() self._params: DDPGParams = DDPGParams()
def with_ddpg_params(self, params: DDPGParams) -> Self: def with_ddpg_params(self, params: DDPGParams) -> Self:
self._params = params self._params = params
return self return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory: def _create_agent_factory(self) -> AgentFactory:
return DDPGAgentFactory( return DDPGAgentFactory(
self._params, self._params,
@ -670,7 +691,6 @@ class REDQExperimentBuilder(
self._params = params self._params = params
return self return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory: def _create_agent_factory(self) -> AgentFactory:
return REDQAgentFactory( return REDQAgentFactory(
self._params, self._params,
@ -694,7 +714,7 @@ class SACExperimentBuilder(
): ):
super().__init__(env_factory, experiment_config, sampling_config) super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self) _BuilderMixinDualCriticFactory.__init__(self, self)
self._params: SACParams = SACParams() self._params: SACParams = SACParams()
def with_sac_params(self, params: SACParams) -> Self: def with_sac_params(self, params: SACParams) -> Self:
@ -712,6 +732,37 @@ class SACExperimentBuilder(
) )
class DiscreteSACExperimentBuilder(
ExperimentBuilder,
_BuilderMixinActorFactory,
_BuilderMixinDualCriticFactory,
):
def __init__(
self,
env_factory: EnvFactory,
experiment_config: ExperimentConfig | None = None,
sampling_config: SamplingConfig | None = None,
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
_BuilderMixinDualCriticFactory.__init__(self, self)
self._params: DiscreteSACParams = DiscreteSACParams()
def with_sac_params(self, params: DiscreteSACParams) -> Self:
self._params = params
return self
def _create_agent_factory(self) -> AgentFactory:
return DiscreteSACAgentFactory(
self._params,
self._sampling_config,
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_critic_factory(1),
self._get_optim_factory(),
)
class TD3ExperimentBuilder( class TD3ExperimentBuilder(
ExperimentBuilder, ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousDeterministic, _BuilderMixinActorFactory_ContinuousDeterministic,
@ -725,7 +776,7 @@ class TD3ExperimentBuilder(
): ):
super().__init__(env_factory, experiment_config, sampling_config) super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self) _BuilderMixinDualCriticFactory.__init__(self, self)
self._params: TD3Params = TD3Params() self._params: TD3Params = TD3Params()
def with_td3_params(self, params: TD3Params) -> Self: def with_td3_params(self, params: TD3Params) -> Self:

View File

@ -1,6 +1,8 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Protocol
import torch import torch
from torch import nn from torch import nn
@ -20,6 +22,18 @@ class ContinuousActorType(Enum):
UNSUPPORTED = "unsupported" UNSUPPORTED = "unsupported"
@dataclass
class ActorFuture:
"""Container, which, in the future, will hold an actor instance."""
actor: BaseActor | nn.Module | None = None
class ActorFutureProviderProtocol(Protocol):
def get_actor_future(self) -> ActorFuture:
pass
class ActorFactory(ToStringMixin, ABC): class ActorFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
@ -175,3 +189,26 @@ class ActorFactoryDiscreteNet(ActorFactory):
hidden_sizes=(), hidden_sizes=(),
device=device, device=device,
).to(device) ).to(device)
class ActorFactoryTransientStorageDecorator(ActorFactory):
def __init__(self, actor_factory: ActorFactory, actor_future: ActorFuture):
self.actor_factory = actor_factory
self._actor_future = actor_future
def __getstate__(self):
d = dict(self.__dict__)
del d["_actor_future"]
return d
def __setstate__(self, state):
self.__dict__ = state
self._actor_future = ActorFuture()
def _tostring_excludes(self):
return [*super()._tostring_excludes(), "_actor_future"]
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
module = self.actor_factory.create_module(envs, device)
self._actor_future.actor = module
return module

View File

@ -4,18 +4,30 @@ from collections.abc import Sequence
from torch import nn from torch import nn
from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.actor import ActorFuture
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.highlevel.module.module_opt import ModuleOpt from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous, discrete from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.net.common import BaseActor, EnsembleLinear, Net
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
class CriticFactory(ToStringMixin, ABC): class CriticFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: def create_module(
pass self,
envs: Environments,
device: TDevice,
use_action: bool,
discrete_last_size_use_action_shape: bool = False,
) -> nn.Module:
""":param envs: the environments
:param device: the torch device
:param use_action: whether to (additionally) expect the action as input
:param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape
:return: the module
"""
def create_module_opt( def create_module_opt(
self, self,
@ -24,8 +36,14 @@ class CriticFactory(ToStringMixin, ABC):
use_action: bool, use_action: bool,
optim_factory: OptimizerFactory, optim_factory: OptimizerFactory,
lr: float, lr: float,
discrete_last_size_use_action_shape: bool = False,
) -> ModuleOpt: ) -> ModuleOpt:
module = self.create_module(envs, device, use_action) module = self.create_module(
envs,
device,
use_action,
discrete_last_size_use_action_shape=discrete_last_size_use_action_shape,
)
opt = optim_factory.create_optimizer(module, lr) opt = optim_factory.create_optimizer(module, lr)
return ModuleOpt(module, opt) return ModuleOpt(module, opt)
@ -38,7 +56,13 @@ class CriticFactoryDefault(CriticFactory):
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES): def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: def create_module(
self,
envs: Environments,
device: TDevice,
use_action: bool,
discrete_last_size_use_action_shape=False,
) -> nn.Module:
factory: CriticFactory factory: CriticFactory
env_type = envs.get_type() env_type = envs.get_type()
match env_type: match env_type:
@ -48,14 +72,25 @@ class CriticFactoryDefault(CriticFactory):
factory = CriticFactoryDiscreteNet(self.hidden_sizes) factory = CriticFactoryDiscreteNet(self.hidden_sizes)
case _: case _:
raise ValueError(f"{env_type} not supported") raise ValueError(f"{env_type} not supported")
return factory.create_module(envs, device, use_action) return factory.create_module(
envs,
device,
use_action,
discrete_last_size_use_action_shape=discrete_last_size_use_action_shape,
)
class CriticFactoryContinuousNet(CriticFactory): class CriticFactoryContinuousNet(CriticFactory):
def __init__(self, hidden_sizes: Sequence[int]): def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: def create_module(
self,
envs: Environments,
device: TDevice,
use_action: bool,
discrete_last_size_use_action_shape=False,
) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0 action_shape = envs.get_action_shape() if use_action else 0
net_c = Net( net_c = Net(
envs.get_observation_shape(), envs.get_observation_shape(),
@ -74,7 +109,13 @@ class CriticFactoryDiscreteNet(CriticFactory):
def __init__(self, hidden_sizes: Sequence[int]): def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: def create_module(
self,
envs: Environments,
device: TDevice,
use_action: bool,
discrete_last_size_use_action_shape=False,
) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0 action_shape = envs.get_action_shape() if use_action else 0
net_c = Net( net_c = Net(
envs.get_observation_shape(), envs.get_observation_shape(),
@ -84,11 +125,50 @@ class CriticFactoryDiscreteNet(CriticFactory):
activation=nn.Tanh, activation=nn.Tanh,
device=device, device=device,
) )
critic = discrete.Critic(net_c, device=device).to(device) last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1
critic = discrete.Critic(net_c, device=device, last_size=last_size).to(device)
init_linear_orthogonal(critic) init_linear_orthogonal(critic)
return critic return critic
class CriticFactoryReuseActor(CriticFactory):
"""A critic factory which reuses the actor's preprocessing component.
This class is for internal use in experiment builders only.
"""
def __init__(self, actor_future: ActorFuture):
""":param actor_future: the object, which will hold the actor instance later when the critic is to be created"""
self.actor_future = actor_future
def _tostring_excludes(self) -> list[str]:
return ["actor_future"]
def create_module(
self,
envs: Environments,
device: TDevice,
use_action: bool,
discrete_last_size_use_action_shape=False,
) -> nn.Module:
actor = self.actor_future.actor
if not isinstance(actor, BaseActor):
raise ValueError(
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
)
if envs.get_type().is_discrete():
last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1
return discrete.Critic(
actor.get_preprocess_net(),
device=device,
last_size=last_size,
).to(device)
elif envs.get_type().is_continuous():
return continuous.Critic(actor.get_preprocess_net(), device=device).to(device)
else:
raise ValueError
class CriticEnsembleFactory: class CriticEnsembleFactory:
@abstractmethod @abstractmethod
def create_module( def create_module(

View File

@ -359,6 +359,20 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
return transformers return transformers
@dataclass
class DiscreteSACParams(Params, ParamsMixinActorAndDualCritics):
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
estimation_step: int = 1
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
transformers.append(ParamTransformerAutoAlpha("alpha"))
return transformers
@dataclass @dataclass
class DQNParams(Params, ParamsMixinLearningRateWithScheduler): class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
discount_factor: float = 0.99 discount_factor: float = 0.99