Add base class BaseActor with method get_preprocess_net for high-level API

This commit is contained in:
Dominik Jain 2023-09-28 20:08:55 +02:00
parent cd79cf8661
commit e0e7349b0a
3 changed files with 20 additions and 5 deletions

View File

@ -595,3 +595,9 @@ def get_dict_state_decorator(
return new_net_class
return decorator_fn, new_state_shape
class BaseActor(nn.Module, ABC):
@abstractmethod
def get_preprocess_net(self) -> nn.Module:
pass

View File

@ -6,13 +6,13 @@ import numpy as np
import torch
from torch import nn
from tianshou.utils.net.common import MLP
from tianshou.utils.net.common import MLP, BaseActor
SIGMA_MIN = -20
SIGMA_MAX = 2
class Actor(nn.Module):
class Actor(BaseActor):
"""Simple actor network.
It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape.
@ -60,6 +60,9 @@ class Actor(nn.Module):
)
self.max_action = max_action
def get_preprocess_net(self) -> nn.Module:
return self.preprocess
def forward(
self,
obs: np.ndarray | torch.Tensor,
@ -147,7 +150,7 @@ class Critic(nn.Module):
return self.last(logits)
class ActorProb(nn.Module):
class ActorProb(BaseActor):
"""Simple actor network (output with a Gauss distribution).
:param preprocess_net: a self-defined preprocess_net which output a
@ -207,6 +210,9 @@ class ActorProb(nn.Module):
self.max_action = max_action
self._unbounded = unbounded
def get_preprocess_net(self) -> nn.Module:
return self.preprocess
def forward(
self,
obs: np.ndarray | torch.Tensor,

View File

@ -7,10 +7,10 @@ import torch.nn.functional as F
from torch import nn
from tianshou.data import Batch, to_torch
from tianshou.utils.net.common import MLP
from tianshou.utils.net.common import MLP, BaseActor
class Actor(nn.Module):
class Actor(BaseActor):
"""Simple actor network.
Will create an actor operated in discrete action space with structure of
@ -61,6 +61,9 @@ class Actor(nn.Module):
)
self.softmax_output = softmax_output
def get_preprocess_net(self) -> nn.Module:
return self.preprocess
def forward(
self,
obs: np.ndarray | torch.Tensor,