60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
from dataclasses import dataclass
|
|
|
|
import torch
|
|
|
|
from tianshou.highlevel.env import Environments
|
|
from tianshou.highlevel.module.actor import ActorFactory
|
|
from tianshou.highlevel.module.core import TDevice
|
|
from tianshou.highlevel.module.critic import CriticFactory
|
|
from tianshou.highlevel.optim import OptimizerFactory
|
|
from tianshou.utils.net.common import ActorCritic
|
|
from tianshou.utils.string import ToStringMixin
|
|
|
|
|
|
@dataclass
|
|
class ModuleOpt:
|
|
module: torch.nn.Module
|
|
optim: torch.optim.Optimizer
|
|
|
|
|
|
@dataclass
|
|
class ActorCriticModuleOpt:
|
|
actor_critic_module: ActorCritic
|
|
optim: torch.optim.Optimizer
|
|
|
|
@property
|
|
def actor(self) -> torch.nn.Module:
|
|
return self.actor_critic_module.actor
|
|
|
|
@property
|
|
def critic(self) -> torch.nn.Module:
|
|
return self.actor_critic_module.critic
|
|
|
|
|
|
class ActorModuleOptFactory(ToStringMixin):
|
|
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
|
self.actor_factory = actor_factory
|
|
self.optim_factory = optim_factory
|
|
|
|
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
actor = self.actor_factory.create_module(envs, device)
|
|
opt = self.optim_factory.create_optimizer(actor, lr)
|
|
return ModuleOpt(actor, opt)
|
|
|
|
|
|
class CriticModuleOptFactory(ToStringMixin):
|
|
def __init__(
|
|
self,
|
|
critic_factory: CriticFactory,
|
|
optim_factory: OptimizerFactory,
|
|
use_action: bool,
|
|
):
|
|
self.critic_factory = critic_factory
|
|
self.optim_factory = optim_factory
|
|
self.use_action = use_action
|
|
|
|
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
critic = self.critic_factory.create_module(envs, device, self.use_action)
|
|
opt = self.optim_factory.create_optimizer(critic, lr)
|
|
return ModuleOpt(critic, opt)
|