Simplify critic/agent with optimizer generation
After adding a function to create ModuleOpt instances directly from AgentFactory and CriticFactory, * several mixins for AgentFactories are no longer needed (deleted) * additional abstractions for ModuleOptFactories are no longer needed (deleted)
This commit is contained in:
parent
6bb3abb2f0
commit
1bb52a6a5c
@ -12,13 +12,11 @@ from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.logger import Logger
|
||||
from tianshou.highlevel.module.actor import (
|
||||
ActorFactory,
|
||||
ActorModuleOptFactory,
|
||||
)
|
||||
from tianshou.highlevel.module.core import TDevice
|
||||
from tianshou.highlevel.module.critic import CriticFactory, CriticModuleOptFactory
|
||||
from tianshou.highlevel.module.critic import CriticFactory
|
||||
from tianshou.highlevel.module.module_opt import (
|
||||
ActorCriticModuleOpt,
|
||||
ModuleOpt,
|
||||
)
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.highlevel.params.policy_params import (
|
||||
@ -243,31 +241,6 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
||||
)
|
||||
|
||||
|
||||
class _ActorMixin:
|
||||
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
||||
self.actor_module_opt_factory = ActorModuleOptFactory(actor_factory, optim_factory)
|
||||
|
||||
def create_actor_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||
return self.actor_module_opt_factory.create_module_opt(envs, device, lr)
|
||||
|
||||
|
||||
class _CriticMixin:
|
||||
def __init__(
|
||||
self,
|
||||
critic_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
critic_use_action: bool,
|
||||
):
|
||||
self.critic_module_opt_factory = CriticModuleOptFactory(
|
||||
critic_factory,
|
||||
optim_factory,
|
||||
critic_use_action,
|
||||
)
|
||||
|
||||
def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||
return self.critic_module_opt_factory.create_module_opt(envs, device, lr)
|
||||
|
||||
|
||||
class _ActorCriticMixin:
|
||||
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
|
||||
|
||||
@ -319,44 +292,7 @@ class _ActorCriticMixin:
|
||||
return ActorCriticModuleOpt(actor_critic, optim)
|
||||
|
||||
|
||||
class _ActorAndCriticMixin(_ActorMixin, _CriticMixin):
|
||||
def __init__(
|
||||
self,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
critic_use_action: bool,
|
||||
):
|
||||
_ActorMixin.__init__(self, actor_factory, optim_factory)
|
||||
_CriticMixin.__init__(self, critic_factory, optim_factory, critic_use_action)
|
||||
|
||||
|
||||
class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
|
||||
def __init__(
|
||||
self,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
critic2_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
critic_use_action: bool,
|
||||
):
|
||||
super().__init__(actor_factory, critic_factory, optim_factory, critic_use_action)
|
||||
self.critic2_module_opt_factory = CriticModuleOptFactory(
|
||||
critic2_factory,
|
||||
optim_factory,
|
||||
critic_use_action,
|
||||
)
|
||||
|
||||
def create_critic2_module_opt(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
lr: float,
|
||||
) -> ModuleOpt:
|
||||
return self.critic2_module_opt_factory.create_module_opt(envs, device, lr)
|
||||
|
||||
|
||||
class PGAgentFactory(OnpolicyAgentFactory, _ActorMixin):
|
||||
class PGAgentFactory(OnpolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
params: PGParams,
|
||||
@ -365,14 +301,16 @@ class PGAgentFactory(OnpolicyAgentFactory, _ActorMixin):
|
||||
optim_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(sampling_config, optim_factory)
|
||||
_ActorMixin.__init__(self, actor_factory, optim_factory)
|
||||
self.params = params
|
||||
self.actor_factory = actor_factory
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> PGPolicy:
|
||||
actor = self.actor_factory.create_module_opt(
|
||||
envs, device, self.optim_factory, self.params.lr,
|
||||
envs,
|
||||
device,
|
||||
self.optim_factory,
|
||||
self.params.lr,
|
||||
)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerData(
|
||||
@ -526,7 +464,7 @@ class DQNAgentFactory(OffpolicyAgentFactory):
|
||||
)
|
||||
|
||||
|
||||
class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
|
||||
class DDPGAgentFactory(OffpolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
params: DDPGParams,
|
||||
@ -536,19 +474,25 @@ class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
|
||||
optim_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(sampling_config, optim_factory)
|
||||
_ActorAndCriticMixin.__init__(
|
||||
self,
|
||||
actor_factory,
|
||||
critic_factory,
|
||||
optim_factory,
|
||||
critic_use_action=True,
|
||||
)
|
||||
self.critic_factory = critic_factory
|
||||
self.actor_factory = actor_factory
|
||||
self.params = params
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
||||
critic = self.create_critic_module_opt(envs, device, self.params.critic_lr)
|
||||
actor = self.actor_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
self.optim_factory,
|
||||
self.params.actor_lr,
|
||||
)
|
||||
critic = self.critic_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
True,
|
||||
self.optim_factory,
|
||||
self.params.critic_lr,
|
||||
)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerData(
|
||||
envs=envs,
|
||||
@ -569,7 +513,7 @@ class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
|
||||
)
|
||||
|
||||
|
||||
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
class SACAgentFactory(OffpolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
params: SACParams,
|
||||
@ -580,21 +524,33 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
optim_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(sampling_config, optim_factory)
|
||||
_ActorAndDualCriticsMixin.__init__(
|
||||
self,
|
||||
actor_factory,
|
||||
critic1_factory,
|
||||
critic2_factory,
|
||||
optim_factory,
|
||||
critic_use_action=True,
|
||||
)
|
||||
self.params = params
|
||||
self.actor_factory = actor_factory
|
||||
self.critic1_factory = critic1_factory
|
||||
self.critic2_factory = critic2_factory
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
||||
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
||||
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
||||
actor = self.actor_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
self.optim_factory,
|
||||
self.params.actor_lr,
|
||||
)
|
||||
critic1 = self.critic1_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
True,
|
||||
self.optim_factory,
|
||||
self.params.critic1_lr,
|
||||
)
|
||||
critic2 = self.critic2_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
True,
|
||||
self.optim_factory,
|
||||
self.params.critic2_lr,
|
||||
)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerData(
|
||||
envs=envs,
|
||||
@ -618,7 +574,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
)
|
||||
|
||||
|
||||
class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
class TD3AgentFactory(OffpolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
params: TD3Params,
|
||||
@ -629,21 +585,33 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
optim_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(sampling_config, optim_factory)
|
||||
_ActorAndDualCriticsMixin.__init__(
|
||||
self,
|
||||
actor_factory,
|
||||
critic1_factory,
|
||||
critic2_factory,
|
||||
optim_factory,
|
||||
critic_use_action=True,
|
||||
)
|
||||
self.params = params
|
||||
self.actor_factory = actor_factory
|
||||
self.critic1_factory = critic1_factory
|
||||
self.critic2_factory = critic2_factory
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
||||
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
||||
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
||||
actor = self.actor_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
self.optim_factory,
|
||||
self.params.actor_lr,
|
||||
)
|
||||
critic1 = self.critic1_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
True,
|
||||
self.optim_factory,
|
||||
self.params.critic1_lr,
|
||||
)
|
||||
critic2 = self.critic2_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
True,
|
||||
self.optim_factory,
|
||||
self.params.critic2_lr,
|
||||
)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerData(
|
||||
envs=envs,
|
||||
|
@ -26,7 +26,11 @@ class ActorFactory(ToStringMixin, ABC):
|
||||
pass
|
||||
|
||||
def create_module_opt(
|
||||
self, envs: Environments, device: TDevice, optim_factory: OptimizerFactory, lr: float,
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
optim_factory: OptimizerFactory,
|
||||
lr: float,
|
||||
) -> ModuleOpt:
|
||||
"""Creates the actor module along with its optimizer for the given learning rate.
|
||||
|
||||
@ -171,14 +175,3 @@ class ActorFactoryDiscreteNet(ActorFactory):
|
||||
hidden_sizes=(),
|
||||
device=device,
|
||||
).to(device)
|
||||
|
||||
|
||||
class ActorModuleOptFactory(ToStringMixin):
|
||||
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
||||
self.actor_factory = actor_factory
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
opt = self.optim_factory.create_optimizer(actor, lr)
|
||||
return ModuleOpt(actor, opt)
|
||||
|
@ -17,6 +17,18 @@ class CriticFactory(ToStringMixin, ABC):
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
pass
|
||||
|
||||
def create_module_opt(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
use_action: bool,
|
||||
optim_factory: OptimizerFactory,
|
||||
lr: float,
|
||||
) -> ModuleOpt:
|
||||
module = self.create_module(envs, device, use_action)
|
||||
opt = optim_factory.create_optimizer(module, lr)
|
||||
return ModuleOpt(module, opt)
|
||||
|
||||
|
||||
class CriticFactoryDefault(CriticFactory):
|
||||
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
|
||||
@ -80,20 +92,3 @@ class CriticFactoryDiscreteNet(CriticFactory):
|
||||
critic = discrete.Critic(net_c, device=device).to(device)
|
||||
init_linear_orthogonal(critic)
|
||||
return critic
|
||||
|
||||
|
||||
class CriticModuleOptFactory(ToStringMixin):
|
||||
def __init__(
|
||||
self,
|
||||
critic_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
use_action: bool,
|
||||
):
|
||||
self.critic_factory = critic_factory
|
||||
self.optim_factory = optim_factory
|
||||
self.use_action = use_action
|
||||
|
||||
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||
critic = self.critic_factory.create_module(envs, device, self.use_action)
|
||||
opt = self.optim_factory.create_optimizer(critic, lr)
|
||||
return ModuleOpt(critic, opt)
|
||||
|
Loading…
x
Reference in New Issue
Block a user