Add method get_output_dim to BaseActor

This commit is contained in:
Dominik Jain 2023-10-11 15:29:47 +02:00
parent c7d0b6b4b2
commit 213e08a846
3 changed files with 13 additions and 0 deletions

View File

@ -604,3 +604,7 @@ class BaseActor(nn.Module, ABC):
@abstractmethod
def get_preprocess_net(self) -> nn.Module:
pass
@abstractmethod
def get_output_dim(self) -> int:
pass

View File

@ -63,6 +63,9 @@ class Actor(BaseActor):
def get_preprocess_net(self) -> nn.Module:
return self.preprocess
def get_output_dim(self) -> int:
return self.output_dim
def forward(
self,
obs: np.ndarray | torch.Tensor,
@ -213,6 +216,9 @@ class ActorProb(BaseActor):
def get_preprocess_net(self) -> nn.Module:
return self.preprocess
def get_output_dim(self):
return self.output_dim
def forward(
self,
obs: np.ndarray | torch.Tensor,

View File

@ -64,6 +64,9 @@ class Actor(BaseActor):
def get_preprocess_net(self) -> nn.Module:
return self.preprocess
def get_output_dim(self) -> int:
return self.output_dim
def forward(
self,
obs: np.ndarray | torch.Tensor,