diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 7603454..5886912 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -604,3 +604,7 @@ class BaseActor(nn.Module, ABC): @abstractmethod def get_preprocess_net(self) -> nn.Module: pass + + @abstractmethod + def get_output_dim(self) -> int: + pass diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index c69c249..60c3569 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -63,6 +63,9 @@ class Actor(BaseActor): def get_preprocess_net(self) -> nn.Module: return self.preprocess + def get_output_dim(self) -> int: + return self.output_dim + def forward( self, obs: np.ndarray | torch.Tensor, @@ -213,6 +216,9 @@ class ActorProb(BaseActor): def get_preprocess_net(self) -> nn.Module: return self.preprocess + def get_output_dim(self): + return self.output_dim + def forward( self, obs: np.ndarray | torch.Tensor, diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 083cac5..ccdd69d 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -64,6 +64,9 @@ class Actor(BaseActor): def get_preprocess_net(self) -> nn.Module: return self.preprocess + def get_output_dim(self) -> int: + return self.output_dim + def forward( self, obs: np.ndarray | torch.Tensor,