Add method get_output_dim to BaseActor
This commit is contained in:
		
							parent
							
								
									c7d0b6b4b2
								
							
						
					
					
						commit
						213e08a846
					
				@ -604,3 +604,7 @@ class BaseActor(nn.Module, ABC):
 | 
				
			|||||||
    @abstractmethod
 | 
					    @abstractmethod
 | 
				
			||||||
    def get_preprocess_net(self) -> nn.Module:
 | 
					    def get_preprocess_net(self) -> nn.Module:
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def get_output_dim(self) -> int:
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
				
			|||||||
@ -63,6 +63,9 @@ class Actor(BaseActor):
 | 
				
			|||||||
    def get_preprocess_net(self) -> nn.Module:
 | 
					    def get_preprocess_net(self) -> nn.Module:
 | 
				
			||||||
        return self.preprocess
 | 
					        return self.preprocess
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_output_dim(self) -> int:
 | 
				
			||||||
 | 
					        return self.output_dim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        obs: np.ndarray | torch.Tensor,
 | 
					        obs: np.ndarray | torch.Tensor,
 | 
				
			||||||
@ -213,6 +216,9 @@ class ActorProb(BaseActor):
 | 
				
			|||||||
    def get_preprocess_net(self) -> nn.Module:
 | 
					    def get_preprocess_net(self) -> nn.Module:
 | 
				
			||||||
        return self.preprocess
 | 
					        return self.preprocess
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_output_dim(self):
 | 
				
			||||||
 | 
					        return self.output_dim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        obs: np.ndarray | torch.Tensor,
 | 
					        obs: np.ndarray | torch.Tensor,
 | 
				
			||||||
 | 
				
			|||||||
@ -64,6 +64,9 @@ class Actor(BaseActor):
 | 
				
			|||||||
    def get_preprocess_net(self) -> nn.Module:
 | 
					    def get_preprocess_net(self) -> nn.Module:
 | 
				
			||||||
        return self.preprocess
 | 
					        return self.preprocess
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_output_dim(self) -> int:
 | 
				
			||||||
 | 
					        return self.output_dim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        obs: np.ndarray | torch.Tensor,
 | 
					        obs: np.ndarray | torch.Tensor,
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user