Simplify critic/agent with optimizer generation

After adding a function to create ModuleOpt instances directly from
AgentFactory and CriticFactory,
  * several mixins for AgentFactories are no longer needed (deleted)
  * additional abstractions for ModuleOptFactories are no longer needed (deleted)
This commit is contained in:
Dominik Jain 2023-10-10 13:12:25 +02:00
parent 6bb3abb2f0
commit 1bb52a6a5c
3 changed files with 87 additions and 131 deletions

View File

@ -12,13 +12,11 @@ from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module.actor import (
ActorFactory,
ActorModuleOptFactory,
)
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.critic import CriticFactory, CriticModuleOptFactory
from tianshou.highlevel.module.critic import CriticFactory
from tianshou.highlevel.module.module_opt import (
ActorCriticModuleOpt,
ModuleOpt,
)
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.policy_params import (
@ -243,31 +241,6 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
)
class _ActorMixin:
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_module_opt_factory = ActorModuleOptFactory(actor_factory, optim_factory)
def create_actor_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
return self.actor_module_opt_factory.create_module_opt(envs, device, lr)
class _CriticMixin:
def __init__(
self,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
critic_use_action: bool,
):
self.critic_module_opt_factory = CriticModuleOptFactory(
critic_factory,
optim_factory,
critic_use_action,
)
def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
return self.critic_module_opt_factory.create_module_opt(envs, device, lr)
class _ActorCriticMixin:
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
@ -319,44 +292,7 @@ class _ActorCriticMixin:
return ActorCriticModuleOpt(actor_critic, optim)
class _ActorAndCriticMixin(_ActorMixin, _CriticMixin):
def __init__(
self,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
critic_use_action: bool,
):
_ActorMixin.__init__(self, actor_factory, optim_factory)
_CriticMixin.__init__(self, critic_factory, optim_factory, critic_use_action)
class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
def __init__(
self,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
critic_use_action: bool,
):
super().__init__(actor_factory, critic_factory, optim_factory, critic_use_action)
self.critic2_module_opt_factory = CriticModuleOptFactory(
critic2_factory,
optim_factory,
critic_use_action,
)
def create_critic2_module_opt(
self,
envs: Environments,
device: TDevice,
lr: float,
) -> ModuleOpt:
return self.critic2_module_opt_factory.create_module_opt(envs, device, lr)
class PGAgentFactory(OnpolicyAgentFactory, _ActorMixin):
class PGAgentFactory(OnpolicyAgentFactory):
def __init__(
self,
params: PGParams,
@ -365,14 +301,16 @@ class PGAgentFactory(OnpolicyAgentFactory, _ActorMixin):
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
_ActorMixin.__init__(self, actor_factory, optim_factory)
self.params = params
self.actor_factory = actor_factory
self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> PGPolicy:
actor = self.actor_factory.create_module_opt(
envs, device, self.optim_factory, self.params.lr,
envs,
device,
self.optim_factory,
self.params.lr,
)
kwargs = self.params.create_kwargs(
ParamTransformerData(
@ -526,7 +464,7 @@ class DQNAgentFactory(OffpolicyAgentFactory):
)
class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
class DDPGAgentFactory(OffpolicyAgentFactory):
def __init__(
self,
params: DDPGParams,
@ -536,19 +474,25 @@ class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
_ActorAndCriticMixin.__init__(
self,
actor_factory,
critic_factory,
optim_factory,
critic_use_action=True,
)
self.critic_factory = critic_factory
self.actor_factory = actor_factory
self.params = params
self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
critic = self.create_critic_module_opt(envs, device, self.params.critic_lr)
actor = self.actor_factory.create_module_opt(
envs,
device,
self.optim_factory,
self.params.actor_lr,
)
critic = self.critic_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic_lr,
)
kwargs = self.params.create_kwargs(
ParamTransformerData(
envs=envs,
@ -569,7 +513,7 @@ class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
)
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
class SACAgentFactory(OffpolicyAgentFactory):
def __init__(
self,
params: SACParams,
@ -580,21 +524,33 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
_ActorAndDualCriticsMixin.__init__(
self,
actor_factory,
critic1_factory,
critic2_factory,
optim_factory,
critic_use_action=True,
)
self.params = params
self.actor_factory = actor_factory
self.critic1_factory = critic1_factory
self.critic2_factory = critic2_factory
self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
actor = self.actor_factory.create_module_opt(
envs,
device,
self.optim_factory,
self.params.actor_lr,
)
critic1 = self.critic1_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic1_lr,
)
critic2 = self.critic2_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic2_lr,
)
kwargs = self.params.create_kwargs(
ParamTransformerData(
envs=envs,
@ -618,7 +574,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
)
class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
class TD3AgentFactory(OffpolicyAgentFactory):
def __init__(
self,
params: TD3Params,
@ -629,21 +585,33 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
_ActorAndDualCriticsMixin.__init__(
self,
actor_factory,
critic1_factory,
critic2_factory,
optim_factory,
critic_use_action=True,
)
self.params = params
self.actor_factory = actor_factory
self.critic1_factory = critic1_factory
self.critic2_factory = critic2_factory
self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
actor = self.actor_factory.create_module_opt(
envs,
device,
self.optim_factory,
self.params.actor_lr,
)
critic1 = self.critic1_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic1_lr,
)
critic2 = self.critic2_factory.create_module_opt(
envs,
device,
True,
self.optim_factory,
self.params.critic2_lr,
)
kwargs = self.params.create_kwargs(
ParamTransformerData(
envs=envs,

View File

@ -26,7 +26,11 @@ class ActorFactory(ToStringMixin, ABC):
pass
def create_module_opt(
self, envs: Environments, device: TDevice, optim_factory: OptimizerFactory, lr: float,
self,
envs: Environments,
device: TDevice,
optim_factory: OptimizerFactory,
lr: float,
) -> ModuleOpt:
"""Creates the actor module along with its optimizer for the given learning rate.
@ -171,14 +175,3 @@ class ActorFactoryDiscreteNet(ActorFactory):
hidden_sizes=(),
device=device,
).to(device)
class ActorModuleOptFactory(ToStringMixin):
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_factory = actor_factory
self.optim_factory = optim_factory
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
actor = self.actor_factory.create_module(envs, device)
opt = self.optim_factory.create_optimizer(actor, lr)
return ModuleOpt(actor, opt)

View File

@ -17,6 +17,18 @@ class CriticFactory(ToStringMixin, ABC):
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
pass
def create_module_opt(
self,
envs: Environments,
device: TDevice,
use_action: bool,
optim_factory: OptimizerFactory,
lr: float,
) -> ModuleOpt:
module = self.create_module(envs, device, use_action)
opt = optim_factory.create_optimizer(module, lr)
return ModuleOpt(module, opt)
class CriticFactoryDefault(CriticFactory):
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
@ -80,20 +92,3 @@ class CriticFactoryDiscreteNet(CriticFactory):
critic = discrete.Critic(net_c, device=device).to(device)
init_linear_orthogonal(critic)
return critic
class CriticModuleOptFactory(ToStringMixin):
def __init__(
self,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
use_action: bool,
):
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.use_action = use_action
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
critic = self.critic_factory.create_module(envs, device, self.use_action)
opt = self.optim_factory.create_optimizer(critic, lr)
return ModuleOpt(critic, opt)