Support discrete SAC in high-level API
* Changed machanism for reusing actor's preprocessing module in critics to avoid special handling in AgentFactory implementations, improving separation of concerns: - Added CriticFactoryReuseActor as the new critic factory - Added ActorFactoryTransientStorageDecorator to pass on the actor data - Added helper classes ActorFuture, ActorFutureProviderProtocol * Add example atari_sac_hl
This commit is contained in:
parent
305b30a6c1
commit
799beb79b4
105
examples/atari/atari_sac_hl.py
Normal file
105
examples/atari/atari_sac_hl.py
Normal file
@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.atari.atari_network import (
|
||||
ActorFactoryAtariDQN,
|
||||
FeatureNetFactoryDQN,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
DiscreteSACExperimentBuilder,
|
||||
ExperimentConfig,
|
||||
)
|
||||
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
|
||||
from tianshou.highlevel.params.policy_params import DiscreteSACParams
|
||||
from tianshou.highlevel.params.policy_wrapper import (
|
||||
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||
)
|
||||
from tianshou.utils import logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "PongNoFrameskip-v4",
|
||||
scale_obs: bool = False,
|
||||
buffer_size: int = 100000,
|
||||
actor_lr: float = 1e-5,
|
||||
critic_lr: float = 1e-5,
|
||||
gamma: float = 0.99,
|
||||
n_step: int = 3,
|
||||
tau: float = 0.005,
|
||||
alpha: float = 0.05,
|
||||
auto_alpha: bool = False,
|
||||
alpha_lr: float = 3e-4,
|
||||
epoch: int = 100,
|
||||
step_per_epoch: int = 100000,
|
||||
step_per_collect: int = 10,
|
||||
update_per_step: float = 0.1,
|
||||
batch_size: int = 64,
|
||||
hidden_size: int = 512,
|
||||
training_num: int = 10,
|
||||
test_num: int = 10,
|
||||
frames_stack: int = 4,
|
||||
save_buffer_name: str | None = None, # TODO add support in high-level API?
|
||||
icm_lr_scale: float = 0.0,
|
||||
icm_reward_scale: float = 0.01,
|
||||
icm_forward_loss_weight: float = 0.2,
|
||||
):
|
||||
log_name = os.path.join(task, "sac", str(experiment_config.seed), datetime_tag())
|
||||
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
update_per_step=update_per_step,
|
||||
batch_size=batch_size,
|
||||
num_train_envs=training_num,
|
||||
num_test_envs=test_num,
|
||||
buffer_size=buffer_size,
|
||||
step_per_collect=step_per_collect,
|
||||
repeat_per_collect=None,
|
||||
replay_buffer_stack_num=frames_stack,
|
||||
replay_buffer_ignore_obs_next=True,
|
||||
replay_buffer_save_only_last_obs=True,
|
||||
)
|
||||
|
||||
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
|
||||
|
||||
builder = (
|
||||
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
.with_sac_params(
|
||||
DiscreteSACParams(
|
||||
actor_lr=actor_lr,
|
||||
critic1_lr=critic_lr,
|
||||
critic2_lr=critic_lr,
|
||||
gamma=gamma,
|
||||
tau=tau,
|
||||
alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha,
|
||||
estimation_step=n_step,
|
||||
),
|
||||
)
|
||||
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs, features_only=True))
|
||||
.with_common_critic_factory_use_actor()
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
)
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
FeatureNetFactoryDQN(),
|
||||
[hidden_size],
|
||||
actor_lr,
|
||||
icm_lr_scale,
|
||||
icm_reward_scale,
|
||||
icm_forward_loss_weight,
|
||||
),
|
||||
)
|
||||
experiment = builder.build()
|
||||
experiment.run(log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.run_main(lambda: CLI(main))
|
@ -22,9 +22,11 @@ from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.highlevel.params.policy_params import (
|
||||
A2CParams,
|
||||
DDPGParams,
|
||||
DiscreteSACParams,
|
||||
DQNParams,
|
||||
NPGParams,
|
||||
Params,
|
||||
ParamsMixinActorAndDualCritics,
|
||||
ParamTransformerData,
|
||||
PGParams,
|
||||
PPOParams,
|
||||
@ -39,6 +41,7 @@ from tianshou.policy import (
|
||||
A2CPolicy,
|
||||
BasePolicy,
|
||||
DDPGPolicy,
|
||||
DiscreteSACPolicy,
|
||||
DQNPolicy,
|
||||
NPGPolicy,
|
||||
PGPolicy,
|
||||
@ -49,13 +52,13 @@ from tianshou.policy import (
|
||||
TRPOPolicy,
|
||||
)
|
||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||
from tianshou.utils.net import continuous, discrete
|
||||
from tianshou.utils.net.common import ActorCritic, BaseActor
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
||||
TParams = TypeVar("TParams", bound=Params)
|
||||
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
|
||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
|
||||
|
||||
@ -247,7 +250,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
||||
)
|
||||
|
||||
|
||||
class _ActorCriticMixin:
|
||||
class _ActorCriticMixin: # TODO merge
|
||||
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
|
||||
|
||||
def __init__(
|
||||
@ -256,13 +259,11 @@ class _ActorCriticMixin:
|
||||
critic_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
critic_use_action: bool,
|
||||
critic_use_actor_module: bool,
|
||||
):
|
||||
self.actor_factory = actor_factory
|
||||
self.critic_factory = critic_factory
|
||||
self.optim_factory = optim_factory
|
||||
self.critic_use_action = critic_use_action
|
||||
self.critic_use_actor_module = critic_use_actor_module
|
||||
|
||||
def create_actor_critic_module_opt(
|
||||
self,
|
||||
@ -271,28 +272,7 @@ class _ActorCriticMixin:
|
||||
lr: float,
|
||||
) -> ActorCriticModuleOpt:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
critic: torch.nn.Module
|
||||
if self.critic_use_actor_module:
|
||||
if self.critic_use_action:
|
||||
raise ValueError(
|
||||
"The options critic_use_actor_module and critic_use_action are mutually exclusive",
|
||||
)
|
||||
if not isinstance(actor, BaseActor):
|
||||
raise ValueError(
|
||||
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
|
||||
)
|
||||
if envs.get_type().is_discrete():
|
||||
critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device)
|
||||
elif envs.get_type().is_continuous():
|
||||
critic = continuous.Critic(actor.get_preprocess_net(), device=device).to(device)
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
critic = self.critic_factory.create_module(
|
||||
envs,
|
||||
device,
|
||||
use_action=self.critic_use_action,
|
||||
)
|
||||
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
||||
return ActorCriticModuleOpt(actor_critic, optim)
|
||||
@ -349,7 +329,6 @@ class ActorCriticAgentFactory(
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
policy_class: type[TPolicy],
|
||||
critic_use_actor_module: bool,
|
||||
):
|
||||
super().__init__(sampling_config, optim_factory=optimizer_factory)
|
||||
_ActorCriticMixin.__init__(
|
||||
@ -358,7 +337,6 @@ class ActorCriticAgentFactory(
|
||||
critic_factory,
|
||||
optimizer_factory,
|
||||
critic_use_action=False,
|
||||
critic_use_actor_module=critic_use_actor_module,
|
||||
)
|
||||
self.params = params
|
||||
self.policy_class = policy_class
|
||||
@ -395,7 +373,6 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
critic_use_actor_module: bool,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
@ -404,7 +381,6 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
|
||||
critic_factory,
|
||||
optimizer_factory,
|
||||
A2CPolicy,
|
||||
critic_use_actor_module,
|
||||
)
|
||||
|
||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||
@ -419,7 +395,6 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
critic_use_actor_module: bool,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
@ -428,7 +403,6 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
||||
critic_factory,
|
||||
optimizer_factory,
|
||||
PPOPolicy,
|
||||
critic_use_actor_module,
|
||||
)
|
||||
|
||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||
@ -443,7 +417,6 @@ class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]):
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
critic_use_actor_module: bool,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
@ -452,7 +425,6 @@ class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]):
|
||||
critic_factory,
|
||||
optimizer_factory,
|
||||
NPGPolicy,
|
||||
critic_use_actor_module,
|
||||
)
|
||||
|
||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||
@ -467,7 +439,6 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
critic_use_actor_module: bool,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
@ -476,7 +447,6 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
|
||||
critic_factory,
|
||||
optimizer_factory,
|
||||
TRPOPolicy,
|
||||
critic_use_actor_module,
|
||||
)
|
||||
|
||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||
@ -619,6 +589,81 @@ class REDQAgentFactory(OffpolicyAgentFactory):
|
||||
)
|
||||
|
||||
|
||||
class ActorDualCriticsAgentFactory(
|
||||
OffpolicyAgentFactory, Generic[TActorDualCriticsParams, TPolicy], ABC,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
params: TActorDualCriticsParams,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic1_factory: CriticFactory,
|
||||
critic2_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(sampling_config, optim_factory)
|
||||
self.params = params
|
||||
self.actor_factory = actor_factory
|
||||
self.critic1_factory = critic1_factory
|
||||
self.critic2_factory = critic2_factory
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
@abstractmethod
|
||||
def _get_policy_class(self) -> type[TPolicy]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_discrete_last_size_use_action_shape(self) -> bool:
|
||||
pass
|
||||
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||
actor = self.actor_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
self.optim_factory,
|
||||
self.params.actor_lr,
|
||||
)
|
||||
use_action_shape = self._get_discrete_last_size_use_action_shape()
|
||||
critic1 = self.critic1_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
True,
|
||||
self.optim_factory,
|
||||
self.params.critic1_lr,
|
||||
discrete_last_size_use_action_shape=use_action_shape,
|
||||
)
|
||||
critic2 = self.critic2_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
True,
|
||||
self.optim_factory,
|
||||
self.params.critic2_lr,
|
||||
discrete_last_size_use_action_shape=use_action_shape,
|
||||
)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerData(
|
||||
envs=envs,
|
||||
device=device,
|
||||
optim_factory=self.optim_factory,
|
||||
actor=actor,
|
||||
critic1=critic1,
|
||||
critic2=critic2,
|
||||
),
|
||||
)
|
||||
policy_class = self._get_policy_class()
|
||||
return policy_class(
|
||||
actor=actor.module,
|
||||
actor_optim=actor.optim,
|
||||
critic=critic1.module,
|
||||
critic_optim=critic1.optim,
|
||||
critic2=critic2.module,
|
||||
critic2_optim=critic2.optim,
|
||||
action_space=envs.get_action_space(),
|
||||
observation_space=envs.get_observation_space(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class SACAgentFactory(OffpolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
@ -680,6 +725,14 @@ class SACAgentFactory(OffpolicyAgentFactory):
|
||||
)
|
||||
|
||||
|
||||
class DiscreteSACAgentFactory(ActorDualCriticsAgentFactory[DiscreteSACParams, DiscreteSACPolicy]):
|
||||
def _get_discrete_last_size_use_action_shape(self) -> bool:
|
||||
return True
|
||||
|
||||
def _get_policy_class(self) -> type[TPolicy]:
|
||||
return DiscreteSACPolicy
|
||||
|
||||
|
||||
class TD3AgentFactory(OffpolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -13,6 +13,7 @@ from tianshou.highlevel.agent import (
|
||||
A2CAgentFactory,
|
||||
AgentFactory,
|
||||
DDPGAgentFactory,
|
||||
DiscreteSACAgentFactory,
|
||||
DQNAgentFactory,
|
||||
NPGAgentFactory,
|
||||
PGAgentFactory,
|
||||
@ -28,6 +29,9 @@ from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
||||
from tianshou.highlevel.module.actor import (
|
||||
ActorFactory,
|
||||
ActorFactoryDefault,
|
||||
ActorFactoryTransientStorageDecorator,
|
||||
ActorFuture,
|
||||
ActorFutureProviderProtocol,
|
||||
ContinuousActorType,
|
||||
)
|
||||
from tianshou.highlevel.module.critic import (
|
||||
@ -35,11 +39,13 @@ from tianshou.highlevel.module.critic import (
|
||||
CriticEnsembleFactoryDefault,
|
||||
CriticFactory,
|
||||
CriticFactoryDefault,
|
||||
CriticFactoryReuseActor,
|
||||
)
|
||||
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
||||
from tianshou.highlevel.params.policy_params import (
|
||||
A2CParams,
|
||||
DDPGParams,
|
||||
DiscreteSACParams,
|
||||
DQNParams,
|
||||
NPGParams,
|
||||
PGParams,
|
||||
@ -263,9 +269,10 @@ class ExperimentBuilder:
|
||||
return experiment
|
||||
|
||||
|
||||
class _BuilderMixinActorFactory:
|
||||
class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
||||
def __init__(self, continuous_actor_type: ContinuousActorType):
|
||||
self._continuous_actor_type = continuous_actor_type
|
||||
self._actor_future = ActorFuture()
|
||||
self._actor_factory: ActorFactory | None = None
|
||||
|
||||
def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
|
||||
@ -286,11 +293,16 @@ class _BuilderMixinActorFactory:
|
||||
)
|
||||
return self
|
||||
|
||||
def get_actor_future(self) -> ActorFuture:
|
||||
return self._actor_future
|
||||
|
||||
def _get_actor_factory(self) -> ActorFactory:
|
||||
actor_factory: ActorFactory
|
||||
if self._actor_factory is None:
|
||||
return ActorFactoryDefault(self._continuous_actor_type)
|
||||
actor_factory = ActorFactoryDefault(self._continuous_actor_type)
|
||||
else:
|
||||
return self._actor_factory
|
||||
actor_factory = self._actor_factory
|
||||
return ActorFactoryTransientStorageDecorator(actor_factory, self._actor_future)
|
||||
|
||||
|
||||
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
||||
@ -325,7 +337,8 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor
|
||||
|
||||
|
||||
class _BuilderMixinCriticsFactory:
|
||||
def __init__(self, num_critics: int):
|
||||
def __init__(self, num_critics: int, actor_future_provider: ActorFutureProviderProtocol):
|
||||
self._actor_future_provider = actor_future_provider
|
||||
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
|
||||
|
||||
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory) -> Self:
|
||||
@ -336,6 +349,12 @@ class _BuilderMixinCriticsFactory:
|
||||
self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes)
|
||||
return self
|
||||
|
||||
def _with_critic_factory_use_actor(self, idx: int) -> Self:
|
||||
self._critic_factories[idx] = CriticFactoryReuseActor(
|
||||
self._actor_future_provider.get_actor_future(),
|
||||
)
|
||||
return self
|
||||
|
||||
def _get_critic_factory(self, idx: int) -> CriticFactory:
|
||||
factory = self._critic_factories[idx]
|
||||
if factory is None:
|
||||
@ -345,8 +364,8 @@ class _BuilderMixinCriticsFactory:
|
||||
|
||||
|
||||
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(1)
|
||||
def __init__(self, actor_future_provider: ActorFutureProviderProtocol = None) -> None:
|
||||
super().__init__(1, actor_future_provider)
|
||||
|
||||
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
||||
self._with_critic_factory(0, critic_factory)
|
||||
@ -361,19 +380,17 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||
|
||||
|
||||
class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._critic_use_actor_module = False
|
||||
def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None:
|
||||
super().__init__(actor_future_provider)
|
||||
|
||||
def with_critic_factory_use_actor(self) -> Self:
|
||||
"""Makes the critic use the same network as the actor."""
|
||||
self._critic_use_actor_module = True
|
||||
return self
|
||||
return self._with_critic_factory_use_actor(0)
|
||||
|
||||
|
||||
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(2)
|
||||
def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None:
|
||||
super().__init__(2, actor_future_provider)
|
||||
|
||||
def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
||||
for i in range(len(self._critic_factories)):
|
||||
@ -388,6 +405,12 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
self._with_critic_factory_default(i, hidden_sizes)
|
||||
return self
|
||||
|
||||
def with_common_critic_factory_use_actor(self) -> Self:
|
||||
"""Makes all critics use the same network as the actor."""
|
||||
for i in range(len(self._critic_factories)):
|
||||
self._with_critic_factory_use_actor(i)
|
||||
return self
|
||||
|
||||
def with_critic1_factory(self, critic_factory: CriticFactory) -> Self:
|
||||
self._with_critic_factory(0, critic_factory)
|
||||
return self
|
||||
@ -399,6 +422,10 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
return self
|
||||
|
||||
def with_critic1_factory_use_actor(self) -> Self:
|
||||
"""Makes the critic use the same network as the actor."""
|
||||
return self._with_critic_factory_use_actor(0)
|
||||
|
||||
def with_critic2_factory(self, critic_factory: CriticFactory) -> Self:
|
||||
self._with_critic_factory(1, critic_factory)
|
||||
return self
|
||||
@ -410,6 +437,10 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
return self
|
||||
|
||||
def with_critic2_factory_use_actor(self) -> Self:
|
||||
"""Makes the second critic use the same network as the actor."""
|
||||
return self._with_critic_factory_use_actor(1)
|
||||
|
||||
|
||||
class _BuilderMixinCriticEnsembleFactory:
|
||||
def __init__(self) -> None:
|
||||
@ -475,7 +506,7 @@ class A2CExperimentBuilder(
|
||||
):
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
|
||||
self._params: A2CParams = A2CParams()
|
||||
self._env_config = None
|
||||
|
||||
@ -483,7 +514,6 @@ class A2CExperimentBuilder(
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return A2CAgentFactory(
|
||||
self._params,
|
||||
@ -491,7 +521,6 @@ class A2CExperimentBuilder(
|
||||
self._get_actor_factory(),
|
||||
self._get_critic_factory(0),
|
||||
self._get_optim_factory(),
|
||||
self._critic_use_actor_module,
|
||||
)
|
||||
|
||||
|
||||
@ -508,14 +537,13 @@ class PPOExperimentBuilder(
|
||||
):
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
|
||||
self._params: PPOParams = PPOParams()
|
||||
|
||||
def with_ppo_params(self, params: PPOParams) -> Self:
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return PPOAgentFactory(
|
||||
self._params,
|
||||
@ -523,7 +551,6 @@ class PPOExperimentBuilder(
|
||||
self._get_actor_factory(),
|
||||
self._get_critic_factory(0),
|
||||
self._get_optim_factory(),
|
||||
self._critic_use_actor_module,
|
||||
)
|
||||
|
||||
|
||||
@ -540,14 +567,13 @@ class NPGExperimentBuilder(
|
||||
):
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
|
||||
self._params: NPGParams = NPGParams()
|
||||
|
||||
def with_npg_params(self, params: NPGParams) -> Self:
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return NPGAgentFactory(
|
||||
self._params,
|
||||
@ -555,7 +581,6 @@ class NPGExperimentBuilder(
|
||||
self._get_actor_factory(),
|
||||
self._get_critic_factory(0),
|
||||
self._get_optim_factory(),
|
||||
self._critic_use_actor_module,
|
||||
)
|
||||
|
||||
|
||||
@ -572,14 +597,13 @@ class TRPOExperimentBuilder(
|
||||
):
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
|
||||
self._params: TRPOParams = TRPOParams()
|
||||
|
||||
def with_trpo_params(self, params: TRPOParams) -> Self:
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return TRPOAgentFactory(
|
||||
self._params,
|
||||
@ -587,7 +611,6 @@ class TRPOExperimentBuilder(
|
||||
self._get_actor_factory(),
|
||||
self._get_critic_factory(0),
|
||||
self._get_optim_factory(),
|
||||
self._critic_use_actor_module,
|
||||
)
|
||||
|
||||
|
||||
@ -609,7 +632,6 @@ class DQNExperimentBuilder(
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return DQNAgentFactory(
|
||||
self._params,
|
||||
@ -632,14 +654,13 @@ class DDPGExperimentBuilder(
|
||||
):
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self)
|
||||
self._params: DDPGParams = DDPGParams()
|
||||
|
||||
def with_ddpg_params(self, params: DDPGParams) -> Self:
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return DDPGAgentFactory(
|
||||
self._params,
|
||||
@ -670,7 +691,6 @@ class REDQExperimentBuilder(
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return REDQAgentFactory(
|
||||
self._params,
|
||||
@ -694,7 +714,7 @@ class SACExperimentBuilder(
|
||||
):
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
_BuilderMixinDualCriticFactory.__init__(self)
|
||||
_BuilderMixinDualCriticFactory.__init__(self, self)
|
||||
self._params: SACParams = SACParams()
|
||||
|
||||
def with_sac_params(self, params: SACParams) -> Self:
|
||||
@ -712,6 +732,37 @@ class SACExperimentBuilder(
|
||||
)
|
||||
|
||||
|
||||
class DiscreteSACExperimentBuilder(
|
||||
ExperimentBuilder,
|
||||
_BuilderMixinActorFactory,
|
||||
_BuilderMixinDualCriticFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
env_factory: EnvFactory,
|
||||
experiment_config: ExperimentConfig | None = None,
|
||||
sampling_config: SamplingConfig | None = None,
|
||||
):
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
|
||||
_BuilderMixinDualCriticFactory.__init__(self, self)
|
||||
self._params: DiscreteSACParams = DiscreteSACParams()
|
||||
|
||||
def with_sac_params(self, params: DiscreteSACParams) -> Self:
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return DiscreteSACAgentFactory(
|
||||
self._params,
|
||||
self._sampling_config,
|
||||
self._get_actor_factory(),
|
||||
self._get_critic_factory(0),
|
||||
self._get_critic_factory(1),
|
||||
self._get_optim_factory(),
|
||||
)
|
||||
|
||||
|
||||
class TD3ExperimentBuilder(
|
||||
ExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousDeterministic,
|
||||
@ -725,7 +776,7 @@ class TD3ExperimentBuilder(
|
||||
):
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||
_BuilderMixinDualCriticFactory.__init__(self)
|
||||
_BuilderMixinDualCriticFactory.__init__(self, self)
|
||||
self._params: TD3Params = TD3Params()
|
||||
|
||||
def with_td3_params(self, params: TD3Params) -> Self:
|
||||
|
@ -1,6 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Protocol
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -20,6 +22,18 @@ class ContinuousActorType(Enum):
|
||||
UNSUPPORTED = "unsupported"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorFuture:
|
||||
"""Container, which, in the future, will hold an actor instance."""
|
||||
|
||||
actor: BaseActor | nn.Module | None = None
|
||||
|
||||
|
||||
class ActorFutureProviderProtocol(Protocol):
|
||||
def get_actor_future(self) -> ActorFuture:
|
||||
pass
|
||||
|
||||
|
||||
class ActorFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
|
||||
@ -175,3 +189,26 @@ class ActorFactoryDiscreteNet(ActorFactory):
|
||||
hidden_sizes=(),
|
||||
device=device,
|
||||
).to(device)
|
||||
|
||||
|
||||
class ActorFactoryTransientStorageDecorator(ActorFactory):
|
||||
def __init__(self, actor_factory: ActorFactory, actor_future: ActorFuture):
|
||||
self.actor_factory = actor_factory
|
||||
self._actor_future = actor_future
|
||||
|
||||
def __getstate__(self):
|
||||
d = dict(self.__dict__)
|
||||
del d["_actor_future"]
|
||||
return d
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
self._actor_future = ActorFuture()
|
||||
|
||||
def _tostring_excludes(self):
|
||||
return [*super()._tostring_excludes(), "_actor_future"]
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
|
||||
module = self.actor_factory.create_module(envs, device)
|
||||
self._actor_future.actor = module
|
||||
return module
|
||||
|
@ -4,18 +4,30 @@ from collections.abc import Sequence
|
||||
from torch import nn
|
||||
|
||||
from tianshou.highlevel.env import Environments, EnvType
|
||||
from tianshou.highlevel.module.actor import ActorFuture
|
||||
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
||||
from tianshou.highlevel.module.module_opt import ModuleOpt
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.utils.net import continuous, discrete
|
||||
from tianshou.utils.net.common import EnsembleLinear, Net
|
||||
from tianshou.utils.net.common import BaseActor, EnsembleLinear, Net
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
class CriticFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
pass
|
||||
def create_module(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
use_action: bool,
|
||||
discrete_last_size_use_action_shape: bool = False,
|
||||
) -> nn.Module:
|
||||
""":param envs: the environments
|
||||
:param device: the torch device
|
||||
:param use_action: whether to (additionally) expect the action as input
|
||||
:param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape
|
||||
:return: the module
|
||||
"""
|
||||
|
||||
def create_module_opt(
|
||||
self,
|
||||
@ -24,8 +36,14 @@ class CriticFactory(ToStringMixin, ABC):
|
||||
use_action: bool,
|
||||
optim_factory: OptimizerFactory,
|
||||
lr: float,
|
||||
discrete_last_size_use_action_shape: bool = False,
|
||||
) -> ModuleOpt:
|
||||
module = self.create_module(envs, device, use_action)
|
||||
module = self.create_module(
|
||||
envs,
|
||||
device,
|
||||
use_action,
|
||||
discrete_last_size_use_action_shape=discrete_last_size_use_action_shape,
|
||||
)
|
||||
opt = optim_factory.create_optimizer(module, lr)
|
||||
return ModuleOpt(module, opt)
|
||||
|
||||
@ -38,7 +56,13 @@ class CriticFactoryDefault(CriticFactory):
|
||||
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
def create_module(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
use_action: bool,
|
||||
discrete_last_size_use_action_shape=False,
|
||||
) -> nn.Module:
|
||||
factory: CriticFactory
|
||||
env_type = envs.get_type()
|
||||
match env_type:
|
||||
@ -48,14 +72,25 @@ class CriticFactoryDefault(CriticFactory):
|
||||
factory = CriticFactoryDiscreteNet(self.hidden_sizes)
|
||||
case _:
|
||||
raise ValueError(f"{env_type} not supported")
|
||||
return factory.create_module(envs, device, use_action)
|
||||
return factory.create_module(
|
||||
envs,
|
||||
device,
|
||||
use_action,
|
||||
discrete_last_size_use_action_shape=discrete_last_size_use_action_shape,
|
||||
)
|
||||
|
||||
|
||||
class CriticFactoryContinuousNet(CriticFactory):
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
def create_module(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
use_action: bool,
|
||||
discrete_last_size_use_action_shape=False,
|
||||
) -> nn.Module:
|
||||
action_shape = envs.get_action_shape() if use_action else 0
|
||||
net_c = Net(
|
||||
envs.get_observation_shape(),
|
||||
@ -74,7 +109,13 @@ class CriticFactoryDiscreteNet(CriticFactory):
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
def create_module(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
use_action: bool,
|
||||
discrete_last_size_use_action_shape=False,
|
||||
) -> nn.Module:
|
||||
action_shape = envs.get_action_shape() if use_action else 0
|
||||
net_c = Net(
|
||||
envs.get_observation_shape(),
|
||||
@ -84,11 +125,50 @@ class CriticFactoryDiscreteNet(CriticFactory):
|
||||
activation=nn.Tanh,
|
||||
device=device,
|
||||
)
|
||||
critic = discrete.Critic(net_c, device=device).to(device)
|
||||
last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1
|
||||
critic = discrete.Critic(net_c, device=device, last_size=last_size).to(device)
|
||||
init_linear_orthogonal(critic)
|
||||
return critic
|
||||
|
||||
|
||||
class CriticFactoryReuseActor(CriticFactory):
|
||||
"""A critic factory which reuses the actor's preprocessing component.
|
||||
|
||||
This class is for internal use in experiment builders only.
|
||||
"""
|
||||
|
||||
def __init__(self, actor_future: ActorFuture):
|
||||
""":param actor_future: the object, which will hold the actor instance later when the critic is to be created"""
|
||||
self.actor_future = actor_future
|
||||
|
||||
def _tostring_excludes(self) -> list[str]:
|
||||
return ["actor_future"]
|
||||
|
||||
def create_module(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
use_action: bool,
|
||||
discrete_last_size_use_action_shape=False,
|
||||
) -> nn.Module:
|
||||
actor = self.actor_future.actor
|
||||
if not isinstance(actor, BaseActor):
|
||||
raise ValueError(
|
||||
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
|
||||
)
|
||||
if envs.get_type().is_discrete():
|
||||
last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1
|
||||
return discrete.Critic(
|
||||
actor.get_preprocess_net(),
|
||||
device=device,
|
||||
last_size=last_size,
|
||||
).to(device)
|
||||
elif envs.get_type().is_continuous():
|
||||
return continuous.Critic(actor.get_preprocess_net(), device=device).to(device)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
class CriticEnsembleFactory:
|
||||
@abstractmethod
|
||||
def create_module(
|
||||
|
@ -359,6 +359,20 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
|
||||
return transformers
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscreteSACParams(Params, ParamsMixinActorAndDualCritics):
|
||||
tau: float = 0.005
|
||||
gamma: float = 0.99
|
||||
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
||||
estimation_step: int = 1
|
||||
|
||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||
transformers = super()._get_param_transformers()
|
||||
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
||||
transformers.append(ParamTransformerAutoAlpha("alpha"))
|
||||
return transformers
|
||||
|
||||
|
||||
@dataclass
|
||||
class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
|
||||
discount_factor: float = 0.99
|
||||
|
Loading…
x
Reference in New Issue
Block a user