Simplify agent factories by making better use of base classes
This commit is contained in:
parent
799beb79b4
commit
c7d0b6b4b2
@ -27,6 +27,7 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
NPGParams,
|
NPGParams,
|
||||||
Params,
|
Params,
|
||||||
ParamsMixinActorAndDualCritics,
|
ParamsMixinActorAndDualCritics,
|
||||||
|
ParamsMixinLearningRateWithScheduler,
|
||||||
ParamTransformerData,
|
ParamTransformerData,
|
||||||
PGParams,
|
PGParams,
|
||||||
PPOParams,
|
PPOParams,
|
||||||
@ -58,6 +59,7 @@ from tianshou.utils.string import ToStringMixin
|
|||||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||||
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
||||||
TParams = TypeVar("TParams", bound=Params)
|
TParams = TypeVar("TParams", bound=Params)
|
||||||
|
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
|
||||||
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
|
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
|
||||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||||
|
|
||||||
@ -316,37 +318,43 @@ class PGAgentFactory(OnpolicyAgentFactory):
|
|||||||
|
|
||||||
|
|
||||||
class ActorCriticAgentFactory(
|
class ActorCriticAgentFactory(
|
||||||
Generic[TParams, TPolicy],
|
Generic[TActorCriticParams, TPolicy],
|
||||||
OnpolicyAgentFactory,
|
OnpolicyAgentFactory,
|
||||||
_ActorCriticMixin,
|
|
||||||
ABC,
|
ABC,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: TParams,
|
params: TActorCriticParams,
|
||||||
sampling_config: SamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optimizer_factory: OptimizerFactory,
|
optimizer_factory: OptimizerFactory,
|
||||||
policy_class: type[TPolicy],
|
|
||||||
):
|
):
|
||||||
super().__init__(sampling_config, optim_factory=optimizer_factory)
|
super().__init__(sampling_config, optim_factory=optimizer_factory)
|
||||||
_ActorCriticMixin.__init__(
|
|
||||||
self,
|
|
||||||
actor_factory,
|
|
||||||
critic_factory,
|
|
||||||
optimizer_factory,
|
|
||||||
critic_use_action=False,
|
|
||||||
)
|
|
||||||
self.params = params
|
self.params = params
|
||||||
self.policy_class = policy_class
|
self.actor_factory = actor_factory
|
||||||
|
self.critic_factory = critic_factory
|
||||||
|
self.optim_factory = optimizer_factory
|
||||||
|
self.critic_use_action = False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
def _get_policy_class(self) -> type[TPolicy]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def create_actor_critic_module_opt(
|
||||||
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
device: TDevice,
|
||||||
|
lr: float,
|
||||||
|
) -> ActorCriticModuleOpt:
|
||||||
|
actor = self.actor_factory.create_module(envs, device)
|
||||||
|
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
|
||||||
|
actor_critic = ActorCritic(actor, critic)
|
||||||
|
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
||||||
|
return ActorCriticModuleOpt(actor_critic, optim)
|
||||||
|
|
||||||
def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
|
def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
|
||||||
actor_critic = self._create_actor_critic(envs, device)
|
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||||
kwargs = self.params.create_kwargs(
|
kwargs = self.params.create_kwargs(
|
||||||
ParamTransformerData(
|
ParamTransformerData(
|
||||||
envs=envs,
|
envs=envs,
|
||||||
@ -362,95 +370,28 @@ class ActorCriticAgentFactory(
|
|||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||||
return self.policy_class(**self._create_kwargs(envs, device))
|
policy_class = self._get_policy_class()
|
||||||
|
return policy_class(**self._create_kwargs(envs, device))
|
||||||
|
|
||||||
|
|
||||||
class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
|
class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
|
||||||
def __init__(
|
def _get_policy_class(self) -> type[A2CPolicy]:
|
||||||
self,
|
return A2CPolicy
|
||||||
params: A2CParams,
|
|
||||||
sampling_config: SamplingConfig,
|
|
||||||
actor_factory: ActorFactory,
|
|
||||||
critic_factory: CriticFactory,
|
|
||||||
optimizer_factory: OptimizerFactory,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
params,
|
|
||||||
sampling_config,
|
|
||||||
actor_factory,
|
|
||||||
critic_factory,
|
|
||||||
optimizer_factory,
|
|
||||||
A2CPolicy,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
|
||||||
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
|
||||||
|
|
||||||
|
|
||||||
class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
||||||
def __init__(
|
def _get_policy_class(self) -> type[PPOPolicy]:
|
||||||
self,
|
return PPOPolicy
|
||||||
params: PPOParams,
|
|
||||||
sampling_config: SamplingConfig,
|
|
||||||
actor_factory: ActorFactory,
|
|
||||||
critic_factory: CriticFactory,
|
|
||||||
optimizer_factory: OptimizerFactory,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
params,
|
|
||||||
sampling_config,
|
|
||||||
actor_factory,
|
|
||||||
critic_factory,
|
|
||||||
optimizer_factory,
|
|
||||||
PPOPolicy,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
|
||||||
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
|
||||||
|
|
||||||
|
|
||||||
class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]):
|
class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]):
|
||||||
def __init__(
|
def _get_policy_class(self) -> type[NPGPolicy]:
|
||||||
self,
|
return NPGPolicy
|
||||||
params: NPGParams,
|
|
||||||
sampling_config: SamplingConfig,
|
|
||||||
actor_factory: ActorFactory,
|
|
||||||
critic_factory: CriticFactory,
|
|
||||||
optimizer_factory: OptimizerFactory,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
params,
|
|
||||||
sampling_config,
|
|
||||||
actor_factory,
|
|
||||||
critic_factory,
|
|
||||||
optimizer_factory,
|
|
||||||
NPGPolicy,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
|
||||||
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
|
||||||
|
|
||||||
|
|
||||||
class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
|
class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
|
||||||
def __init__(
|
def _get_policy_class(self) -> type[TRPOPolicy]:
|
||||||
self,
|
return TRPOPolicy
|
||||||
params: TRPOParams,
|
|
||||||
sampling_config: SamplingConfig,
|
|
||||||
actor_factory: ActorFactory,
|
|
||||||
critic_factory: CriticFactory,
|
|
||||||
optimizer_factory: OptimizerFactory,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
params,
|
|
||||||
sampling_config,
|
|
||||||
actor_factory,
|
|
||||||
critic_factory,
|
|
||||||
optimizer_factory,
|
|
||||||
TRPOPolicy,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
|
||||||
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
|
||||||
|
|
||||||
|
|
||||||
class DQNAgentFactory(OffpolicyAgentFactory):
|
class DQNAgentFactory(OffpolicyAgentFactory):
|
||||||
@ -590,7 +531,9 @@ class REDQAgentFactory(OffpolicyAgentFactory):
|
|||||||
|
|
||||||
|
|
||||||
class ActorDualCriticsAgentFactory(
|
class ActorDualCriticsAgentFactory(
|
||||||
OffpolicyAgentFactory, Generic[TActorDualCriticsParams, TPolicy], ABC,
|
OffpolicyAgentFactory,
|
||||||
|
Generic[TActorDualCriticsParams, TPolicy],
|
||||||
|
ABC,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -612,9 +555,8 @@ class ActorDualCriticsAgentFactory(
|
|||||||
def _get_policy_class(self) -> type[TPolicy]:
|
def _get_policy_class(self) -> type[TPolicy]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _get_discrete_last_size_use_action_shape(self) -> bool:
|
def _get_discrete_last_size_use_action_shape(self) -> bool:
|
||||||
pass
|
return True
|
||||||
|
|
||||||
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||||
actor = self.actor_factory.create_module_opt(
|
actor = self.actor_factory.create_module_opt(
|
||||||
@ -664,131 +606,16 @@ class ActorDualCriticsAgentFactory(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SACAgentFactory(OffpolicyAgentFactory):
|
class SACAgentFactory(ActorDualCriticsAgentFactory[SACParams, SACPolicy]):
|
||||||
def __init__(
|
def _get_policy_class(self) -> type[SACPolicy]:
|
||||||
self,
|
return SACPolicy
|
||||||
params: SACParams,
|
|
||||||
sampling_config: SamplingConfig,
|
|
||||||
actor_factory: ActorFactory,
|
|
||||||
critic1_factory: CriticFactory,
|
|
||||||
critic2_factory: CriticFactory,
|
|
||||||
optim_factory: OptimizerFactory,
|
|
||||||
):
|
|
||||||
super().__init__(sampling_config, optim_factory)
|
|
||||||
self.params = params
|
|
||||||
self.actor_factory = actor_factory
|
|
||||||
self.critic1_factory = critic1_factory
|
|
||||||
self.critic2_factory = critic2_factory
|
|
||||||
self.optim_factory = optim_factory
|
|
||||||
|
|
||||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
|
||||||
actor = self.actor_factory.create_module_opt(
|
|
||||||
envs,
|
|
||||||
device,
|
|
||||||
self.optim_factory,
|
|
||||||
self.params.actor_lr,
|
|
||||||
)
|
|
||||||
critic1 = self.critic1_factory.create_module_opt(
|
|
||||||
envs,
|
|
||||||
device,
|
|
||||||
True,
|
|
||||||
self.optim_factory,
|
|
||||||
self.params.critic1_lr,
|
|
||||||
)
|
|
||||||
critic2 = self.critic2_factory.create_module_opt(
|
|
||||||
envs,
|
|
||||||
device,
|
|
||||||
True,
|
|
||||||
self.optim_factory,
|
|
||||||
self.params.critic2_lr,
|
|
||||||
)
|
|
||||||
kwargs = self.params.create_kwargs(
|
|
||||||
ParamTransformerData(
|
|
||||||
envs=envs,
|
|
||||||
device=device,
|
|
||||||
optim_factory=self.optim_factory,
|
|
||||||
actor=actor,
|
|
||||||
critic1=critic1,
|
|
||||||
critic2=critic2,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return SACPolicy(
|
|
||||||
actor=actor.module,
|
|
||||||
actor_optim=actor.optim,
|
|
||||||
critic=critic1.module,
|
|
||||||
critic_optim=critic1.optim,
|
|
||||||
critic2=critic2.module,
|
|
||||||
critic2_optim=critic2.optim,
|
|
||||||
action_space=envs.get_action_space(),
|
|
||||||
observation_space=envs.get_observation_space(),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DiscreteSACAgentFactory(ActorDualCriticsAgentFactory[DiscreteSACParams, DiscreteSACPolicy]):
|
class DiscreteSACAgentFactory(ActorDualCriticsAgentFactory[DiscreteSACParams, DiscreteSACPolicy]):
|
||||||
def _get_discrete_last_size_use_action_shape(self) -> bool:
|
def _get_policy_class(self) -> type[DiscreteSACPolicy]:
|
||||||
return True
|
|
||||||
|
|
||||||
def _get_policy_class(self) -> type[TPolicy]:
|
|
||||||
return DiscreteSACPolicy
|
return DiscreteSACPolicy
|
||||||
|
|
||||||
|
|
||||||
class TD3AgentFactory(OffpolicyAgentFactory):
|
class TD3AgentFactory(ActorDualCriticsAgentFactory[TD3Params, TD3Policy]):
|
||||||
def __init__(
|
def _get_policy_class(self) -> type[TD3Policy]:
|
||||||
self,
|
return TD3Policy
|
||||||
params: TD3Params,
|
|
||||||
sampling_config: SamplingConfig,
|
|
||||||
actor_factory: ActorFactory,
|
|
||||||
critic1_factory: CriticFactory,
|
|
||||||
critic2_factory: CriticFactory,
|
|
||||||
optim_factory: OptimizerFactory,
|
|
||||||
):
|
|
||||||
super().__init__(sampling_config, optim_factory)
|
|
||||||
self.params = params
|
|
||||||
self.actor_factory = actor_factory
|
|
||||||
self.critic1_factory = critic1_factory
|
|
||||||
self.critic2_factory = critic2_factory
|
|
||||||
self.optim_factory = optim_factory
|
|
||||||
|
|
||||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
|
||||||
actor = self.actor_factory.create_module_opt(
|
|
||||||
envs,
|
|
||||||
device,
|
|
||||||
self.optim_factory,
|
|
||||||
self.params.actor_lr,
|
|
||||||
)
|
|
||||||
critic1 = self.critic1_factory.create_module_opt(
|
|
||||||
envs,
|
|
||||||
device,
|
|
||||||
True,
|
|
||||||
self.optim_factory,
|
|
||||||
self.params.critic1_lr,
|
|
||||||
)
|
|
||||||
critic2 = self.critic2_factory.create_module_opt(
|
|
||||||
envs,
|
|
||||||
device,
|
|
||||||
True,
|
|
||||||
self.optim_factory,
|
|
||||||
self.params.critic2_lr,
|
|
||||||
)
|
|
||||||
kwargs = self.params.create_kwargs(
|
|
||||||
ParamTransformerData(
|
|
||||||
envs=envs,
|
|
||||||
device=device,
|
|
||||||
optim_factory=self.optim_factory,
|
|
||||||
actor=actor,
|
|
||||||
critic1=critic1,
|
|
||||||
critic2=critic2,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return TD3Policy(
|
|
||||||
actor=actor.module,
|
|
||||||
actor_optim=actor.optim,
|
|
||||||
critic=critic1.module,
|
|
||||||
critic_optim=critic1.optim,
|
|
||||||
critic2=critic2.module,
|
|
||||||
critic2_optim=critic2.optim,
|
|
||||||
action_space=envs.get_action_space(),
|
|
||||||
observation_space=envs.get_observation_space(),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user