Add base class BaseActor with method get_preprocess_net for high-level API
This commit is contained in:
parent
cd79cf8661
commit
e0e7349b0a
@ -595,3 +595,9 @@ def get_dict_state_decorator(
|
|||||||
return new_net_class
|
return new_net_class
|
||||||
|
|
||||||
return decorator_fn, new_state_shape
|
return decorator_fn, new_state_shape
|
||||||
|
|
||||||
|
|
||||||
|
class BaseActor(nn.Module, ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def get_preprocess_net(self) -> nn.Module:
|
||||||
|
pass
|
||||||
|
@ -6,13 +6,13 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.utils.net.common import MLP
|
from tianshou.utils.net.common import MLP, BaseActor
|
||||||
|
|
||||||
SIGMA_MIN = -20
|
SIGMA_MIN = -20
|
||||||
SIGMA_MAX = 2
|
SIGMA_MAX = 2
|
||||||
|
|
||||||
|
|
||||||
class Actor(nn.Module):
|
class Actor(BaseActor):
|
||||||
"""Simple actor network.
|
"""Simple actor network.
|
||||||
|
|
||||||
It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape.
|
It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape.
|
||||||
@ -60,6 +60,9 @@ class Actor(nn.Module):
|
|||||||
)
|
)
|
||||||
self.max_action = max_action
|
self.max_action = max_action
|
||||||
|
|
||||||
|
def get_preprocess_net(self) -> nn.Module:
|
||||||
|
return self.preprocess
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
obs: np.ndarray | torch.Tensor,
|
obs: np.ndarray | torch.Tensor,
|
||||||
@ -147,7 +150,7 @@ class Critic(nn.Module):
|
|||||||
return self.last(logits)
|
return self.last(logits)
|
||||||
|
|
||||||
|
|
||||||
class ActorProb(nn.Module):
|
class ActorProb(BaseActor):
|
||||||
"""Simple actor network (output with a Gauss distribution).
|
"""Simple actor network (output with a Gauss distribution).
|
||||||
|
|
||||||
:param preprocess_net: a self-defined preprocess_net which output a
|
:param preprocess_net: a self-defined preprocess_net which output a
|
||||||
@ -207,6 +210,9 @@ class ActorProb(nn.Module):
|
|||||||
self.max_action = max_action
|
self.max_action = max_action
|
||||||
self._unbounded = unbounded
|
self._unbounded = unbounded
|
||||||
|
|
||||||
|
def get_preprocess_net(self) -> nn.Module:
|
||||||
|
return self.preprocess
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
obs: np.ndarray | torch.Tensor,
|
obs: np.ndarray | torch.Tensor,
|
||||||
|
@ -7,10 +7,10 @@ import torch.nn.functional as F
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.data import Batch, to_torch
|
from tianshou.data import Batch, to_torch
|
||||||
from tianshou.utils.net.common import MLP
|
from tianshou.utils.net.common import MLP, BaseActor
|
||||||
|
|
||||||
|
|
||||||
class Actor(nn.Module):
|
class Actor(BaseActor):
|
||||||
"""Simple actor network.
|
"""Simple actor network.
|
||||||
|
|
||||||
Will create an actor operated in discrete action space with structure of
|
Will create an actor operated in discrete action space with structure of
|
||||||
@ -61,6 +61,9 @@ class Actor(nn.Module):
|
|||||||
)
|
)
|
||||||
self.softmax_output = softmax_output
|
self.softmax_output = softmax_output
|
||||||
|
|
||||||
|
def get_preprocess_net(self) -> nn.Module:
|
||||||
|
return self.preprocess
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
obs: np.ndarray | torch.Tensor,
|
obs: np.ndarray | torch.Tensor,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user