Simplify critic/agent with optimizer generation
After adding a function to create ModuleOpt instances directly from AgentFactory and CriticFactory, * several mixins for AgentFactories are no longer needed (deleted) * additional abstractions for ModuleOptFactories are no longer needed (deleted)
This commit is contained in:
parent
6bb3abb2f0
commit
1bb52a6a5c
@ -12,13 +12,11 @@ from tianshou.highlevel.env import Environments
|
|||||||
from tianshou.highlevel.logger import Logger
|
from tianshou.highlevel.logger import Logger
|
||||||
from tianshou.highlevel.module.actor import (
|
from tianshou.highlevel.module.actor import (
|
||||||
ActorFactory,
|
ActorFactory,
|
||||||
ActorModuleOptFactory,
|
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.module.core import TDevice
|
from tianshou.highlevel.module.core import TDevice
|
||||||
from tianshou.highlevel.module.critic import CriticFactory, CriticModuleOptFactory
|
from tianshou.highlevel.module.critic import CriticFactory
|
||||||
from tianshou.highlevel.module.module_opt import (
|
from tianshou.highlevel.module.module_opt import (
|
||||||
ActorCriticModuleOpt,
|
ActorCriticModuleOpt,
|
||||||
ModuleOpt,
|
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.highlevel.params.policy_params import (
|
from tianshou.highlevel.params.policy_params import (
|
||||||
@ -243,31 +241,6 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class _ActorMixin:
|
|
||||||
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
|
||||||
self.actor_module_opt_factory = ActorModuleOptFactory(actor_factory, optim_factory)
|
|
||||||
|
|
||||||
def create_actor_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
||||||
return self.actor_module_opt_factory.create_module_opt(envs, device, lr)
|
|
||||||
|
|
||||||
|
|
||||||
class _CriticMixin:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
critic_factory: CriticFactory,
|
|
||||||
optim_factory: OptimizerFactory,
|
|
||||||
critic_use_action: bool,
|
|
||||||
):
|
|
||||||
self.critic_module_opt_factory = CriticModuleOptFactory(
|
|
||||||
critic_factory,
|
|
||||||
optim_factory,
|
|
||||||
critic_use_action,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
||||||
return self.critic_module_opt_factory.create_module_opt(envs, device, lr)
|
|
||||||
|
|
||||||
|
|
||||||
class _ActorCriticMixin:
|
class _ActorCriticMixin:
|
||||||
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
|
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
|
||||||
|
|
||||||
@ -319,44 +292,7 @@ class _ActorCriticMixin:
|
|||||||
return ActorCriticModuleOpt(actor_critic, optim)
|
return ActorCriticModuleOpt(actor_critic, optim)
|
||||||
|
|
||||||
|
|
||||||
class _ActorAndCriticMixin(_ActorMixin, _CriticMixin):
|
class PGAgentFactory(OnpolicyAgentFactory):
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
actor_factory: ActorFactory,
|
|
||||||
critic_factory: CriticFactory,
|
|
||||||
optim_factory: OptimizerFactory,
|
|
||||||
critic_use_action: bool,
|
|
||||||
):
|
|
||||||
_ActorMixin.__init__(self, actor_factory, optim_factory)
|
|
||||||
_CriticMixin.__init__(self, critic_factory, optim_factory, critic_use_action)
|
|
||||||
|
|
||||||
|
|
||||||
class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
actor_factory: ActorFactory,
|
|
||||||
critic_factory: CriticFactory,
|
|
||||||
critic2_factory: CriticFactory,
|
|
||||||
optim_factory: OptimizerFactory,
|
|
||||||
critic_use_action: bool,
|
|
||||||
):
|
|
||||||
super().__init__(actor_factory, critic_factory, optim_factory, critic_use_action)
|
|
||||||
self.critic2_module_opt_factory = CriticModuleOptFactory(
|
|
||||||
critic2_factory,
|
|
||||||
optim_factory,
|
|
||||||
critic_use_action,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_critic2_module_opt(
|
|
||||||
self,
|
|
||||||
envs: Environments,
|
|
||||||
device: TDevice,
|
|
||||||
lr: float,
|
|
||||||
) -> ModuleOpt:
|
|
||||||
return self.critic2_module_opt_factory.create_module_opt(envs, device, lr)
|
|
||||||
|
|
||||||
|
|
||||||
class PGAgentFactory(OnpolicyAgentFactory, _ActorMixin):
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: PGParams,
|
params: PGParams,
|
||||||
@ -365,14 +301,16 @@ class PGAgentFactory(OnpolicyAgentFactory, _ActorMixin):
|
|||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
):
|
):
|
||||||
super().__init__(sampling_config, optim_factory)
|
super().__init__(sampling_config, optim_factory)
|
||||||
_ActorMixin.__init__(self, actor_factory, optim_factory)
|
|
||||||
self.params = params
|
self.params = params
|
||||||
self.actor_factory = actor_factory
|
self.actor_factory = actor_factory
|
||||||
self.optim_factory = optim_factory
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
def _create_policy(self, envs: Environments, device: TDevice) -> PGPolicy:
|
def _create_policy(self, envs: Environments, device: TDevice) -> PGPolicy:
|
||||||
actor = self.actor_factory.create_module_opt(
|
actor = self.actor_factory.create_module_opt(
|
||||||
envs, device, self.optim_factory, self.params.lr,
|
envs,
|
||||||
|
device,
|
||||||
|
self.optim_factory,
|
||||||
|
self.params.lr,
|
||||||
)
|
)
|
||||||
kwargs = self.params.create_kwargs(
|
kwargs = self.params.create_kwargs(
|
||||||
ParamTransformerData(
|
ParamTransformerData(
|
||||||
@ -526,7 +464,7 @@ class DQNAgentFactory(OffpolicyAgentFactory):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
|
class DDPGAgentFactory(OffpolicyAgentFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: DDPGParams,
|
params: DDPGParams,
|
||||||
@ -536,19 +474,25 @@ class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
|
|||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
):
|
):
|
||||||
super().__init__(sampling_config, optim_factory)
|
super().__init__(sampling_config, optim_factory)
|
||||||
_ActorAndCriticMixin.__init__(
|
self.critic_factory = critic_factory
|
||||||
self,
|
self.actor_factory = actor_factory
|
||||||
actor_factory,
|
|
||||||
critic_factory,
|
|
||||||
optim_factory,
|
|
||||||
critic_use_action=True,
|
|
||||||
)
|
|
||||||
self.params = params
|
self.params = params
|
||||||
self.optim_factory = optim_factory
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
actor = self.actor_factory.create_module_opt(
|
||||||
critic = self.create_critic_module_opt(envs, device, self.params.critic_lr)
|
envs,
|
||||||
|
device,
|
||||||
|
self.optim_factory,
|
||||||
|
self.params.actor_lr,
|
||||||
|
)
|
||||||
|
critic = self.critic_factory.create_module_opt(
|
||||||
|
envs,
|
||||||
|
device,
|
||||||
|
True,
|
||||||
|
self.optim_factory,
|
||||||
|
self.params.critic_lr,
|
||||||
|
)
|
||||||
kwargs = self.params.create_kwargs(
|
kwargs = self.params.create_kwargs(
|
||||||
ParamTransformerData(
|
ParamTransformerData(
|
||||||
envs=envs,
|
envs=envs,
|
||||||
@ -569,7 +513,7 @@ class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
class SACAgentFactory(OffpolicyAgentFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: SACParams,
|
params: SACParams,
|
||||||
@ -580,21 +524,33 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
|||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
):
|
):
|
||||||
super().__init__(sampling_config, optim_factory)
|
super().__init__(sampling_config, optim_factory)
|
||||||
_ActorAndDualCriticsMixin.__init__(
|
|
||||||
self,
|
|
||||||
actor_factory,
|
|
||||||
critic1_factory,
|
|
||||||
critic2_factory,
|
|
||||||
optim_factory,
|
|
||||||
critic_use_action=True,
|
|
||||||
)
|
|
||||||
self.params = params
|
self.params = params
|
||||||
|
self.actor_factory = actor_factory
|
||||||
|
self.critic1_factory = critic1_factory
|
||||||
|
self.critic2_factory = critic2_factory
|
||||||
self.optim_factory = optim_factory
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
actor = self.actor_factory.create_module_opt(
|
||||||
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
envs,
|
||||||
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
device,
|
||||||
|
self.optim_factory,
|
||||||
|
self.params.actor_lr,
|
||||||
|
)
|
||||||
|
critic1 = self.critic1_factory.create_module_opt(
|
||||||
|
envs,
|
||||||
|
device,
|
||||||
|
True,
|
||||||
|
self.optim_factory,
|
||||||
|
self.params.critic1_lr,
|
||||||
|
)
|
||||||
|
critic2 = self.critic2_factory.create_module_opt(
|
||||||
|
envs,
|
||||||
|
device,
|
||||||
|
True,
|
||||||
|
self.optim_factory,
|
||||||
|
self.params.critic2_lr,
|
||||||
|
)
|
||||||
kwargs = self.params.create_kwargs(
|
kwargs = self.params.create_kwargs(
|
||||||
ParamTransformerData(
|
ParamTransformerData(
|
||||||
envs=envs,
|
envs=envs,
|
||||||
@ -618,7 +574,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
class TD3AgentFactory(OffpolicyAgentFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: TD3Params,
|
params: TD3Params,
|
||||||
@ -629,21 +585,33 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
|||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
):
|
):
|
||||||
super().__init__(sampling_config, optim_factory)
|
super().__init__(sampling_config, optim_factory)
|
||||||
_ActorAndDualCriticsMixin.__init__(
|
|
||||||
self,
|
|
||||||
actor_factory,
|
|
||||||
critic1_factory,
|
|
||||||
critic2_factory,
|
|
||||||
optim_factory,
|
|
||||||
critic_use_action=True,
|
|
||||||
)
|
|
||||||
self.params = params
|
self.params = params
|
||||||
|
self.actor_factory = actor_factory
|
||||||
|
self.critic1_factory = critic1_factory
|
||||||
|
self.critic2_factory = critic2_factory
|
||||||
self.optim_factory = optim_factory
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
actor = self.actor_factory.create_module_opt(
|
||||||
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
envs,
|
||||||
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
device,
|
||||||
|
self.optim_factory,
|
||||||
|
self.params.actor_lr,
|
||||||
|
)
|
||||||
|
critic1 = self.critic1_factory.create_module_opt(
|
||||||
|
envs,
|
||||||
|
device,
|
||||||
|
True,
|
||||||
|
self.optim_factory,
|
||||||
|
self.params.critic1_lr,
|
||||||
|
)
|
||||||
|
critic2 = self.critic2_factory.create_module_opt(
|
||||||
|
envs,
|
||||||
|
device,
|
||||||
|
True,
|
||||||
|
self.optim_factory,
|
||||||
|
self.params.critic2_lr,
|
||||||
|
)
|
||||||
kwargs = self.params.create_kwargs(
|
kwargs = self.params.create_kwargs(
|
||||||
ParamTransformerData(
|
ParamTransformerData(
|
||||||
envs=envs,
|
envs=envs,
|
||||||
|
|||||||
@ -26,7 +26,11 @@ class ActorFactory(ToStringMixin, ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def create_module_opt(
|
def create_module_opt(
|
||||||
self, envs: Environments, device: TDevice, optim_factory: OptimizerFactory, lr: float,
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
device: TDevice,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
lr: float,
|
||||||
) -> ModuleOpt:
|
) -> ModuleOpt:
|
||||||
"""Creates the actor module along with its optimizer for the given learning rate.
|
"""Creates the actor module along with its optimizer for the given learning rate.
|
||||||
|
|
||||||
@ -171,14 +175,3 @@ class ActorFactoryDiscreteNet(ActorFactory):
|
|||||||
hidden_sizes=(),
|
hidden_sizes=(),
|
||||||
device=device,
|
device=device,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
|
|
||||||
class ActorModuleOptFactory(ToStringMixin):
|
|
||||||
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
|
||||||
self.actor_factory = actor_factory
|
|
||||||
self.optim_factory = optim_factory
|
|
||||||
|
|
||||||
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
||||||
actor = self.actor_factory.create_module(envs, device)
|
|
||||||
opt = self.optim_factory.create_optimizer(actor, lr)
|
|
||||||
return ModuleOpt(actor, opt)
|
|
||||||
|
|||||||
@ -17,6 +17,18 @@ class CriticFactory(ToStringMixin, ABC):
|
|||||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def create_module_opt(
|
||||||
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
device: TDevice,
|
||||||
|
use_action: bool,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
lr: float,
|
||||||
|
) -> ModuleOpt:
|
||||||
|
module = self.create_module(envs, device, use_action)
|
||||||
|
opt = optim_factory.create_optimizer(module, lr)
|
||||||
|
return ModuleOpt(module, opt)
|
||||||
|
|
||||||
|
|
||||||
class CriticFactoryDefault(CriticFactory):
|
class CriticFactoryDefault(CriticFactory):
|
||||||
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
|
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
|
||||||
@ -80,20 +92,3 @@ class CriticFactoryDiscreteNet(CriticFactory):
|
|||||||
critic = discrete.Critic(net_c, device=device).to(device)
|
critic = discrete.Critic(net_c, device=device).to(device)
|
||||||
init_linear_orthogonal(critic)
|
init_linear_orthogonal(critic)
|
||||||
return critic
|
return critic
|
||||||
|
|
||||||
|
|
||||||
class CriticModuleOptFactory(ToStringMixin):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
critic_factory: CriticFactory,
|
|
||||||
optim_factory: OptimizerFactory,
|
|
||||||
use_action: bool,
|
|
||||||
):
|
|
||||||
self.critic_factory = critic_factory
|
|
||||||
self.optim_factory = optim_factory
|
|
||||||
self.use_action = use_action
|
|
||||||
|
|
||||||
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
||||||
critic = self.critic_factory.create_module(envs, device, self.use_action)
|
|
||||||
opt = self.optim_factory.create_optimizer(critic, lr)
|
|
||||||
return ModuleOpt(critic, opt)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user