Add base class BaseActor with method get_preprocess_net for high-level API
This commit is contained in:
parent
cd79cf8661
commit
e0e7349b0a
@ -595,3 +595,9 @@ def get_dict_state_decorator(
|
||||
return new_net_class
|
||||
|
||||
return decorator_fn, new_state_shape
|
||||
|
||||
|
||||
class BaseActor(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def get_preprocess_net(self) -> nn.Module:
|
||||
pass
|
||||
|
@ -6,13 +6,13 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tianshou.utils.net.common import MLP
|
||||
from tianshou.utils.net.common import MLP, BaseActor
|
||||
|
||||
SIGMA_MIN = -20
|
||||
SIGMA_MAX = 2
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
class Actor(BaseActor):
|
||||
"""Simple actor network.
|
||||
|
||||
It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape.
|
||||
@ -60,6 +60,9 @@ class Actor(nn.Module):
|
||||
)
|
||||
self.max_action = max_action
|
||||
|
||||
def get_preprocess_net(self) -> nn.Module:
|
||||
return self.preprocess
|
||||
|
||||
def forward(
|
||||
self,
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
@ -147,7 +150,7 @@ class Critic(nn.Module):
|
||||
return self.last(logits)
|
||||
|
||||
|
||||
class ActorProb(nn.Module):
|
||||
class ActorProb(BaseActor):
|
||||
"""Simple actor network (output with a Gauss distribution).
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
@ -207,6 +210,9 @@ class ActorProb(nn.Module):
|
||||
self.max_action = max_action
|
||||
self._unbounded = unbounded
|
||||
|
||||
def get_preprocess_net(self) -> nn.Module:
|
||||
return self.preprocess
|
||||
|
||||
def forward(
|
||||
self,
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
|
@ -7,10 +7,10 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.utils.net.common import MLP
|
||||
from tianshou.utils.net.common import MLP, BaseActor
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
class Actor(BaseActor):
|
||||
"""Simple actor network.
|
||||
|
||||
Will create an actor operated in discrete action space with structure of
|
||||
@ -61,6 +61,9 @@ class Actor(nn.Module):
|
||||
)
|
||||
self.softmax_output = softmax_output
|
||||
|
||||
def get_preprocess_net(self) -> nn.Module:
|
||||
return self.preprocess
|
||||
|
||||
def forward(
|
||||
self,
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
|
Loading…
x
Reference in New Issue
Block a user