Simplify agent factories by making better use of base classes

This commit is contained in:
Dominik Jain 2023-10-10 20:07:30 +02:00
parent 799beb79b4
commit c7d0b6b4b2

View File

@ -27,6 +27,7 @@ from tianshou.highlevel.params.policy_params import (
NPGParams, NPGParams,
Params, Params,
ParamsMixinActorAndDualCritics, ParamsMixinActorAndDualCritics,
ParamsMixinLearningRateWithScheduler,
ParamTransformerData, ParamTransformerData,
PGParams, PGParams,
PPOParams, PPOParams,
@ -58,6 +59,7 @@ from tianshou.utils.string import ToStringMixin
CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_MODEL = "model"
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
TParams = TypeVar("TParams", bound=Params) TParams = TypeVar("TParams", bound=Params)
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics) TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
TPolicy = TypeVar("TPolicy", bound=BasePolicy) TPolicy = TypeVar("TPolicy", bound=BasePolicy)
@ -316,37 +318,43 @@ class PGAgentFactory(OnpolicyAgentFactory):
class ActorCriticAgentFactory( class ActorCriticAgentFactory(
Generic[TParams, TPolicy], Generic[TActorCriticParams, TPolicy],
OnpolicyAgentFactory, OnpolicyAgentFactory,
_ActorCriticMixin,
ABC, ABC,
): ):
def __init__( def __init__(
self, self,
params: TParams, params: TActorCriticParams,
sampling_config: SamplingConfig, sampling_config: SamplingConfig,
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic_factory: CriticFactory, critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory, optimizer_factory: OptimizerFactory,
policy_class: type[TPolicy],
): ):
super().__init__(sampling_config, optim_factory=optimizer_factory) super().__init__(sampling_config, optim_factory=optimizer_factory)
_ActorCriticMixin.__init__(
self,
actor_factory,
critic_factory,
optimizer_factory,
critic_use_action=False,
)
self.params = params self.params = params
self.policy_class = policy_class self.actor_factory = actor_factory
self.critic_factory = critic_factory
self.optim_factory = optimizer_factory
self.critic_use_action = False
@abstractmethod @abstractmethod
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: def _get_policy_class(self) -> type[TPolicy]:
pass pass
def create_actor_critic_module_opt(
self,
envs: Environments,
device: TDevice,
lr: float,
) -> ActorCriticModuleOpt:
actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
actor_critic = ActorCritic(actor, critic)
optim = self.optim_factory.create_optimizer(actor_critic, lr)
return ActorCriticModuleOpt(actor_critic, optim)
def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
actor_critic = self._create_actor_critic(envs, device) actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
kwargs = self.params.create_kwargs( kwargs = self.params.create_kwargs(
ParamTransformerData( ParamTransformerData(
envs=envs, envs=envs,
@ -362,95 +370,28 @@ class ActorCriticAgentFactory(
return kwargs return kwargs
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
return self.policy_class(**self._create_kwargs(envs, device)) policy_class = self._get_policy_class()
return policy_class(**self._create_kwargs(envs, device))
class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]): class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
def __init__( def _get_policy_class(self) -> type[A2CPolicy]:
self, return A2CPolicy
params: A2CParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
):
super().__init__(
params,
sampling_config,
actor_factory,
critic_factory,
optimizer_factory,
A2CPolicy,
)
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]): class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
def __init__( def _get_policy_class(self) -> type[PPOPolicy]:
self, return PPOPolicy
params: PPOParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
):
super().__init__(
params,
sampling_config,
actor_factory,
critic_factory,
optimizer_factory,
PPOPolicy,
)
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]): class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]):
def __init__( def _get_policy_class(self) -> type[NPGPolicy]:
self, return NPGPolicy
params: NPGParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
):
super().__init__(
params,
sampling_config,
actor_factory,
critic_factory,
optimizer_factory,
NPGPolicy,
)
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]): class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
def __init__( def _get_policy_class(self) -> type[TRPOPolicy]:
self, return TRPOPolicy
params: TRPOParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
):
super().__init__(
params,
sampling_config,
actor_factory,
critic_factory,
optimizer_factory,
TRPOPolicy,
)
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
class DQNAgentFactory(OffpolicyAgentFactory): class DQNAgentFactory(OffpolicyAgentFactory):
@ -590,7 +531,9 @@ class REDQAgentFactory(OffpolicyAgentFactory):
class ActorDualCriticsAgentFactory( class ActorDualCriticsAgentFactory(
OffpolicyAgentFactory, Generic[TActorDualCriticsParams, TPolicy], ABC, OffpolicyAgentFactory,
Generic[TActorDualCriticsParams, TPolicy],
ABC,
): ):
def __init__( def __init__(
self, self,
@ -612,9 +555,8 @@ class ActorDualCriticsAgentFactory(
def _get_policy_class(self) -> type[TPolicy]: def _get_policy_class(self) -> type[TPolicy]:
pass pass
@abstractmethod
def _get_discrete_last_size_use_action_shape(self) -> bool: def _get_discrete_last_size_use_action_shape(self) -> bool:
pass return True
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
actor = self.actor_factory.create_module_opt( actor = self.actor_factory.create_module_opt(
@ -664,131 +606,16 @@ class ActorDualCriticsAgentFactory(
) )
class SACAgentFactory(OffpolicyAgentFactory): class SACAgentFactory(ActorDualCriticsAgentFactory[SACParams, SACPolicy]):
def __init__( def _get_policy_class(self) -> type[SACPolicy]:
self, return SACPolicy
params: SACParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic1_factory: CriticFactory,
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
self.params = params
self.actor_factory = actor_factory
self.critic1_factory = critic1_factory
self.critic2_factory = critic2_factory
self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.actor_factory.create_module_opt(
envs,
device,
self.optim_factory,
self.params.actor_lr,
)
critic1 = self.critic1_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic1_lr,
)
critic2 = self.critic2_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic2_lr,
)
kwargs = self.params.create_kwargs(
ParamTransformerData(
envs=envs,
device=device,
optim_factory=self.optim_factory,
actor=actor,
critic1=critic1,
critic2=critic2,
),
)
return SACPolicy(
actor=actor.module,
actor_optim=actor.optim,
critic=critic1.module,
critic_optim=critic1.optim,
critic2=critic2.module,
critic2_optim=critic2.optim,
action_space=envs.get_action_space(),
observation_space=envs.get_observation_space(),
**kwargs,
)
class DiscreteSACAgentFactory(ActorDualCriticsAgentFactory[DiscreteSACParams, DiscreteSACPolicy]): class DiscreteSACAgentFactory(ActorDualCriticsAgentFactory[DiscreteSACParams, DiscreteSACPolicy]):
def _get_discrete_last_size_use_action_shape(self) -> bool: def _get_policy_class(self) -> type[DiscreteSACPolicy]:
return True
def _get_policy_class(self) -> type[TPolicy]:
return DiscreteSACPolicy return DiscreteSACPolicy
class TD3AgentFactory(OffpolicyAgentFactory): class TD3AgentFactory(ActorDualCriticsAgentFactory[TD3Params, TD3Policy]):
def __init__( def _get_policy_class(self) -> type[TD3Policy]:
self, return TD3Policy
params: TD3Params,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic1_factory: CriticFactory,
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
self.params = params
self.actor_factory = actor_factory
self.critic1_factory = critic1_factory
self.critic2_factory = critic2_factory
self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.actor_factory.create_module_opt(
envs,
device,
self.optim_factory,
self.params.actor_lr,
)
critic1 = self.critic1_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic1_lr,
)
critic2 = self.critic2_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic2_lr,
)
kwargs = self.params.create_kwargs(
ParamTransformerData(
envs=envs,
device=device,
optim_factory=self.optim_factory,
actor=actor,
critic1=critic1,
critic2=critic2,
),
)
return TD3Policy(
actor=actor.module,
actor_optim=actor.optim,
critic=critic1.module,
critic_optim=critic1.optim,
critic2=critic2.module,
critic2_optim=critic2.optim,
action_space=envs.get_action_space(),
observation_space=envs.get_observation_space(),
**kwargs,
)