Simplify agent factories by making better use of base classes
This commit is contained in:
parent
799beb79b4
commit
c7d0b6b4b2
@ -27,6 +27,7 @@ from tianshou.highlevel.params.policy_params import (
|
||||
NPGParams,
|
||||
Params,
|
||||
ParamsMixinActorAndDualCritics,
|
||||
ParamsMixinLearningRateWithScheduler,
|
||||
ParamTransformerData,
|
||||
PGParams,
|
||||
PPOParams,
|
||||
@ -58,6 +59,7 @@ from tianshou.utils.string import ToStringMixin
|
||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
||||
TParams = TypeVar("TParams", bound=Params)
|
||||
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
|
||||
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
|
||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
|
||||
@ -316,37 +318,43 @@ class PGAgentFactory(OnpolicyAgentFactory):
|
||||
|
||||
|
||||
class ActorCriticAgentFactory(
|
||||
Generic[TParams, TPolicy],
|
||||
Generic[TActorCriticParams, TPolicy],
|
||||
OnpolicyAgentFactory,
|
||||
_ActorCriticMixin,
|
||||
ABC,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
params: TParams,
|
||||
params: TActorCriticParams,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
policy_class: type[TPolicy],
|
||||
):
|
||||
super().__init__(sampling_config, optim_factory=optimizer_factory)
|
||||
_ActorCriticMixin.__init__(
|
||||
self,
|
||||
actor_factory,
|
||||
critic_factory,
|
||||
optimizer_factory,
|
||||
critic_use_action=False,
|
||||
)
|
||||
self.params = params
|
||||
self.policy_class = policy_class
|
||||
self.actor_factory = actor_factory
|
||||
self.critic_factory = critic_factory
|
||||
self.optim_factory = optimizer_factory
|
||||
self.critic_use_action = False
|
||||
|
||||
@abstractmethod
|
||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||
def _get_policy_class(self) -> type[TPolicy]:
|
||||
pass
|
||||
|
||||
def create_actor_critic_module_opt(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
lr: float,
|
||||
) -> ActorCriticModuleOpt:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
||||
return ActorCriticModuleOpt(actor_critic, optim)
|
||||
|
||||
def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
|
||||
actor_critic = self._create_actor_critic(envs, device)
|
||||
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerData(
|
||||
envs=envs,
|
||||
@ -362,95 +370,28 @@ class ActorCriticAgentFactory(
|
||||
return kwargs
|
||||
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||
return self.policy_class(**self._create_kwargs(envs, device))
|
||||
policy_class = self._get_policy_class()
|
||||
return policy_class(**self._create_kwargs(envs, device))
|
||||
|
||||
|
||||
class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
|
||||
def __init__(
|
||||
self,
|
||||
params: A2CParams,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
sampling_config,
|
||||
actor_factory,
|
||||
critic_factory,
|
||||
optimizer_factory,
|
||||
A2CPolicy,
|
||||
)
|
||||
|
||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||
def _get_policy_class(self) -> type[A2CPolicy]:
|
||||
return A2CPolicy
|
||||
|
||||
|
||||
class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
||||
def __init__(
|
||||
self,
|
||||
params: PPOParams,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
sampling_config,
|
||||
actor_factory,
|
||||
critic_factory,
|
||||
optimizer_factory,
|
||||
PPOPolicy,
|
||||
)
|
||||
|
||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||
def _get_policy_class(self) -> type[PPOPolicy]:
|
||||
return PPOPolicy
|
||||
|
||||
|
||||
class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]):
|
||||
def __init__(
|
||||
self,
|
||||
params: NPGParams,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
sampling_config,
|
||||
actor_factory,
|
||||
critic_factory,
|
||||
optimizer_factory,
|
||||
NPGPolicy,
|
||||
)
|
||||
|
||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||
def _get_policy_class(self) -> type[NPGPolicy]:
|
||||
return NPGPolicy
|
||||
|
||||
|
||||
class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
|
||||
def __init__(
|
||||
self,
|
||||
params: TRPOParams,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
sampling_config,
|
||||
actor_factory,
|
||||
critic_factory,
|
||||
optimizer_factory,
|
||||
TRPOPolicy,
|
||||
)
|
||||
|
||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||
def _get_policy_class(self) -> type[TRPOPolicy]:
|
||||
return TRPOPolicy
|
||||
|
||||
|
||||
class DQNAgentFactory(OffpolicyAgentFactory):
|
||||
@ -590,7 +531,9 @@ class REDQAgentFactory(OffpolicyAgentFactory):
|
||||
|
||||
|
||||
class ActorDualCriticsAgentFactory(
|
||||
OffpolicyAgentFactory, Generic[TActorDualCriticsParams, TPolicy], ABC,
|
||||
OffpolicyAgentFactory,
|
||||
Generic[TActorDualCriticsParams, TPolicy],
|
||||
ABC,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@ -612,9 +555,8 @@ class ActorDualCriticsAgentFactory(
|
||||
def _get_policy_class(self) -> type[TPolicy]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_discrete_last_size_use_action_shape(self) -> bool:
|
||||
pass
|
||||
return True
|
||||
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||
actor = self.actor_factory.create_module_opt(
|
||||
@ -664,131 +606,16 @@ class ActorDualCriticsAgentFactory(
|
||||
)
|
||||
|
||||
|
||||
class SACAgentFactory(OffpolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
params: SACParams,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic1_factory: CriticFactory,
|
||||
critic2_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(sampling_config, optim_factory)
|
||||
self.params = params
|
||||
self.actor_factory = actor_factory
|
||||
self.critic1_factory = critic1_factory
|
||||
self.critic2_factory = critic2_factory
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
actor = self.actor_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
self.optim_factory,
|
||||
self.params.actor_lr,
|
||||
)
|
||||
critic1 = self.critic1_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
True,
|
||||
self.optim_factory,
|
||||
self.params.critic1_lr,
|
||||
)
|
||||
critic2 = self.critic2_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
True,
|
||||
self.optim_factory,
|
||||
self.params.critic2_lr,
|
||||
)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerData(
|
||||
envs=envs,
|
||||
device=device,
|
||||
optim_factory=self.optim_factory,
|
||||
actor=actor,
|
||||
critic1=critic1,
|
||||
critic2=critic2,
|
||||
),
|
||||
)
|
||||
return SACPolicy(
|
||||
actor=actor.module,
|
||||
actor_optim=actor.optim,
|
||||
critic=critic1.module,
|
||||
critic_optim=critic1.optim,
|
||||
critic2=critic2.module,
|
||||
critic2_optim=critic2.optim,
|
||||
action_space=envs.get_action_space(),
|
||||
observation_space=envs.get_observation_space(),
|
||||
**kwargs,
|
||||
)
|
||||
class SACAgentFactory(ActorDualCriticsAgentFactory[SACParams, SACPolicy]):
|
||||
def _get_policy_class(self) -> type[SACPolicy]:
|
||||
return SACPolicy
|
||||
|
||||
|
||||
class DiscreteSACAgentFactory(ActorDualCriticsAgentFactory[DiscreteSACParams, DiscreteSACPolicy]):
|
||||
def _get_discrete_last_size_use_action_shape(self) -> bool:
|
||||
return True
|
||||
|
||||
def _get_policy_class(self) -> type[TPolicy]:
|
||||
def _get_policy_class(self) -> type[DiscreteSACPolicy]:
|
||||
return DiscreteSACPolicy
|
||||
|
||||
|
||||
class TD3AgentFactory(OffpolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
params: TD3Params,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic1_factory: CriticFactory,
|
||||
critic2_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(sampling_config, optim_factory)
|
||||
self.params = params
|
||||
self.actor_factory = actor_factory
|
||||
self.critic1_factory = critic1_factory
|
||||
self.critic2_factory = critic2_factory
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
actor = self.actor_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
self.optim_factory,
|
||||
self.params.actor_lr,
|
||||
)
|
||||
critic1 = self.critic1_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
True,
|
||||
self.optim_factory,
|
||||
self.params.critic1_lr,
|
||||
)
|
||||
critic2 = self.critic2_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
True,
|
||||
self.optim_factory,
|
||||
self.params.critic2_lr,
|
||||
)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerData(
|
||||
envs=envs,
|
||||
device=device,
|
||||
optim_factory=self.optim_factory,
|
||||
actor=actor,
|
||||
critic1=critic1,
|
||||
critic2=critic2,
|
||||
),
|
||||
)
|
||||
return TD3Policy(
|
||||
actor=actor.module,
|
||||
actor_optim=actor.optim,
|
||||
critic=critic1.module,
|
||||
critic_optim=critic1.optim,
|
||||
critic2=critic2.module,
|
||||
critic2_optim=critic2.optim,
|
||||
action_space=envs.get_action_space(),
|
||||
observation_space=envs.get_observation_space(),
|
||||
**kwargs,
|
||||
)
|
||||
class TD3AgentFactory(ActorDualCriticsAgentFactory[TD3Params, TD3Policy]):
|
||||
def _get_policy_class(self) -> type[TD3Policy]:
|
||||
return TD3Policy
|
||||
|
Loading…
x
Reference in New Issue
Block a user