Add base class BaseActor with method get_preprocess_net for high-level API
This commit is contained in:
		
							parent
							
								
									cd79cf8661
								
							
						
					
					
						commit
						e0e7349b0a
					
				@ -595,3 +595,9 @@ def get_dict_state_decorator(
 | 
				
			|||||||
        return new_net_class
 | 
					        return new_net_class
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return decorator_fn, new_state_shape
 | 
					    return decorator_fn, new_state_shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BaseActor(nn.Module, ABC):
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def get_preprocess_net(self) -> nn.Module:
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
				
			|||||||
@ -6,13 +6,13 @@ import numpy as np
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch import nn
 | 
					from torch import nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from tianshou.utils.net.common import MLP
 | 
					from tianshou.utils.net.common import MLP, BaseActor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
SIGMA_MIN = -20
 | 
					SIGMA_MIN = -20
 | 
				
			||||||
SIGMA_MAX = 2
 | 
					SIGMA_MAX = 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Actor(nn.Module):
 | 
					class Actor(BaseActor):
 | 
				
			||||||
    """Simple actor network.
 | 
					    """Simple actor network.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape.
 | 
					    It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape.
 | 
				
			||||||
@ -60,6 +60,9 @@ class Actor(nn.Module):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
        self.max_action = max_action
 | 
					        self.max_action = max_action
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_preprocess_net(self) -> nn.Module:
 | 
				
			||||||
 | 
					        return self.preprocess
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        obs: np.ndarray | torch.Tensor,
 | 
					        obs: np.ndarray | torch.Tensor,
 | 
				
			||||||
@ -147,7 +150,7 @@ class Critic(nn.Module):
 | 
				
			|||||||
        return self.last(logits)
 | 
					        return self.last(logits)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ActorProb(nn.Module):
 | 
					class ActorProb(BaseActor):
 | 
				
			||||||
    """Simple actor network (output with a Gauss distribution).
 | 
					    """Simple actor network (output with a Gauss distribution).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    :param preprocess_net: a self-defined preprocess_net which output a
 | 
					    :param preprocess_net: a self-defined preprocess_net which output a
 | 
				
			||||||
@ -207,6 +210,9 @@ class ActorProb(nn.Module):
 | 
				
			|||||||
        self.max_action = max_action
 | 
					        self.max_action = max_action
 | 
				
			||||||
        self._unbounded = unbounded
 | 
					        self._unbounded = unbounded
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_preprocess_net(self) -> nn.Module:
 | 
				
			||||||
 | 
					        return self.preprocess
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        obs: np.ndarray | torch.Tensor,
 | 
					        obs: np.ndarray | torch.Tensor,
 | 
				
			||||||
 | 
				
			|||||||
@ -7,10 +7,10 @@ import torch.nn.functional as F
 | 
				
			|||||||
from torch import nn
 | 
					from torch import nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from tianshou.data import Batch, to_torch
 | 
					from tianshou.data import Batch, to_torch
 | 
				
			||||||
from tianshou.utils.net.common import MLP
 | 
					from tianshou.utils.net.common import MLP, BaseActor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Actor(nn.Module):
 | 
					class Actor(BaseActor):
 | 
				
			||||||
    """Simple actor network.
 | 
					    """Simple actor network.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Will create an actor operated in discrete action space with structure of
 | 
					    Will create an actor operated in discrete action space with structure of
 | 
				
			||||||
@ -61,6 +61,9 @@ class Actor(nn.Module):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
        self.softmax_output = softmax_output
 | 
					        self.softmax_output = softmax_output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_preprocess_net(self) -> nn.Module:
 | 
				
			||||||
 | 
					        return self.preprocess
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        obs: np.ndarray | torch.Tensor,
 | 
					        obs: np.ndarray | torch.Tensor,
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user