| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  | import torch | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  | from copy import deepcopy | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  | import torch.nn.functional as F | 
					
						
							| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | from tianshou.data import Batch | 
					
						
							|  |  |  | from tianshou.policy import BasePolicy | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  | class DQNPolicy(BasePolicy): | 
					
						
							| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  |     """docstring for DQNPolicy""" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  |     def __init__(self, model, optim, | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |                  discount_factor=0.99, | 
					
						
							|  |  |  |                  estimation_step=1, | 
					
						
							| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  |                  use_target_network=True): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.model = model | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |         self.optim = optim | 
					
						
							|  |  |  |         self.eps = 0 | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  |         assert 0 < discount_factor <= 1, 'discount_factor should in (0, 1]' | 
					
						
							| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  |         self._gamma = discount_factor | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |         assert estimation_step > 0, 'estimation_step should greater than 0' | 
					
						
							| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  |         self._n_step = estimation_step | 
					
						
							|  |  |  |         self._target = use_target_network | 
					
						
							|  |  |  |         if use_target_network: | 
					
						
							|  |  |  |             self.model_old = deepcopy(self.model) | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |             self.model_old.eval() | 
					
						
							| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |     def set_eps(self, eps): | 
					
						
							|  |  |  |         self.eps = eps | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def train(self): | 
					
						
							|  |  |  |         self.training = True | 
					
						
							|  |  |  |         self.model.train() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def eval(self): | 
					
						
							|  |  |  |         self.training = False | 
					
						
							|  |  |  |         self.model.eval() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def sync_weight(self): | 
					
						
							|  |  |  |         if self._target: | 
					
						
							|  |  |  |             self.model_old.load_state_dict(self.model.state_dict()) | 
					
						
							| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def process_fn(self, batch, buffer, indice): | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |         returns = np.zeros_like(indice) | 
					
						
							|  |  |  |         gammas = np.zeros_like(indice) + self._n_step | 
					
						
							|  |  |  |         for n in range(self._n_step - 1, -1, -1): | 
					
						
							|  |  |  |             now = (indice + n) % len(buffer) | 
					
						
							|  |  |  |             gammas[buffer.done[now] > 0] = n | 
					
						
							|  |  |  |             returns[buffer.done[now] > 0] = 0 | 
					
						
							|  |  |  |             returns = buffer.rew[now] + self._gamma * returns | 
					
						
							|  |  |  |         terminal = (indice + self._n_step - 1) % len(buffer) | 
					
						
							|  |  |  |         if self._target: | 
					
						
							|  |  |  |             # target_Q = Q_old(s_, argmax(Q_new(s_, *))) | 
					
						
							| 
									
										
										
										
											2020-03-16 11:11:29 +08:00
										 |  |  |             a = self(buffer[terminal], input='obs_next', eps=0).act | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |             target_q = self( | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |                 buffer[terminal], model='model_old', input='obs_next').logits | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |             if isinstance(target_q, torch.Tensor): | 
					
						
							|  |  |  |                 target_q = target_q.detach().cpu().numpy() | 
					
						
							|  |  |  |             target_q = target_q[np.arange(len(a)), a] | 
					
						
							|  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |             target_q = self(buffer[terminal], input='obs_next').logits | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |             if isinstance(target_q, torch.Tensor): | 
					
						
							|  |  |  |                 target_q = target_q.detach().cpu().numpy() | 
					
						
							|  |  |  |             target_q = target_q.max(axis=1) | 
					
						
							|  |  |  |         target_q[gammas != self._n_step] = 0 | 
					
						
							|  |  |  |         returns += (self._gamma ** gammas) * target_q | 
					
						
							|  |  |  |         batch.update(returns=returns) | 
					
						
							| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  |         return batch | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |     def __call__(self, batch, state=None, | 
					
						
							|  |  |  |                  model='model', input='obs', eps=None): | 
					
						
							|  |  |  |         model = getattr(self, model) | 
					
						
							|  |  |  |         obs = getattr(batch, input) | 
					
						
							|  |  |  |         q, h = model(obs, state=state, info=batch.info) | 
					
						
							|  |  |  |         act = q.max(dim=1)[1].detach().cpu().numpy() | 
					
						
							|  |  |  |         # add eps to act | 
					
						
							|  |  |  |         if eps is None: | 
					
						
							|  |  |  |             eps = self.eps | 
					
						
							|  |  |  |         for i in range(len(q)): | 
					
						
							|  |  |  |             if np.random.rand() < eps: | 
					
						
							|  |  |  |                 act[i] = np.random.randint(q.shape[1]) | 
					
						
							|  |  |  |         return Batch(logits=q, act=act, state=h) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |     def learn(self, batch, batch_size=None, repeat=1): | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |         self.optim.zero_grad() | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |         q = self(batch).logits | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |         q = q[np.arange(len(q)), batch.act] | 
					
						
							|  |  |  |         r = batch.returns | 
					
						
							|  |  |  |         if isinstance(r, np.ndarray): | 
					
						
							|  |  |  |             r = torch.tensor(r, device=q.device, dtype=q.dtype) | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  |         loss = F.mse_loss(q, r) | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |         loss.backward() | 
					
						
							|  |  |  |         self.optim.step() | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |         return {'loss': loss.detach().cpu().numpy()} |