Add method get_output_dim to BaseActor
This commit is contained in:
parent
c7d0b6b4b2
commit
213e08a846
@ -604,3 +604,7 @@ class BaseActor(nn.Module, ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_preprocess_net(self) -> nn.Module:
|
def get_preprocess_net(self) -> nn.Module:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_output_dim(self) -> int:
|
||||||
|
pass
|
||||||
|
@ -63,6 +63,9 @@ class Actor(BaseActor):
|
|||||||
def get_preprocess_net(self) -> nn.Module:
|
def get_preprocess_net(self) -> nn.Module:
|
||||||
return self.preprocess
|
return self.preprocess
|
||||||
|
|
||||||
|
def get_output_dim(self) -> int:
|
||||||
|
return self.output_dim
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
obs: np.ndarray | torch.Tensor,
|
obs: np.ndarray | torch.Tensor,
|
||||||
@ -213,6 +216,9 @@ class ActorProb(BaseActor):
|
|||||||
def get_preprocess_net(self) -> nn.Module:
|
def get_preprocess_net(self) -> nn.Module:
|
||||||
return self.preprocess
|
return self.preprocess
|
||||||
|
|
||||||
|
def get_output_dim(self):
|
||||||
|
return self.output_dim
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
obs: np.ndarray | torch.Tensor,
|
obs: np.ndarray | torch.Tensor,
|
||||||
|
@ -64,6 +64,9 @@ class Actor(BaseActor):
|
|||||||
def get_preprocess_net(self) -> nn.Module:
|
def get_preprocess_net(self) -> nn.Module:
|
||||||
return self.preprocess
|
return self.preprocess
|
||||||
|
|
||||||
|
def get_output_dim(self) -> int:
|
||||||
|
return self.output_dim
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
obs: np.ndarray | torch.Tensor,
|
obs: np.ndarray | torch.Tensor,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user