| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  | import torch | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | from torch import nn | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-02 22:29:50 +08:00
										 |  |  | from tianshou.data import to_torch | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | class Actor(nn.Module): | 
					
						
							| 
									
										
										
										
											2020-07-09 22:57:01 +08:00
										 |  |  |     """For advanced usage (how to customize the network), please refer to
 | 
					
						
							|  |  |  |     :ref:`build_the_network`. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, preprocess_net, action_shape, | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  |                  max_action, device='cpu'): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							| 
									
										
										
										
											2020-07-09 22:57:01 +08:00
										 |  |  |         self.preprocess = preprocess_net | 
					
						
							|  |  |  |         self.last = nn.Linear(128, np.prod(action_shape)) | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  |         self._max = max_action | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-09 22:57:01 +08:00
										 |  |  |     def forward(self, s, state=None, info={}): | 
					
						
							|  |  |  |         logits, h = self.preprocess(s, state) | 
					
						
							|  |  |  |         logits = self._max * torch.tanh(self.last(logits)) | 
					
						
							|  |  |  |         return logits, h | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Critic(nn.Module): | 
					
						
							|  |  |  |     """For advanced usage (how to customize the network), please refer to
 | 
					
						
							|  |  |  |     :ref:`build_the_network`. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, preprocess_net, device='cpu'): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.device = device | 
					
						
							|  |  |  |         self.preprocess = preprocess_net | 
					
						
							|  |  |  |         self.last = nn.Linear(128, 1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, s, a=None, **kwargs): | 
					
						
							|  |  |  |         s = to_torch(s, device=self.device, dtype=torch.float32) | 
					
						
							|  |  |  |         s = s.flatten(1) | 
					
						
							|  |  |  |         if a is not None: | 
					
						
							|  |  |  |             a = to_torch(a, device=self.device, dtype=torch.float32) | 
					
						
							|  |  |  |             a = a.flatten(1) | 
					
						
							|  |  |  |             s = torch.cat([s, a], dim=1) | 
					
						
							|  |  |  |         logits, h = self.preprocess(s) | 
					
						
							|  |  |  |         logits = self.last(logits) | 
					
						
							|  |  |  |         return logits | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-21 17:04:42 +08:00
										 |  |  | class ActorProb(nn.Module): | 
					
						
							| 
									
										
										
										
											2020-07-09 22:57:01 +08:00
										 |  |  |     """For advanced usage (how to customize the network), please refer to
 | 
					
						
							|  |  |  |     :ref:`build_the_network`. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, preprocess_net, action_shape, | 
					
						
							|  |  |  |                  max_action, device='cpu', unbounded=False): | 
					
						
							| 
									
										
										
										
											2020-03-21 17:04:42 +08:00
										 |  |  |         super().__init__() | 
					
						
							| 
									
										
										
										
											2020-07-09 22:57:01 +08:00
										 |  |  |         self.preprocess = preprocess_net | 
					
						
							| 
									
										
										
										
											2020-03-21 17:04:42 +08:00
										 |  |  |         self.device = device | 
					
						
							|  |  |  |         self.mu = nn.Linear(128, np.prod(action_shape)) | 
					
						
							| 
									
										
										
										
											2020-04-19 14:30:42 +08:00
										 |  |  |         self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) | 
					
						
							| 
									
										
										
										
											2020-03-21 17:04:42 +08:00
										 |  |  |         self._max = max_action | 
					
						
							| 
									
										
										
										
											2020-07-09 22:57:01 +08:00
										 |  |  |         self._unbounded = unbounded | 
					
						
							| 
									
										
										
										
											2020-03-21 17:04:42 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-09 22:57:01 +08:00
										 |  |  |     def forward(self, s, state=None, **kwargs): | 
					
						
							|  |  |  |         logits, h = self.preprocess(s, state) | 
					
						
							| 
									
										
										
										
											2020-04-19 14:30:42 +08:00
										 |  |  |         mu = self.mu(logits) | 
					
						
							| 
									
										
										
										
											2020-07-09 22:57:01 +08:00
										 |  |  |         if not self._unbounded: | 
					
						
							|  |  |  |             mu = self._max * torch.tanh(mu) | 
					
						
							| 
									
										
										
										
											2020-04-19 14:30:42 +08:00
										 |  |  |         shape = [1] * len(mu.shape) | 
					
						
							|  |  |  |         shape[1] = -1 | 
					
						
							|  |  |  |         sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() | 
					
						
							| 
									
										
										
										
											2020-03-21 17:04:42 +08:00
										 |  |  |         return (mu, sigma), None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-30 16:31:40 +08:00
										 |  |  | class RecurrentActorProb(nn.Module): | 
					
						
							| 
									
										
										
										
											2020-07-09 22:57:01 +08:00
										 |  |  |     """For advanced usage (how to customize the network), please refer to
 | 
					
						
							|  |  |  |     :ref:`build_the_network`. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-30 16:31:40 +08:00
										 |  |  |     def __init__(self, layer_num, state_shape, action_shape, | 
					
						
							|  |  |  |                  max_action, device='cpu'): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.device = device | 
					
						
							|  |  |  |         self.nn = nn.LSTM(input_size=np.prod(state_shape), hidden_size=128, | 
					
						
							|  |  |  |                           num_layers=layer_num, batch_first=True) | 
					
						
							|  |  |  |         self.mu = nn.Linear(128, np.prod(action_shape)) | 
					
						
							|  |  |  |         self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, s, **kwargs): | 
					
						
							| 
									
										
										
										
											2020-07-09 22:57:01 +08:00
										 |  |  |         s = to_torch(s, device=self.device, dtype=torch.float32) | 
					
						
							| 
									
										
										
										
											2020-04-30 16:31:40 +08:00
										 |  |  |         # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) | 
					
						
							|  |  |  |         # In short, the tensor's shape in training phase is longer than which | 
					
						
							|  |  |  |         # in evaluation phase. | 
					
						
							|  |  |  |         if len(s.shape) == 2: | 
					
						
							| 
									
										
										
										
											2020-07-09 22:57:01 +08:00
										 |  |  |             s = s.unsqueeze(-2) | 
					
						
							| 
									
										
										
										
											2020-04-30 16:31:40 +08:00
										 |  |  |         logits, _ = self.nn(s) | 
					
						
							|  |  |  |         logits = logits[:, -1] | 
					
						
							|  |  |  |         mu = self.mu(logits) | 
					
						
							|  |  |  |         shape = [1] * len(mu.shape) | 
					
						
							|  |  |  |         shape[1] = -1 | 
					
						
							|  |  |  |         sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() | 
					
						
							|  |  |  |         return (mu, sigma), None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class RecurrentCritic(nn.Module): | 
					
						
							| 
									
										
										
										
											2020-07-09 22:57:01 +08:00
										 |  |  |     """For advanced usage (how to customize the network), please refer to
 | 
					
						
							|  |  |  |     :ref:`build_the_network`. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-30 16:31:40 +08:00
										 |  |  |     def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.state_shape = state_shape | 
					
						
							|  |  |  |         self.action_shape = action_shape | 
					
						
							|  |  |  |         self.device = device | 
					
						
							|  |  |  |         self.nn = nn.LSTM(input_size=np.prod(state_shape), hidden_size=128, | 
					
						
							|  |  |  |                           num_layers=layer_num, batch_first=True) | 
					
						
							|  |  |  |         self.fc2 = nn.Linear(128 + np.prod(action_shape), 1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, s, a=None): | 
					
						
							| 
									
										
										
										
											2020-07-09 22:57:01 +08:00
										 |  |  |         s = to_torch(s, device=self.device, dtype=torch.float32) | 
					
						
							| 
									
										
										
										
											2020-04-30 16:31:40 +08:00
										 |  |  |         # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) | 
					
						
							|  |  |  |         # In short, the tensor's shape in training phase is longer than which | 
					
						
							|  |  |  |         # in evaluation phase. | 
					
						
							|  |  |  |         assert len(s.shape) == 3 | 
					
						
							|  |  |  |         self.nn.flatten_parameters() | 
					
						
							|  |  |  |         s, (h, c) = self.nn(s) | 
					
						
							|  |  |  |         s = s[:, -1] | 
					
						
							|  |  |  |         if a is not None: | 
					
						
							|  |  |  |             if not isinstance(a, torch.Tensor): | 
					
						
							| 
									
										
										
										
											2020-07-09 22:57:01 +08:00
										 |  |  |                 a = torch.tensor(a, device=self.device, dtype=torch.float32) | 
					
						
							| 
									
										
										
										
											2020-04-30 16:31:40 +08:00
										 |  |  |             s = torch.cat([s, a], dim=1) | 
					
						
							|  |  |  |         s = self.fc2(s) | 
					
						
							|  |  |  |         return s |