| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  | import torch | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  | from torch import nn | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  | from copy import deepcopy | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  | import torch.nn.functional as F | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from tianshou.data import Batch | 
					
						
							|  |  |  | from tianshou.policy import PGPolicy | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class PPOPolicy(PGPolicy): | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     :param torch.nn.Module actor: the actor network following the rules in | 
					
						
							|  |  |  |         :class:`~tianshou.policy.BasePolicy`. (s -> logits) | 
					
						
							|  |  |  |     :param torch.nn.Module critic: the critic network. (s -> V(s)) | 
					
						
							|  |  |  |     :param torch.optim.Optimizer optim: the optimizer for actor and critic | 
					
						
							|  |  |  |         network. | 
					
						
							|  |  |  |     :param torch.distributions.Distribution dist_fn: for computing the action. | 
					
						
							|  |  |  |     :param float discount_factor: in [0, 1], defaults to 0.99. | 
					
						
							|  |  |  |     :param float max_grad_norm: clipping gradients in back propagation, | 
					
						
							|  |  |  |         defaults to ``None``. | 
					
						
							|  |  |  |     :param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original | 
					
						
							|  |  |  |         paper, defaults to 0.2. | 
					
						
							|  |  |  |     :param float vf_coef: weight for value loss, defaults to 0.5. | 
					
						
							|  |  |  |     :param float ent_coef: weight for entropy loss, defaults to 0.01. | 
					
						
							|  |  |  |     :param action_range: the action range (minimum, maximum). | 
					
						
							|  |  |  |     :type action_range: [float, float] | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |     def __init__(self, actor, critic, optim, dist_fn, | 
					
						
							|  |  |  |                  discount_factor=0.99, | 
					
						
							|  |  |  |                  max_grad_norm=.5, | 
					
						
							|  |  |  |                  eps_clip=.2, | 
					
						
							|  |  |  |                  vf_coef=.5, | 
					
						
							|  |  |  |                  ent_coef=.0, | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |                  action_range=None, | 
					
						
							|  |  |  |                  **kwargs): | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |         super().__init__(None, None, dist_fn, discount_factor) | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         self._max_grad_norm = max_grad_norm | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |         self._eps_clip = eps_clip | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         self._w_vf = vf_coef | 
					
						
							|  |  |  |         self._w_ent = ent_coef | 
					
						
							|  |  |  |         self._range = action_range | 
					
						
							|  |  |  |         self.actor, self.actor_old = actor, deepcopy(actor) | 
					
						
							|  |  |  |         self.actor_old.eval() | 
					
						
							|  |  |  |         self.critic, self.critic_old = critic, deepcopy(critic) | 
					
						
							|  |  |  |         self.critic_old.eval() | 
					
						
							|  |  |  |         self.optim = optim | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |     def train(self): | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |         """Set the module in training mode, except for the target network.""" | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         self.training = True | 
					
						
							|  |  |  |         self.actor.train() | 
					
						
							|  |  |  |         self.critic.train() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def eval(self): | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |         """Set the module in evaluation mode, except for the target network.""" | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         self.training = False | 
					
						
							|  |  |  |         self.actor.eval() | 
					
						
							|  |  |  |         self.critic.eval() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     def __call__(self, batch, state=None, model='actor', **kwargs): | 
					
						
							|  |  |  |         """Compute action over the given batch data.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :return: A :class:`~tianshou.data.Batch` which has 4 keys: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             * ``act`` the action. | 
					
						
							|  |  |  |             * ``logits`` the network's raw output. | 
					
						
							|  |  |  |             * ``dist`` the action distribution. | 
					
						
							|  |  |  |             * ``state`` the hidden state. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         More information can be found at | 
					
						
							|  |  |  |         :meth:`~tianshou.policy.BasePolicy.__call__`. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         model = getattr(self, model) | 
					
						
							|  |  |  |         logits, h = model(batch.obs, state=state, info=batch.info) | 
					
						
							|  |  |  |         if isinstance(logits, tuple): | 
					
						
							|  |  |  |             dist = self.dist_fn(*logits) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             dist = self.dist_fn(logits) | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |         act = dist.sample() | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         if self._range: | 
					
						
							|  |  |  |             act = act.clamp(self._range[0], self._range[1]) | 
					
						
							|  |  |  |         return Batch(logits=logits, act=act, state=h, dist=dist) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def sync_weight(self): | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |         """Synchronize the weight for the target network.""" | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         self.actor_old.load_state_dict(self.actor.state_dict()) | 
					
						
							|  |  |  |         self.critic_old.load_state_dict(self.critic.state_dict()) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     def learn(self, batch, batch_size=None, repeat=1, **kwargs): | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         losses, clip_losses, vf_losses, ent_losses = [], [], [], [] | 
					
						
							| 
									
										
										
										
											2020-03-26 11:42:34 +08:00
										 |  |  |         r = batch.returns | 
					
						
							|  |  |  |         batch.returns = (r - r.mean()) / (r.std() + self._eps) | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         batch.act = torch.tensor(batch.act) | 
					
						
							|  |  |  |         batch.returns = torch.tensor(batch.returns)[:, None] | 
					
						
							|  |  |  |         for _ in range(repeat): | 
					
						
							|  |  |  |             for b in batch.split(batch_size): | 
					
						
							|  |  |  |                 vs_old, vs__old = self.critic_old(np.concatenate([ | 
					
						
							|  |  |  |                     b.obs, b.obs_next])).split(b.obs.shape[0]) | 
					
						
							|  |  |  |                 dist = self(b).dist | 
					
						
							|  |  |  |                 dist_old = self(b, model='actor_old').dist | 
					
						
							|  |  |  |                 target_v = b.returns.to(vs__old.device) + self._gamma * vs__old | 
					
						
							|  |  |  |                 adv = (target_v - vs_old).detach() | 
					
						
							|  |  |  |                 a = b.act.to(adv.device) | 
					
						
							|  |  |  |                 ratio = torch.exp(dist.log_prob(a) - dist_old.log_prob(a)) | 
					
						
							|  |  |  |                 surr1 = ratio * adv | 
					
						
							|  |  |  |                 surr2 = ratio.clamp( | 
					
						
							|  |  |  |                     1. - self._eps_clip, 1. + self._eps_clip) * adv | 
					
						
							|  |  |  |                 clip_loss = -torch.min(surr1, surr2).mean() | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |                 clip_losses.append(clip_loss.item()) | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |                 vf_loss = F.smooth_l1_loss(self.critic(b.obs), target_v) | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |                 vf_losses.append(vf_loss.item()) | 
					
						
							| 
									
										
										
										
											2020-03-26 11:42:34 +08:00
										 |  |  |                 e_loss = dist.entropy().mean() | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |                 ent_losses.append(e_loss.item()) | 
					
						
							| 
									
										
										
										
											2020-03-26 11:42:34 +08:00
										 |  |  |                 loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |                 losses.append(loss.item()) | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |                 self.optim.zero_grad() | 
					
						
							|  |  |  |                 loss.backward() | 
					
						
							|  |  |  |                 nn.utils.clip_grad_norm_(list( | 
					
						
							|  |  |  |                     self.actor.parameters()) + list(self.critic.parameters()), | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |                     self._max_grad_norm) | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |                 self.optim.step() | 
					
						
							|  |  |  |         self.sync_weight() | 
					
						
							|  |  |  |         return { | 
					
						
							|  |  |  |             'loss': losses, | 
					
						
							|  |  |  |             'loss/clip': clip_losses, | 
					
						
							|  |  |  |             'loss/vf': vf_losses, | 
					
						
							|  |  |  |             'loss/ent': ent_losses, | 
					
						
							|  |  |  |         } |