2023-09-28 20:07:52 +02:00
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from tianshou.utils.net.common import ActorCritic
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class ModuleOpt:
|
2023-10-16 18:19:31 +02:00
|
|
|
"""Container for a torch module along with its optimizer."""
|
|
|
|
|
2023-09-28 20:07:52 +02:00
|
|
|
module: torch.nn.Module
|
|
|
|
optim: torch.optim.Optimizer
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class ActorCriticModuleOpt:
|
2023-10-16 18:19:31 +02:00
|
|
|
"""Container for an :class:`ActorCritic` instance along with its optimizer."""
|
|
|
|
|
2023-09-28 20:07:52 +02:00
|
|
|
actor_critic_module: ActorCritic
|
|
|
|
optim: torch.optim.Optimizer
|
|
|
|
|
|
|
|
@property
|
2023-10-09 17:22:52 +02:00
|
|
|
def actor(self) -> torch.nn.Module:
|
2023-09-28 20:07:52 +02:00
|
|
|
return self.actor_critic_module.actor
|
|
|
|
|
|
|
|
@property
|
2023-10-09 17:22:52 +02:00
|
|
|
def critic(self) -> torch.nn.Module:
|
2023-09-28 20:07:52 +02:00
|
|
|
return self.actor_critic_module.critic
|