60 lines
1.8 KiB
Python

from dataclasses import dataclass
import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import ActorFactory
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.critic import CriticFactory
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.string import ToStringMixin
@dataclass
class ModuleOpt:
module: torch.nn.Module
optim: torch.optim.Optimizer
@dataclass
class ActorCriticModuleOpt:
actor_critic_module: ActorCritic
optim: torch.optim.Optimizer
@property
def actor(self) -> torch.nn.Module:
return self.actor_critic_module.actor
@property
def critic(self) -> torch.nn.Module:
return self.actor_critic_module.critic
class ActorModuleOptFactory(ToStringMixin):
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_factory = actor_factory
self.optim_factory = optim_factory
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
actor = self.actor_factory.create_module(envs, device)
opt = self.optim_factory.create_optimizer(actor, lr)
return ModuleOpt(actor, opt)
class CriticModuleOptFactory(ToStringMixin):
def __init__(
self,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
use_action: bool,
):
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.use_action = use_action
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
critic = self.critic_factory.create_module(envs, device, self.use_action)
opt = self.optim_factory.create_optimizer(critic, lr)
return ModuleOpt(critic, opt)