from dataclasses import dataclass import torch from tianshou.highlevel.env import Environments from tianshou.highlevel.module.actor import ActorFactory from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.module.critic import CriticFactory from tianshou.highlevel.optim import OptimizerFactory from tianshou.utils.net.common import ActorCritic from tianshou.utils.string import ToStringMixin @dataclass class ModuleOpt: module: torch.nn.Module optim: torch.optim.Optimizer @dataclass class ActorCriticModuleOpt: actor_critic_module: ActorCritic optim: torch.optim.Optimizer @property def actor(self): return self.actor_critic_module.actor @property def critic(self): return self.actor_critic_module.critic class ActorModuleOptFactory(ToStringMixin): def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory): self.actor_factory = actor_factory self.optim_factory = optim_factory def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt: actor = self.actor_factory.create_module(envs, device) opt = self.optim_factory.create_optimizer(actor, lr) return ModuleOpt(actor, opt) class CriticModuleOptFactory(ToStringMixin): def __init__( self, critic_factory: CriticFactory, optim_factory: OptimizerFactory, use_action: bool, ): self.critic_factory = critic_factory self.optim_factory = optim_factory self.use_action = use_action def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt: critic = self.critic_factory.create_module(envs, device, self.use_action) opt = self.optim_factory.create_optimizer(critic, lr) return ModuleOpt(critic, opt)