Add method get_output_dim to BaseActor
This commit is contained in:
parent
c7d0b6b4b2
commit
213e08a846
@ -604,3 +604,7 @@ class BaseActor(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def get_preprocess_net(self) -> nn.Module:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_output_dim(self) -> int:
|
||||
pass
|
||||
|
@ -63,6 +63,9 @@ class Actor(BaseActor):
|
||||
def get_preprocess_net(self) -> nn.Module:
|
||||
return self.preprocess
|
||||
|
||||
def get_output_dim(self) -> int:
|
||||
return self.output_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
@ -213,6 +216,9 @@ class ActorProb(BaseActor):
|
||||
def get_preprocess_net(self) -> nn.Module:
|
||||
return self.preprocess
|
||||
|
||||
def get_output_dim(self):
|
||||
return self.output_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
|
@ -64,6 +64,9 @@ class Actor(BaseActor):
|
||||
def get_preprocess_net(self) -> nn.Module:
|
||||
return self.preprocess
|
||||
|
||||
def get_output_dim(self) -> int:
|
||||
return self.output_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
|
Loading…
x
Reference in New Issue
Block a user