59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
|
from dataclasses import dataclass
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from tianshou.highlevel.env import Environments
|
||
|
from tianshou.highlevel.module.actor import ActorFactory
|
||
|
from tianshou.highlevel.module.core import TDevice
|
||
|
from tianshou.highlevel.module.critic import CriticFactory
|
||
|
from tianshou.highlevel.optim import OptimizerFactory
|
||
|
from tianshou.utils.net.common import ActorCritic
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class ModuleOpt:
|
||
|
module: torch.nn.Module
|
||
|
optim: torch.optim.Optimizer
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class ActorCriticModuleOpt:
|
||
|
actor_critic_module: ActorCritic
|
||
|
optim: torch.optim.Optimizer
|
||
|
|
||
|
@property
|
||
|
def actor(self):
|
||
|
return self.actor_critic_module.actor
|
||
|
|
||
|
@property
|
||
|
def critic(self):
|
||
|
return self.actor_critic_module.critic
|
||
|
|
||
|
|
||
|
class ActorModuleOptFactory:
|
||
|
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
||
|
self.actor_factory = actor_factory
|
||
|
self.optim_factory = optim_factory
|
||
|
|
||
|
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||
|
actor = self.actor_factory.create_module(envs, device)
|
||
|
opt = self.optim_factory.create_optimizer(actor, lr)
|
||
|
return ModuleOpt(actor, opt)
|
||
|
|
||
|
|
||
|
class CriticModuleOptFactory:
|
||
|
def __init__(
|
||
|
self,
|
||
|
critic_factory: CriticFactory,
|
||
|
optim_factory: OptimizerFactory,
|
||
|
use_action: bool,
|
||
|
):
|
||
|
self.critic_factory = critic_factory
|
||
|
self.optim_factory = optim_factory
|
||
|
self.use_action = use_action
|
||
|
|
||
|
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||
|
critic = self.critic_factory.create_module(envs, device, self.use_action)
|
||
|
opt = self.optim_factory.create_optimizer(critic, lr)
|
||
|
return ModuleOpt(critic, opt)
|