| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | from typing import Any, Dict, List, Type | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | import torch | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  | import torch.nn.functional as F | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | from torch import nn | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  | from torch.distributions import kl_divergence | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from tianshou.data import Batch, ReplayBuffer | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | from tianshou.policy import A2CPolicy | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class NPGPolicy(A2CPolicy): | 
					
						
							|  |  |  |     """Implementation of Natural Policy Gradient.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     :param torch.nn.Module actor: the actor network following the rules in | 
					
						
							|  |  |  |         :class:`~tianshou.policy.BasePolicy`. (s -> logits) | 
					
						
							|  |  |  |     :param torch.nn.Module critic: the critic network. (s -> V(s)) | 
					
						
							|  |  |  |     :param torch.optim.Optimizer optim: the optimizer for actor and critic network. | 
					
						
							|  |  |  |     :param dist_fn: distribution class for computing the action. | 
					
						
							|  |  |  |     :type dist_fn: Type[torch.distributions.Distribution] | 
					
						
							|  |  |  |     :param bool advantage_normalization: whether to do per mini-batch advantage | 
					
						
							|  |  |  |         normalization. Default to True. | 
					
						
							|  |  |  |     :param int optim_critic_iters: Number of times to optimize critic network per | 
					
						
							|  |  |  |         update. Default to 5. | 
					
						
							|  |  |  |     :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation. | 
					
						
							|  |  |  |         Default to 0.95. | 
					
						
							|  |  |  |     :param bool reward_normalization: normalize estimated values to have std close to | 
					
						
							|  |  |  |         1. Default to False. | 
					
						
							|  |  |  |     :param int max_batchsize: the maximum size of the batch when computing GAE, | 
					
						
							|  |  |  |         depends on the size of available memory and the memory cost of the | 
					
						
							|  |  |  |         model; should be as large as possible within the memory constraint. | 
					
						
							|  |  |  |         Default to 256. | 
					
						
							|  |  |  |     :param bool action_scaling: whether to map actions from range [-1, 1] to range | 
					
						
							|  |  |  |         [action_spaces.low, action_spaces.high]. Default to True. | 
					
						
							|  |  |  |     :param str action_bound_method: method to bound action to range [-1, 1], can be | 
					
						
							|  |  |  |         either "clip" (for simply clipping the action), "tanh" (for applying tanh | 
					
						
							|  |  |  |         squashing) for now, or empty string for no bounding. Default to "clip". | 
					
						
							|  |  |  |     :param Optional[gym.Space] action_space: env's action space, mandatory if you want | 
					
						
							|  |  |  |         to use option "action_scaling" or "action_bound_method". Default to None. | 
					
						
							|  |  |  |     :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in | 
					
						
							|  |  |  |         optimizer in each policy.update(). Default to None (no lr_scheduler). | 
					
						
							| 
									
										
										
										
											2021-04-27 21:22:39 +08:00
										 |  |  |     :param bool deterministic_eval: whether to use deterministic action instead of | 
					
						
							|  |  |  |         stochastic action sampled by the policy. Default to False. | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         actor: torch.nn.Module, | 
					
						
							|  |  |  |         critic: torch.nn.Module, | 
					
						
							|  |  |  |         optim: torch.optim.Optimizer, | 
					
						
							|  |  |  |         dist_fn: Type[torch.distributions.Distribution], | 
					
						
							|  |  |  |         advantage_normalization: bool = True, | 
					
						
							|  |  |  |         optim_critic_iters: int = 5, | 
					
						
							|  |  |  |         actor_step_size: float = 0.5, | 
					
						
							|  |  |  |         **kwargs: Any, | 
					
						
							|  |  |  |     ) -> None: | 
					
						
							|  |  |  |         super().__init__(actor, critic, optim, dist_fn, **kwargs) | 
					
						
							|  |  |  |         del self._weight_vf, self._weight_ent, self._grad_norm | 
					
						
							|  |  |  |         self._norm_adv = advantage_normalization | 
					
						
							|  |  |  |         self._optim_critic_iters = optim_critic_iters | 
					
						
							|  |  |  |         self._step_size = actor_step_size | 
					
						
							|  |  |  |         # adjusts Hessian-vector product calculation for numerical stability | 
					
						
							|  |  |  |         self._damping = 0.1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def process_fn( | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |         self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  |     ) -> Batch: | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |         batch = super().process_fn(batch, buffer, indices) | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  |         old_log_prob = [] | 
					
						
							|  |  |  |         with torch.no_grad(): | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |             for minibatch in batch.split(self._batch, shuffle=False, merge_last=True): | 
					
						
							|  |  |  |                 old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act)) | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  |         batch.logp_old = torch.cat(old_log_prob, dim=0) | 
					
						
							|  |  |  |         if self._norm_adv: | 
					
						
							|  |  |  |             batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std() | 
					
						
							|  |  |  |         return batch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def learn(  # type: ignore | 
					
						
							|  |  |  |         self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any | 
					
						
							|  |  |  |     ) -> Dict[str, List[float]]: | 
					
						
							|  |  |  |         actor_losses, vf_losses, kls = [], [], [] | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         for _ in range(repeat): | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |             for minibatch in batch.split(batch_size, merge_last=True): | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  |                 # optimize actor | 
					
						
							|  |  |  |                 # direction: calculate villia gradient | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |                 dist = self(minibatch).dist | 
					
						
							|  |  |  |                 log_prob = dist.log_prob(minibatch.act) | 
					
						
							| 
									
										
										
										
											2021-04-21 16:31:20 +08:00
										 |  |  |                 log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1) | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |                 actor_loss = -(log_prob * minibatch.adv).mean() | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  |                 flat_grads = self._get_flat_grad( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |                     actor_loss, self.actor, retain_graph=True | 
					
						
							|  |  |  |                 ).detach() | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 # direction: calculate natural gradient | 
					
						
							|  |  |  |                 with torch.no_grad(): | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |                     old_dist = self(minibatch).dist | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 kl = kl_divergence(old_dist, dist).mean() | 
					
						
							|  |  |  |                 # calculate first order gradient of kl with respect to theta | 
					
						
							|  |  |  |                 flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True) | 
					
						
							|  |  |  |                 search_direction = -self._conjugate_gradients( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |                     flat_grads, flat_kl_grad, nsteps=10 | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 # step | 
					
						
							|  |  |  |                 with torch.no_grad(): | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |                     flat_params = torch.cat( | 
					
						
							|  |  |  |                         [param.data.view(-1) for param in self.actor.parameters()] | 
					
						
							|  |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  |                     new_flat_params = flat_params + self._step_size * search_direction | 
					
						
							|  |  |  |                     self._set_from_flat_params(self.actor, new_flat_params) | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |                     new_dist = self(minibatch).dist | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  |                     kl = kl_divergence(old_dist, new_dist).mean() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # optimize citirc | 
					
						
							|  |  |  |                 for _ in range(self._optim_critic_iters): | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |                     value = self.critic(minibatch.obs).flatten() | 
					
						
							|  |  |  |                     vf_loss = F.mse_loss(minibatch.returns, value) | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  |                     self.optim.zero_grad() | 
					
						
							|  |  |  |                     vf_loss.backward() | 
					
						
							|  |  |  |                     self.optim.step() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 actor_losses.append(actor_loss.item()) | 
					
						
							|  |  |  |                 vf_losses.append(vf_loss.item()) | 
					
						
							|  |  |  |                 kls.append(kl.item()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return { | 
					
						
							|  |  |  |             "loss/actor": actor_losses, | 
					
						
							|  |  |  |             "loss/vf": vf_losses, | 
					
						
							|  |  |  |             "kl": kls, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor: | 
					
						
							|  |  |  |         """Matrix vector product.""" | 
					
						
							|  |  |  |         # caculate second order gradient of kl with respect to theta | 
					
						
							|  |  |  |         kl_v = (flat_kl_grad * v).sum() | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         flat_kl_grad_grad = self._get_flat_grad(kl_v, self.actor, | 
					
						
							|  |  |  |                                                 retain_graph=True).detach() | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  |         return flat_kl_grad_grad + v * self._damping | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _conjugate_gradients( | 
					
						
							|  |  |  |         self, | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |         minibatch: torch.Tensor, | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  |         flat_kl_grad: torch.Tensor, | 
					
						
							|  |  |  |         nsteps: int = 10, | 
					
						
							|  |  |  |         residual_tol: float = 1e-10 | 
					
						
							|  |  |  |     ) -> torch.Tensor: | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |         x = torch.zeros_like(minibatch) | 
					
						
							|  |  |  |         r, p = minibatch.clone(), minibatch.clone() | 
					
						
							|  |  |  |         # Note: should be 'r, p = minibatch - MVP(x)', but for x=0, MVP(x)=0. | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  |         # Change if doing warm start. | 
					
						
							|  |  |  |         rdotr = r.dot(r) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         for _ in range(nsteps): | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  |             z = self._MVP(p, flat_kl_grad) | 
					
						
							|  |  |  |             alpha = rdotr / p.dot(z) | 
					
						
							|  |  |  |             x += alpha * p | 
					
						
							|  |  |  |             r -= alpha * z | 
					
						
							|  |  |  |             new_rdotr = r.dot(r) | 
					
						
							|  |  |  |             if new_rdotr < residual_tol: | 
					
						
							|  |  |  |                 break | 
					
						
							|  |  |  |             p = r + new_rdotr / rdotr * p | 
					
						
							|  |  |  |             rdotr = new_rdotr | 
					
						
							|  |  |  |         return x | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _get_flat_grad( | 
					
						
							|  |  |  |         self, y: torch.Tensor, model: nn.Module, **kwargs: Any | 
					
						
							|  |  |  |     ) -> torch.Tensor: | 
					
						
							|  |  |  |         grads = torch.autograd.grad(y, model.parameters(), **kwargs)  # type: ignore | 
					
						
							|  |  |  |         return torch.cat([grad.reshape(-1) for grad in grads]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _set_from_flat_params( | 
					
						
							|  |  |  |         self, model: nn.Module, flat_params: torch.Tensor | 
					
						
							|  |  |  |     ) -> nn.Module: | 
					
						
							|  |  |  |         prev_ind = 0 | 
					
						
							|  |  |  |         for param in model.parameters(): | 
					
						
							|  |  |  |             flat_size = int(np.prod(list(param.size()))) | 
					
						
							|  |  |  |             param.data.copy_( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |                 flat_params[prev_ind:prev_ind + flat_size].view(param.size()) | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2021-04-21 09:52:15 +08:00
										 |  |  |             prev_ind += flat_size | 
					
						
							|  |  |  |         return model |