Add method get_output_dim to BaseActor
This commit is contained in:
		
							parent
							
								
									c7d0b6b4b2
								
							
						
					
					
						commit
						213e08a846
					
				@ -604,3 +604,7 @@ class BaseActor(nn.Module, ABC):
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def get_preprocess_net(self) -> nn.Module:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def get_output_dim(self) -> int:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
@ -63,6 +63,9 @@ class Actor(BaseActor):
 | 
			
		||||
    def get_preprocess_net(self) -> nn.Module:
 | 
			
		||||
        return self.preprocess
 | 
			
		||||
 | 
			
		||||
    def get_output_dim(self) -> int:
 | 
			
		||||
        return self.output_dim
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        obs: np.ndarray | torch.Tensor,
 | 
			
		||||
@ -213,6 +216,9 @@ class ActorProb(BaseActor):
 | 
			
		||||
    def get_preprocess_net(self) -> nn.Module:
 | 
			
		||||
        return self.preprocess
 | 
			
		||||
 | 
			
		||||
    def get_output_dim(self):
 | 
			
		||||
        return self.output_dim
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        obs: np.ndarray | torch.Tensor,
 | 
			
		||||
 | 
			
		||||
@ -64,6 +64,9 @@ class Actor(BaseActor):
 | 
			
		||||
    def get_preprocess_net(self) -> nn.Module:
 | 
			
		||||
        return self.preprocess
 | 
			
		||||
 | 
			
		||||
    def get_output_dim(self) -> int:
 | 
			
		||||
        return self.output_dim
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        obs: np.ndarray | torch.Tensor,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user