Dominik Jain 6bb3abb2f0 Support PG/Reinforce in high-level API
* Add example mujoco_reinforce_hl
* Extended functionality of ActorFactory to support creation of ModuleOpt
2023-10-18 20:44:17 +02:00

26 lines
505 B
Python

from dataclasses import dataclass
import torch
from tianshou.utils.net.common import ActorCritic
@dataclass
class ModuleOpt:
module: torch.nn.Module
optim: torch.optim.Optimizer
@dataclass
class ActorCriticModuleOpt:
actor_critic_module: ActorCritic
optim: torch.optim.Optimizer
@property
def actor(self) -> torch.nn.Module:
return self.actor_critic_module.actor
@property
def critic(self) -> torch.nn.Module:
return self.actor_critic_module.critic