From e0e7349b0a0fd69a4d2351b7018728583f734604 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 28 Sep 2023 20:08:55 +0200 Subject: [PATCH] Add base class BaseActor with method get_preprocess_net for high-level API --- tianshou/utils/net/common.py | 6 ++++++ tianshou/utils/net/continuous.py | 12 +++++++++--- tianshou/utils/net/discrete.py | 7 +++++-- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 631ed7b..4c263ed 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -595,3 +595,9 @@ def get_dict_state_decorator( return new_net_class return decorator_fn, new_state_shape + + +class BaseActor(nn.Module, ABC): + @abstractmethod + def get_preprocess_net(self) -> nn.Module: + pass diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index d0eb31c..f6009a4 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -6,13 +6,13 @@ import numpy as np import torch from torch import nn -from tianshou.utils.net.common import MLP +from tianshou.utils.net.common import MLP, BaseActor SIGMA_MIN = -20 SIGMA_MAX = 2 -class Actor(nn.Module): +class Actor(BaseActor): """Simple actor network. It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape. @@ -60,6 +60,9 @@ class Actor(nn.Module): ) self.max_action = max_action + def get_preprocess_net(self) -> nn.Module: + return self.preprocess + def forward( self, obs: np.ndarray | torch.Tensor, @@ -147,7 +150,7 @@ class Critic(nn.Module): return self.last(logits) -class ActorProb(nn.Module): +class ActorProb(BaseActor): """Simple actor network (output with a Gauss distribution). :param preprocess_net: a self-defined preprocess_net which output a @@ -207,6 +210,9 @@ class ActorProb(nn.Module): self.max_action = max_action self._unbounded = unbounded + def get_preprocess_net(self) -> nn.Module: + return self.preprocess + def forward( self, obs: np.ndarray | torch.Tensor, diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 75a34a5..a5bc6af 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -7,10 +7,10 @@ import torch.nn.functional as F from torch import nn from tianshou.data import Batch, to_torch -from tianshou.utils.net.common import MLP +from tianshou.utils.net.common import MLP, BaseActor -class Actor(nn.Module): +class Actor(BaseActor): """Simple actor network. Will create an actor operated in discrete action space with structure of @@ -61,6 +61,9 @@ class Actor(nn.Module): ) self.softmax_output = softmax_output + def get_preprocess_net(self) -> nn.Module: + return self.preprocess + def forward( self, obs: np.ndarray | torch.Tensor,