From 213e08a8464cb1059e3dfd6dcf80ae2b2dc61915 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 11 Oct 2023 15:29:47 +0200 Subject: [PATCH] Add method get_output_dim to BaseActor --- tianshou/utils/net/common.py | 4 ++++ tianshou/utils/net/continuous.py | 6 ++++++ tianshou/utils/net/discrete.py | 3 +++ 3 files changed, 13 insertions(+) diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 7603454..5886912 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -604,3 +604,7 @@ class BaseActor(nn.Module, ABC): @abstractmethod def get_preprocess_net(self) -> nn.Module: pass + + @abstractmethod + def get_output_dim(self) -> int: + pass diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index c69c249..60c3569 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -63,6 +63,9 @@ class Actor(BaseActor): def get_preprocess_net(self) -> nn.Module: return self.preprocess + def get_output_dim(self) -> int: + return self.output_dim + def forward( self, obs: np.ndarray | torch.Tensor, @@ -213,6 +216,9 @@ class ActorProb(BaseActor): def get_preprocess_net(self) -> nn.Module: return self.preprocess + def get_output_dim(self): + return self.output_dim + def forward( self, obs: np.ndarray | torch.Tensor, diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 083cac5..ccdd69d 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -64,6 +64,9 @@ class Actor(BaseActor): def get_preprocess_net(self) -> nn.Module: return self.preprocess + def get_output_dim(self) -> int: + return self.output_dim + def forward( self, obs: np.ndarray | torch.Tensor,