| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  | import torch | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | from torch import nn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Actor(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, layer_num, state_shape, action_shape, | 
					
						
							|  |  |  |                  max_action, device='cpu'): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.device = device | 
					
						
							|  |  |  |         self.model = [ | 
					
						
							|  |  |  |             nn.Linear(np.prod(state_shape), 128), | 
					
						
							|  |  |  |             nn.ReLU(inplace=True)] | 
					
						
							|  |  |  |         for i in range(layer_num): | 
					
						
							|  |  |  |             self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] | 
					
						
							|  |  |  |         self.model += [nn.Linear(128, np.prod(action_shape))] | 
					
						
							|  |  |  |         self.model = nn.Sequential(*self.model) | 
					
						
							|  |  |  |         self._max = max_action | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, s, **kwargs): | 
					
						
							|  |  |  |         s = torch.tensor(s, device=self.device, dtype=torch.float) | 
					
						
							|  |  |  |         batch = s.shape[0] | 
					
						
							|  |  |  |         s = s.view(batch, -1) | 
					
						
							|  |  |  |         logits = self.model(s) | 
					
						
							|  |  |  |         logits = self._max * torch.tanh(logits) | 
					
						
							|  |  |  |         return logits, None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-21 17:04:42 +08:00
										 |  |  | class ActorProb(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, layer_num, state_shape, action_shape, | 
					
						
							|  |  |  |                  max_action, device='cpu'): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.device = device | 
					
						
							|  |  |  |         self.model = [ | 
					
						
							|  |  |  |             nn.Linear(np.prod(state_shape), 128), | 
					
						
							|  |  |  |             nn.ReLU(inplace=True)] | 
					
						
							|  |  |  |         for i in range(layer_num): | 
					
						
							|  |  |  |             self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] | 
					
						
							|  |  |  |         self.model = nn.Sequential(*self.model) | 
					
						
							|  |  |  |         self.mu = nn.Linear(128, np.prod(action_shape)) | 
					
						
							|  |  |  |         self.sigma = nn.Linear(128, np.prod(action_shape)) | 
					
						
							|  |  |  |         self._max = max_action | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, s, **kwargs): | 
					
						
							| 
									
										
										
										
											2020-03-25 14:08:28 +08:00
										 |  |  |         if not isinstance(s, torch.Tensor): | 
					
						
							|  |  |  |             s = torch.tensor(s, device=self.device, dtype=torch.float) | 
					
						
							| 
									
										
										
										
											2020-03-21 17:04:42 +08:00
										 |  |  |         batch = s.shape[0] | 
					
						
							|  |  |  |         s = s.view(batch, -1) | 
					
						
							|  |  |  |         logits = self.model(s) | 
					
						
							|  |  |  |         mu = self._max * torch.tanh(self.mu(logits)) | 
					
						
							|  |  |  |         sigma = torch.exp(self.sigma(logits)) | 
					
						
							|  |  |  |         return (mu, sigma), None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  | class Critic(nn.Module): | 
					
						
							| 
									
										
										
										
											2020-03-21 17:04:42 +08:00
										 |  |  |     def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'): | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.device = device | 
					
						
							|  |  |  |         self.model = [ | 
					
						
							|  |  |  |             nn.Linear(np.prod(state_shape) + np.prod(action_shape), 128), | 
					
						
							|  |  |  |             nn.ReLU(inplace=True)] | 
					
						
							|  |  |  |         for i in range(layer_num): | 
					
						
							|  |  |  |             self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] | 
					
						
							|  |  |  |         self.model += [nn.Linear(128, 1)] | 
					
						
							|  |  |  |         self.model = nn.Sequential(*self.model) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-21 17:04:42 +08:00
										 |  |  |     def forward(self, s, a=None): | 
					
						
							| 
									
										
										
										
											2020-03-25 14:08:28 +08:00
										 |  |  |         if not isinstance(s, torch.Tensor): | 
					
						
							|  |  |  |             s = torch.tensor(s, device=self.device, dtype=torch.float) | 
					
						
							|  |  |  |         if a is not None and not isinstance(a, torch.Tensor): | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  |             a = torch.tensor(a, device=self.device, dtype=torch.float) | 
					
						
							|  |  |  |         batch = s.shape[0] | 
					
						
							|  |  |  |         s = s.view(batch, -1) | 
					
						
							| 
									
										
										
										
											2020-03-21 17:04:42 +08:00
										 |  |  |         if a is None: | 
					
						
							|  |  |  |             logits = self.model(s) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             a = a.view(batch, -1) | 
					
						
							|  |  |  |             logits = self.model(torch.cat([s, a], dim=1)) | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  |         return logits |