2023-09-28 20:07:52 +02:00
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from tianshou.utils.net.common import ActorCritic
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class ModuleOpt:
|
|
|
|
module: torch.nn.Module
|
|
|
|
optim: torch.optim.Optimizer
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class ActorCriticModuleOpt:
|
|
|
|
actor_critic_module: ActorCritic
|
|
|
|
optim: torch.optim.Optimizer
|
|
|
|
|
|
|
|
@property
|
2023-10-09 17:22:52 +02:00
|
|
|
def actor(self) -> torch.nn.Module:
|
2023-09-28 20:07:52 +02:00
|
|
|
return self.actor_critic_module.actor
|
|
|
|
|
|
|
|
@property
|
2023-10-09 17:22:52 +02:00
|
|
|
def critic(self) -> torch.nn.Module:
|
2023-09-28 20:07:52 +02:00
|
|
|
return self.actor_critic_module.critic
|