| 
									
										
										
										
											2020-03-17 20:22:37 +08:00
										 |  |  | import torch | 
					
						
							| 
									
										
										
										
											2020-04-14 21:11:06 +08:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  | from torch import nn | 
					
						
							| 
									
										
										
										
											2020-03-17 20:22:37 +08:00
										 |  |  | import torch.nn.functional as F | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  | from typing import Any, Dict, List, Type, Union, Optional | 
					
						
							| 
									
										
										
										
											2020-03-17 20:22:37 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | from tianshou.policy import PGPolicy | 
					
						
							| 
									
										
										
										
											2020-06-03 13:59:47 +08:00
										 |  |  | from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy | 
					
						
							| 
									
										
										
										
											2020-03-17 20:22:37 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class A2CPolicy(PGPolicy): | 
					
						
							| 
									
										
										
										
											2020-09-11 07:55:37 +08:00
										 |  |  |     """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783.
 | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     :param torch.nn.Module actor: the actor network following the rules in | 
					
						
							|  |  |  |         :class:`~tianshou.policy.BasePolicy`. (s -> logits) | 
					
						
							|  |  |  |     :param torch.nn.Module critic: the critic network. (s -> V(s)) | 
					
						
							|  |  |  |     :param torch.optim.Optimizer optim: the optimizer for actor and critic | 
					
						
							|  |  |  |         network. | 
					
						
							| 
									
										
										
										
											2020-09-11 07:55:37 +08:00
										 |  |  |     :param dist_fn: distribution class for computing the action. | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |     :type dist_fn: Type[torch.distributions.Distribution] | 
					
						
							|  |  |  |     :param float discount_factor: in [0, 1]. Default to 0.99. | 
					
						
							|  |  |  |     :param float vf_coef: weight for value loss. Default to 0.5. | 
					
						
							|  |  |  |     :param float ent_coef: weight for entropy loss. Default to 0.01. | 
					
						
							|  |  |  |     :param float max_grad_norm: clipping gradients in back propagation. | 
					
						
							|  |  |  |         Default to None. | 
					
						
							| 
									
										
										
										
											2020-04-14 21:11:06 +08:00
										 |  |  |     :param float gae_lambda: in [0, 1], param for Generalized Advantage | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |         Estimation. Default to 0.95. | 
					
						
							|  |  |  |     :param bool reward_normalization: normalize the reward to Normal(0, 1). | 
					
						
							|  |  |  |         Default to False. | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     :param int max_batchsize: the maximum size of the batch when computing GAE, | 
					
						
							|  |  |  |         depends on the size of available memory and the memory cost of the | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |         model; should be as large as possible within the memory constraint. | 
					
						
							|  |  |  |         Default to 256. | 
					
						
							| 
									
										
										
										
											2020-04-09 21:36:53 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     .. seealso:: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed | 
					
						
							|  |  |  |         explanation. | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2020-03-17 20:22:37 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         actor: torch.nn.Module, | 
					
						
							|  |  |  |         critic: torch.nn.Module, | 
					
						
							|  |  |  |         optim: torch.optim.Optimizer, | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |         dist_fn: Type[torch.distributions.Distribution], | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         discount_factor: float = 0.99, | 
					
						
							|  |  |  |         vf_coef: float = 0.5, | 
					
						
							|  |  |  |         ent_coef: float = 0.01, | 
					
						
							|  |  |  |         max_grad_norm: Optional[float] = None, | 
					
						
							|  |  |  |         gae_lambda: float = 0.95, | 
					
						
							|  |  |  |         reward_normalization: bool = False, | 
					
						
							|  |  |  |         max_batchsize: int = 256, | 
					
						
							|  |  |  |         **kwargs: Any | 
					
						
							|  |  |  |     ) -> None: | 
					
						
							| 
									
										
										
										
											2020-04-08 21:13:15 +08:00
										 |  |  |         super().__init__(None, optim, dist_fn, discount_factor, **kwargs) | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |         self.actor = actor | 
					
						
							|  |  |  |         self.critic = critic | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]." | 
					
						
							| 
									
										
										
										
											2020-04-14 21:11:06 +08:00
										 |  |  |         self._lambda = gae_lambda | 
					
						
							| 
									
										
										
										
											2021-01-20 02:13:04 -08:00
										 |  |  |         self._weight_vf = vf_coef | 
					
						
							|  |  |  |         self._weight_ent = ent_coef | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  |         self._grad_norm = max_grad_norm | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |         self._batch = max_batchsize | 
					
						
							| 
									
										
										
										
											2020-04-26 16:13:51 +08:00
										 |  |  |         self._rew_norm = reward_normalization | 
					
						
							| 
									
										
										
										
											2020-04-14 21:11:06 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def process_fn( | 
					
						
							|  |  |  |         self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray | 
					
						
							|  |  |  |     ) -> Batch: | 
					
						
							|  |  |  |         if self._lambda in [0.0, 1.0]: | 
					
						
							| 
									
										
										
										
											2020-04-14 21:11:06 +08:00
										 |  |  |             return self.compute_episodic_return( | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |                 batch, buffer, indice, | 
					
						
							|  |  |  |                 None, gamma=self._gamma, gae_lambda=self._lambda) | 
					
						
							| 
									
										
										
										
											2020-04-14 21:11:06 +08:00
										 |  |  |         v_ = [] | 
					
						
							|  |  |  |         with torch.no_grad(): | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |             for b in batch.split(self._batch, shuffle=False, merge_last=True): | 
					
						
							| 
									
										
										
										
											2020-05-29 14:45:21 +02:00
										 |  |  |                 v_.append(to_numpy(self.critic(b.obs_next))) | 
					
						
							| 
									
										
										
										
											2020-04-14 21:11:06 +08:00
										 |  |  |         v_ = np.concatenate(v_, axis=0) | 
					
						
							|  |  |  |         return self.compute_episodic_return( | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |             batch, buffer, indice, v_, | 
					
						
							|  |  |  |             gamma=self._gamma, gae_lambda=self._lambda, rew_norm=self._rew_norm) | 
					
						
							| 
									
										
										
										
											2020-03-17 20:22:37 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def forward( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         batch: Batch, | 
					
						
							|  |  |  |         state: Optional[Union[dict, Batch, np.ndarray]] = None, | 
					
						
							|  |  |  |         **kwargs: Any | 
					
						
							|  |  |  |     ) -> Batch: | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |         """Compute action over the given batch data.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :return: A :class:`~tianshou.data.Batch` which has 4 keys: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             * ``act`` the action. | 
					
						
							|  |  |  |             * ``logits`` the network's raw output. | 
					
						
							|  |  |  |             * ``dist`` the action distribution. | 
					
						
							|  |  |  |             * ``state`` the hidden state. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-09 21:36:53 +08:00
										 |  |  |         .. seealso:: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-10 10:47:16 +08:00
										 |  |  |             Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for | 
					
						
							| 
									
										
										
										
											2020-04-09 21:36:53 +08:00
										 |  |  |             more detailed explanation. | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |         logits, h = self.actor(batch.obs, state=state, info=batch.info) | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |         if isinstance(logits, tuple): | 
					
						
							|  |  |  |             dist = self.dist_fn(*logits) | 
					
						
							|  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |             dist = self.dist_fn(logits) | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  |         act = dist.sample() | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |         return Batch(logits=logits, act=act, state=h, dist=dist) | 
					
						
							| 
									
										
										
										
											2020-03-17 20:22:37 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |     def learn(  # type: ignore | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any | 
					
						
							|  |  |  |     ) -> Dict[str, List[float]]: | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         losses, actor_losses, vf_losses, ent_losses = [], [], [], [] | 
					
						
							|  |  |  |         for _ in range(repeat): | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |             for b in batch.split(batch_size, merge_last=True): | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |                 self.optim.zero_grad() | 
					
						
							| 
									
										
										
										
											2020-04-29 17:48:48 +08:00
										 |  |  |                 dist = self(b).dist | 
					
						
							| 
									
										
										
										
											2020-07-23 15:12:02 +08:00
										 |  |  |                 v = self.critic(b.obs).flatten() | 
					
						
							| 
									
										
										
										
											2020-06-03 13:59:47 +08:00
										 |  |  |                 a = to_torch_as(b.act, v) | 
					
						
							|  |  |  |                 r = to_torch_as(b.returns, v) | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |                 log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) | 
					
						
							| 
									
										
										
										
											2020-07-24 17:38:12 +08:00
										 |  |  |                 a_loss = -(log_prob * (r - v).detach()).mean() | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |                 vf_loss = F.mse_loss(r, v)  # type: ignore | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |                 ent_loss = dist.entropy().mean() | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |                 loss = a_loss + self._weight_vf * vf_loss - self._weight_ent * ent_loss | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |                 loss.backward() | 
					
						
							| 
									
										
										
										
											2020-06-03 13:59:47 +08:00
										 |  |  |                 if self._grad_norm is not None: | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |                     nn.utils.clip_grad_norm_( | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |                         list(self.actor.parameters()) + list(self.critic.parameters()), | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |                         max_norm=self._grad_norm, | 
					
						
							|  |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |                 self.optim.step() | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |                 actor_losses.append(a_loss.item()) | 
					
						
							|  |  |  |                 vf_losses.append(vf_loss.item()) | 
					
						
							|  |  |  |                 ent_losses.append(ent_loss.item()) | 
					
						
							|  |  |  |                 losses.append(loss.item()) | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         return { | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |             "loss": losses, | 
					
						
							|  |  |  |             "loss/actor": actor_losses, | 
					
						
							|  |  |  |             "loss/vf": vf_losses, | 
					
						
							|  |  |  |             "loss/ent": ent_losses, | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         } |