Simplify agent factories by making better use of base classes

This commit is contained in:
Dominik Jain 2023-10-10 20:07:30 +02:00
parent 799beb79b4
commit c7d0b6b4b2

View File

@ -27,6 +27,7 @@ from tianshou.highlevel.params.policy_params import (
NPGParams,
Params,
ParamsMixinActorAndDualCritics,
ParamsMixinLearningRateWithScheduler,
ParamTransformerData,
PGParams,
PPOParams,
@ -58,6 +59,7 @@ from tianshou.utils.string import ToStringMixin
CHECKPOINT_DICT_KEY_MODEL = "model"
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
TParams = TypeVar("TParams", bound=Params)
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
@ -316,37 +318,43 @@ class PGAgentFactory(OnpolicyAgentFactory):
class ActorCriticAgentFactory(
Generic[TParams, TPolicy],
Generic[TActorCriticParams, TPolicy],
OnpolicyAgentFactory,
_ActorCriticMixin,
ABC,
):
def __init__(
self,
params: TParams,
params: TActorCriticParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
policy_class: type[TPolicy],
):
super().__init__(sampling_config, optim_factory=optimizer_factory)
_ActorCriticMixin.__init__(
self,
actor_factory,
critic_factory,
optimizer_factory,
critic_use_action=False,
)
self.params = params
self.policy_class = policy_class
self.actor_factory = actor_factory
self.critic_factory = critic_factory
self.optim_factory = optimizer_factory
self.critic_use_action = False
@abstractmethod
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
def _get_policy_class(self) -> type[TPolicy]:
pass
def create_actor_critic_module_opt(
self,
envs: Environments,
device: TDevice,
lr: float,
) -> ActorCriticModuleOpt:
actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
actor_critic = ActorCritic(actor, critic)
optim = self.optim_factory.create_optimizer(actor_critic, lr)
return ActorCriticModuleOpt(actor_critic, optim)
def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
actor_critic = self._create_actor_critic(envs, device)
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
kwargs = self.params.create_kwargs(
ParamTransformerData(
envs=envs,
@ -362,95 +370,28 @@ class ActorCriticAgentFactory(
return kwargs
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
return self.policy_class(**self._create_kwargs(envs, device))
policy_class = self._get_policy_class()
return policy_class(**self._create_kwargs(envs, device))
class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
def __init__(
self,
params: A2CParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
):
super().__init__(
params,
sampling_config,
actor_factory,
critic_factory,
optimizer_factory,
A2CPolicy,
)
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
def _get_policy_class(self) -> type[A2CPolicy]:
return A2CPolicy
class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
def __init__(
self,
params: PPOParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
):
super().__init__(
params,
sampling_config,
actor_factory,
critic_factory,
optimizer_factory,
PPOPolicy,
)
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
def _get_policy_class(self) -> type[PPOPolicy]:
return PPOPolicy
class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]):
def __init__(
self,
params: NPGParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
):
super().__init__(
params,
sampling_config,
actor_factory,
critic_factory,
optimizer_factory,
NPGPolicy,
)
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
def _get_policy_class(self) -> type[NPGPolicy]:
return NPGPolicy
class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
def __init__(
self,
params: TRPOParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
):
super().__init__(
params,
sampling_config,
actor_factory,
critic_factory,
optimizer_factory,
TRPOPolicy,
)
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
def _get_policy_class(self) -> type[TRPOPolicy]:
return TRPOPolicy
class DQNAgentFactory(OffpolicyAgentFactory):
@ -590,7 +531,9 @@ class REDQAgentFactory(OffpolicyAgentFactory):
class ActorDualCriticsAgentFactory(
OffpolicyAgentFactory, Generic[TActorDualCriticsParams, TPolicy], ABC,
OffpolicyAgentFactory,
Generic[TActorDualCriticsParams, TPolicy],
ABC,
):
def __init__(
self,
@ -612,9 +555,8 @@ class ActorDualCriticsAgentFactory(
def _get_policy_class(self) -> type[TPolicy]:
pass
@abstractmethod
def _get_discrete_last_size_use_action_shape(self) -> bool:
pass
return True
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
actor = self.actor_factory.create_module_opt(
@ -664,131 +606,16 @@ class ActorDualCriticsAgentFactory(
)
class SACAgentFactory(OffpolicyAgentFactory):
def __init__(
self,
params: SACParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic1_factory: CriticFactory,
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
self.params = params
self.actor_factory = actor_factory
self.critic1_factory = critic1_factory
self.critic2_factory = critic2_factory
self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.actor_factory.create_module_opt(
envs,
device,
self.optim_factory,
self.params.actor_lr,
)
critic1 = self.critic1_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic1_lr,
)
critic2 = self.critic2_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic2_lr,
)
kwargs = self.params.create_kwargs(
ParamTransformerData(
envs=envs,
device=device,
optim_factory=self.optim_factory,
actor=actor,
critic1=critic1,
critic2=critic2,
),
)
return SACPolicy(
actor=actor.module,
actor_optim=actor.optim,
critic=critic1.module,
critic_optim=critic1.optim,
critic2=critic2.module,
critic2_optim=critic2.optim,
action_space=envs.get_action_space(),
observation_space=envs.get_observation_space(),
**kwargs,
)
class SACAgentFactory(ActorDualCriticsAgentFactory[SACParams, SACPolicy]):
def _get_policy_class(self) -> type[SACPolicy]:
return SACPolicy
class DiscreteSACAgentFactory(ActorDualCriticsAgentFactory[DiscreteSACParams, DiscreteSACPolicy]):
def _get_discrete_last_size_use_action_shape(self) -> bool:
return True
def _get_policy_class(self) -> type[TPolicy]:
def _get_policy_class(self) -> type[DiscreteSACPolicy]:
return DiscreteSACPolicy
class TD3AgentFactory(OffpolicyAgentFactory):
def __init__(
self,
params: TD3Params,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic1_factory: CriticFactory,
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
self.params = params
self.actor_factory = actor_factory
self.critic1_factory = critic1_factory
self.critic2_factory = critic2_factory
self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.actor_factory.create_module_opt(
envs,
device,
self.optim_factory,
self.params.actor_lr,
)
critic1 = self.critic1_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic1_lr,
)
critic2 = self.critic2_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic2_lr,
)
kwargs = self.params.create_kwargs(
ParamTransformerData(
envs=envs,
device=device,
optim_factory=self.optim_factory,
actor=actor,
critic1=critic1,
critic2=critic2,
),
)
return TD3Policy(
actor=actor.module,
actor_optim=actor.optim,
critic=critic1.module,
critic_optim=critic1.optim,
critic2=critic2.module,
critic2_optim=critic2.optim,
action_space=envs.get_action_space(),
observation_space=envs.get_observation_space(),
**kwargs,
)
class TD3AgentFactory(ActorDualCriticsAgentFactory[TD3Params, TD3Policy]):
def _get_policy_class(self) -> type[TD3Policy]:
return TD3Policy