| 
									
										
										
										
											2021-03-30 16:06:03 +08:00
										 |  |  | import warnings | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  | from copy import deepcopy | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | from typing import Any, Dict, Optional, Tuple, Union | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | import torch | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  | from tianshou.data import Batch, ReplayBuffer | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | from tianshou.exploration import BaseNoise, GaussianNoise | 
					
						
							|  |  |  | from tianshou.policy import BasePolicy | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class DDPGPolicy(BasePolicy): | 
					
						
							| 
									
										
										
										
											2020-09-11 07:55:37 +08:00
										 |  |  |     """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.
 | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     :param torch.nn.Module actor: the actor network following the rules in | 
					
						
							|  |  |  |         :class:`~tianshou.policy.BasePolicy`. (s -> logits) | 
					
						
							|  |  |  |     :param torch.optim.Optimizer actor_optim: the optimizer for actor network. | 
					
						
							|  |  |  |     :param torch.nn.Module critic: the critic network. (s, a -> Q(s, a)) | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |     :param torch.optim.Optimizer critic_optim: the optimizer for critic network. | 
					
						
							|  |  |  |     :param float tau: param for soft update of the target network. Default to 0.005. | 
					
						
							|  |  |  |     :param float gamma: discount factor, in [0, 1]. Default to 0.99. | 
					
						
							| 
									
										
										
										
											2020-06-16 22:17:28 +08:00
										 |  |  |     :param BaseNoise exploration_noise: the exploration noise, | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |         add to the action. Default to ``GaussianNoise(sigma=0.1)``. | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     :param bool reward_normalization: normalize the reward to Normal(0, 1), | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |         Default to False. | 
					
						
							|  |  |  |     :param int estimation_step: the number of steps to look ahead. Default to 1. | 
					
						
							| 
									
										
										
										
											2021-03-21 16:45:50 +08:00
										 |  |  |     :param bool action_scaling: whether to map actions from range [-1, 1] to range | 
					
						
							|  |  |  |         [action_spaces.low, action_spaces.high]. Default to True. | 
					
						
							|  |  |  |     :param str action_bound_method: method to bound action to range [-1, 1], can be | 
					
						
							| 
									
										
										
										
											2021-04-04 17:33:35 +08:00
										 |  |  |         either "clip" (for simply clipping the action) or empty string for no bounding. | 
					
						
							|  |  |  |         Default to "clip". | 
					
						
							| 
									
										
										
										
											2021-03-21 16:45:50 +08:00
										 |  |  |     :param Optional[gym.Space] action_space: env's action space, mandatory if you want | 
					
						
							|  |  |  |         to use option "action_scaling" or "action_bound_method". Default to None. | 
					
						
							| 
									
										
										
										
											2022-04-17 08:52:30 -07:00
										 |  |  |     :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in | 
					
						
							|  |  |  |         optimizer in each policy.update(). Default to None (no lr_scheduler). | 
					
						
							| 
									
										
										
										
											2020-04-09 21:36:53 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     .. seealso:: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed | 
					
						
							|  |  |  |         explanation. | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         actor: Optional[torch.nn.Module], | 
					
						
							|  |  |  |         actor_optim: Optional[torch.optim.Optimizer], | 
					
						
							|  |  |  |         critic: Optional[torch.nn.Module], | 
					
						
							|  |  |  |         critic_optim: Optional[torch.optim.Optimizer], | 
					
						
							|  |  |  |         tau: float = 0.005, | 
					
						
							|  |  |  |         gamma: float = 0.99, | 
					
						
							|  |  |  |         exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1), | 
					
						
							|  |  |  |         reward_normalization: bool = False, | 
					
						
							|  |  |  |         estimation_step: int = 1, | 
					
						
							| 
									
										
										
										
											2021-03-21 16:45:50 +08:00
										 |  |  |         action_scaling: bool = True, | 
					
						
							|  |  |  |         action_bound_method: str = "clip", | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         **kwargs: Any, | 
					
						
							|  |  |  |     ) -> None: | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         super().__init__( | 
					
						
							|  |  |  |             action_scaling=action_scaling, | 
					
						
							|  |  |  |             action_bound_method=action_bound_method, | 
					
						
							|  |  |  |             **kwargs | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-04-04 17:33:35 +08:00
										 |  |  |         assert action_bound_method != "tanh", "tanh mapping is not supported" \ | 
					
						
							|  |  |  |             "in policies where action is used as input of critic , because" \ | 
					
						
							|  |  |  |             "raw action in range (-inf, inf) will cause instability in training" | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |         if actor is not None and actor_optim is not None: | 
					
						
							|  |  |  |             self.actor: torch.nn.Module = actor | 
					
						
							|  |  |  |             self.actor_old = deepcopy(actor) | 
					
						
							| 
									
										
										
										
											2020-03-23 17:17:41 +08:00
										 |  |  |             self.actor_old.eval() | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |             self.actor_optim: torch.optim.Optimizer = actor_optim | 
					
						
							|  |  |  |         if critic is not None and critic_optim is not None: | 
					
						
							|  |  |  |             self.critic: torch.nn.Module = critic | 
					
						
							|  |  |  |             self.critic_old = deepcopy(critic) | 
					
						
							| 
									
										
										
										
											2020-03-23 11:34:52 +08:00
										 |  |  |             self.critic_old.eval() | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |             self.critic_optim: torch.optim.Optimizer = critic_optim | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]" | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |         self.tau = tau | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]" | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  |         self._gamma = gamma | 
					
						
							| 
									
										
										
										
											2020-06-16 22:17:28 +08:00
										 |  |  |         self._noise = exploration_noise | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         # it is only a little difference to use GaussianNoise | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  |         # self.noise = OUNoise() | 
					
						
							| 
									
										
										
										
											2020-03-21 15:31:31 +08:00
										 |  |  |         self._rew_norm = reward_normalization | 
					
						
							| 
									
										
										
										
											2020-06-03 13:59:47 +08:00
										 |  |  |         self._n_step = estimation_step | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-16 22:17:28 +08:00
										 |  |  |     def set_exp_noise(self, noise: Optional[BaseNoise]) -> None: | 
					
						
							|  |  |  |         """Set the exploration noise.""" | 
					
						
							|  |  |  |         self._noise = noise | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def train(self, mode: bool = True) -> "DDPGPolicy": | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |         """Set the module in training mode, except for the target network.""" | 
					
						
							| 
									
										
										
										
											2020-07-06 10:44:34 +08:00
										 |  |  |         self.training = mode | 
					
						
							|  |  |  |         self.actor.train(mode) | 
					
						
							|  |  |  |         self.critic.train(mode) | 
					
						
							|  |  |  |         return self | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def sync_weight(self) -> None: | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |         """Soft-update the weight for the target network.""" | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |         self.soft_update(self.actor_old, self.actor, self.tau) | 
					
						
							|  |  |  |         self.soft_update(self.critic_old, self.critic, self.tau) | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |         batch = buffer[indices]  # batch.obs_next: s_{t+n} | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |         target_q = self.critic_old( | 
					
						
							|  |  |  |             batch.obs_next, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             self(batch, model='actor_old', input='obs_next').act | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-06-03 13:59:47 +08:00
										 |  |  |         return target_q | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def process_fn( | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |         self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     ) -> Batch: | 
					
						
							| 
									
										
										
										
											2020-06-03 13:59:47 +08:00
										 |  |  |         batch = self.compute_nstep_return( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             batch, buffer, indices, self._target_q, self._gamma, self._n_step, | 
					
						
							|  |  |  |             self._rew_norm | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-03-23 17:17:41 +08:00
										 |  |  |         return batch | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def forward( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         batch: Batch, | 
					
						
							|  |  |  |         state: Optional[Union[dict, Batch, np.ndarray]] = None, | 
					
						
							|  |  |  |         model: str = "actor", | 
					
						
							|  |  |  |         input: str = "obs", | 
					
						
							|  |  |  |         **kwargs: Any, | 
					
						
							|  |  |  |     ) -> Batch: | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |         """Compute action over the given batch data.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :return: A :class:`~tianshou.data.Batch` which has 2 keys: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             * ``act`` the action. | 
					
						
							|  |  |  |             * ``state`` the hidden state. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-09 21:36:53 +08:00
										 |  |  |         .. seealso:: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-10 10:47:16 +08:00
										 |  |  |             Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for | 
					
						
							| 
									
										
										
										
											2020-04-09 21:36:53 +08:00
										 |  |  |             more detailed explanation. | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  |         model = getattr(self, model) | 
					
						
							| 
									
										
										
										
											2020-09-14 14:59:23 +08:00
										 |  |  |         obs = batch[input] | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |         actions, hidden = model(obs, state=state, info=batch.info) | 
					
						
							|  |  |  |         return Batch(act=actions, state=hidden) | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-07 19:21:02 +08:00
										 |  |  |     @staticmethod | 
					
						
							|  |  |  |     def _mse_optimizer( | 
					
						
							|  |  |  |         batch: Batch, critic: torch.nn.Module, optimizer: torch.optim.Optimizer | 
					
						
							|  |  |  |     ) -> Tuple[torch.Tensor, torch.Tensor]: | 
					
						
							|  |  |  |         """A simple wrapper script for updating critic network.""" | 
					
						
							|  |  |  |         weight = getattr(batch, "weight", 1.0) | 
					
						
							|  |  |  |         current_q = critic(batch.obs, batch.act).flatten() | 
					
						
							| 
									
										
										
										
											2020-07-23 15:12:02 +08:00
										 |  |  |         target_q = batch.returns.flatten() | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |         td = current_q - target_q | 
					
						
							| 
									
										
										
										
											2021-03-07 19:21:02 +08:00
										 |  |  |         # critic_loss = F.mse_loss(current_q1, target_q) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |         critic_loss = (td.pow(2) * weight).mean() | 
					
						
							| 
									
										
										
										
											2021-03-07 19:21:02 +08:00
										 |  |  |         optimizer.zero_grad() | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  |         critic_loss.backward() | 
					
						
							| 
									
										
										
										
											2021-03-07 19:21:02 +08:00
										 |  |  |         optimizer.step() | 
					
						
							|  |  |  |         return td, critic_loss | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: | 
					
						
							|  |  |  |         # critic | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) | 
					
						
							| 
									
										
										
										
											2021-03-07 19:21:02 +08:00
										 |  |  |         batch.weight = td  # prio-buffer | 
					
						
							|  |  |  |         # actor | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |         actor_loss = -self.critic(batch.obs, self(batch).act).mean() | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  |         self.actor_optim.zero_grad() | 
					
						
							|  |  |  |         actor_loss.backward() | 
					
						
							|  |  |  |         self.actor_optim.step() | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |         self.sync_weight() | 
					
						
							|  |  |  |         return { | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |             "loss/actor": actor_loss.item(), | 
					
						
							|  |  |  |             "loss/critic": critic_loss.item(), | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     def exploration_noise(self, act: Union[np.ndarray, Batch], | 
					
						
							|  |  |  |                           batch: Batch) -> Union[np.ndarray, Batch]: | 
					
						
							| 
									
										
										
										
											2021-03-30 16:06:03 +08:00
										 |  |  |         if self._noise is None: | 
					
						
							|  |  |  |             return act | 
					
						
							|  |  |  |         if isinstance(act, np.ndarray): | 
					
						
							|  |  |  |             return act + self._noise(act.shape) | 
					
						
							|  |  |  |         warnings.warn("Cannot add exploration noise to non-numpy_array action.") | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |         return act |