| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  | import torch | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from tianshou.data import Batch | 
					
						
							|  |  |  | from tianshou.policy import BasePolicy | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  | class PGPolicy(BasePolicy): | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     """Implementation of Vanilla Policy Gradient.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     :param torch.nn.Module model: a model following the rules in | 
					
						
							|  |  |  |         :class:`~tianshou.policy.BasePolicy`. (s -> logits) | 
					
						
							|  |  |  |     :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. | 
					
						
							|  |  |  |     :param torch.distributions.Distribution dist_fn: for computing the action. | 
					
						
							|  |  |  |     :param float discount_factor: in [0, 1]. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-17 15:16:30 +08:00
										 |  |  |     def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |                  discount_factor=0.99, **kwargs): | 
					
						
							| 
									
										
										
										
											2020-04-08 21:13:15 +08:00
										 |  |  |         super().__init__(**kwargs) | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |         self.model = model | 
					
						
							|  |  |  |         self.optim = optim | 
					
						
							| 
									
										
										
										
											2020-03-17 15:16:30 +08:00
										 |  |  |         self.dist_fn = dist_fn | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |         self._eps = np.finfo(np.float32).eps.item() | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |         assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]' | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |         self._gamma = discount_factor | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def process_fn(self, batch, buffer, indice): | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |         r"""Compute the discounted returns for each frame:
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         .. math:: | 
					
						
							|  |  |  |             G_t = \sum_{i=t}^T \gamma^{i-t}r_i | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         , where :math:`T` is the terminal time step, :math:`\gamma` is the | 
					
						
							|  |  |  |         discount factor, :math:`\gamma \in [0, 1]`. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |         batch.returns = self._vanilla_returns(batch) | 
					
						
							|  |  |  |         # batch.returns = self._vectorized_returns(batch) | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |         return batch | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     def __call__(self, batch, state=None, **kwargs): | 
					
						
							|  |  |  |         """Compute action over the given batch data.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :return: A :class:`~tianshou.data.Batch` which has 4 keys: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             * ``act`` the action. | 
					
						
							|  |  |  |             * ``logits`` the network's raw output. | 
					
						
							|  |  |  |             * ``dist`` the action distribution. | 
					
						
							|  |  |  |             * ``state`` the hidden state. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         More information can be found at | 
					
						
							|  |  |  |         :meth:`~tianshou.policy.BasePolicy.__call__`. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |         logits, h = self.model(batch.obs, state=state, info=batch.info) | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |         if isinstance(logits, tuple): | 
					
						
							|  |  |  |             dist = self.dist_fn(*logits) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             dist = self.dist_fn(logits) | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  |         act = dist.sample() | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |         return Batch(logits=logits, act=act, state=h, dist=dist) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     def learn(self, batch, batch_size=None, repeat=1, **kwargs): | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |         losses = [] | 
					
						
							| 
									
										
										
										
											2020-03-26 11:42:34 +08:00
										 |  |  |         r = batch.returns | 
					
						
							|  |  |  |         batch.returns = (r - r.mean()) / (r.std() + self._eps) | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         for _ in range(repeat): | 
					
						
							|  |  |  |             for b in batch.split(batch_size): | 
					
						
							|  |  |  |                 self.optim.zero_grad() | 
					
						
							|  |  |  |                 dist = self(b).dist | 
					
						
							|  |  |  |                 a = torch.tensor(b.act, device=dist.logits.device) | 
					
						
							|  |  |  |                 r = torch.tensor(b.returns, device=dist.logits.device) | 
					
						
							|  |  |  |                 loss = -(dist.log_prob(a) * r).sum() | 
					
						
							|  |  |  |                 loss.backward() | 
					
						
							|  |  |  |                 self.optim.step() | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |                 losses.append(loss.item()) | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |         return {'loss': losses} | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-17 20:22:37 +08:00
										 |  |  |     def _vanilla_returns(self, batch): | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |         returns = batch.rew[:] | 
					
						
							|  |  |  |         last = 0 | 
					
						
							| 
									
										
										
										
											2020-03-17 20:22:37 +08:00
										 |  |  |         for i in range(len(returns) - 1, -1, -1): | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |             if not batch.done[i]: | 
					
						
							|  |  |  |                 returns[i] += self._gamma * last | 
					
						
							|  |  |  |             last = returns[i] | 
					
						
							|  |  |  |         return returns | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-17 20:22:37 +08:00
										 |  |  |     def _vectorized_returns(self, batch): | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |         # according to my tests, it is slower than _vanilla_returns | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |         # import scipy.signal | 
					
						
							|  |  |  |         convolve = np.convolve | 
					
						
							|  |  |  |         # convolve = scipy.signal.convolve | 
					
						
							|  |  |  |         rew = batch.rew[::-1] | 
					
						
							| 
									
										
										
										
											2020-03-17 20:22:37 +08:00
										 |  |  |         batch_size = len(rew) | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |         gammas = self._gamma ** np.arange(batch_size) | 
					
						
							|  |  |  |         c = convolve(rew, gammas)[:batch_size] | 
					
						
							|  |  |  |         T = np.where(batch.done[::-1])[0] | 
					
						
							|  |  |  |         d = np.zeros_like(rew) | 
					
						
							|  |  |  |         d[T] += c[T] - rew[T] | 
					
						
							|  |  |  |         d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T) | 
					
						
							|  |  |  |         return (c - convolve(d, gammas)[:batch_size])[::-1] |