| 
									
										
										
										
											2020-03-23 11:34:52 +08:00
										 |  |  | from copy import deepcopy | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | from typing import Any, Dict, Optional, Tuple, Union | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | import torch | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  | from torch.distributions import Independent, Normal | 
					
						
							| 
									
										
										
										
											2020-03-23 11:34:52 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-21 16:45:50 +08:00
										 |  |  | from tianshou.data import Batch, ReplayBuffer, to_torch_as | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | from tianshou.exploration import BaseNoise | 
					
						
							|  |  |  | from tianshou.policy import DDPGPolicy | 
					
						
							| 
									
										
										
										
											2020-03-23 11:34:52 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SACPolicy(DDPGPolicy): | 
					
						
							| 
									
										
										
										
											2020-09-11 07:55:37 +08:00
										 |  |  |     """Implementation of Soft Actor-Critic. arXiv:1812.05905.
 | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     :param torch.nn.Module actor: the actor network following the rules in | 
					
						
							|  |  |  |         :class:`~tianshou.policy.BasePolicy`. (s -> logits) | 
					
						
							|  |  |  |     :param torch.optim.Optimizer actor_optim: the optimizer for actor network. | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |     :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a)) | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     :param torch.optim.Optimizer critic1_optim: the optimizer for the first | 
					
						
							|  |  |  |         critic network. | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |     :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a)) | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     :param torch.optim.Optimizer critic2_optim: the optimizer for the second | 
					
						
							|  |  |  |         critic network. | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |     :param float tau: param for soft update of the target network. Default to 0.005. | 
					
						
							|  |  |  |     :param float gamma: discount factor, in [0, 1]. Default to 0.99. | 
					
						
							| 
									
										
										
										
											2020-06-16 22:17:28 +08:00
										 |  |  |     :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |         regularization coefficient. Default to 0.2. | 
					
						
							| 
									
										
										
										
											2020-06-16 22:17:28 +08:00
										 |  |  |         If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         alpha is automatically tuned. | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |     :param bool reward_normalization: normalize the reward to Normal(0, 1). | 
					
						
							|  |  |  |         Default to False. | 
					
						
							|  |  |  |     :param BaseNoise exploration_noise: add a noise to action for exploration. | 
					
						
							|  |  |  |         Default to None. This is useful when solving hard-exploration problem. | 
					
						
							| 
									
										
										
										
											2020-11-09 16:43:55 +08:00
										 |  |  |     :param bool deterministic_eval: whether to use deterministic action (mean | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |         of Gaussian policy) instead of stochastic action sampled by the policy. | 
					
						
							|  |  |  |         Default to True. | 
					
						
							| 
									
										
										
										
											2021-03-21 16:45:50 +08:00
										 |  |  |     :param bool action_scaling: whether to map actions from range [-1, 1] to range | 
					
						
							|  |  |  |         [action_spaces.low, action_spaces.high]. Default to True. | 
					
						
							|  |  |  |     :param str action_bound_method: method to bound action to range [-1, 1], can be | 
					
						
							| 
									
										
										
										
											2021-04-04 17:33:35 +08:00
										 |  |  |         either "clip" (for simply clipping the action) or empty string for no bounding. | 
					
						
							|  |  |  |         Default to "clip". | 
					
						
							| 
									
										
										
										
											2021-03-21 16:45:50 +08:00
										 |  |  |     :param Optional[gym.Space] action_space: env's action space, mandatory if you want | 
					
						
							|  |  |  |         to use option "action_scaling" or "action_bound_method". Default to None. | 
					
						
							| 
									
										
										
										
											2020-04-09 21:36:53 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     .. seealso:: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed | 
					
						
							|  |  |  |         explanation. | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2020-03-23 17:17:41 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 08:44:50 +08:00
										 |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         actor: torch.nn.Module, | 
					
						
							|  |  |  |         actor_optim: torch.optim.Optimizer, | 
					
						
							|  |  |  |         critic1: torch.nn.Module, | 
					
						
							|  |  |  |         critic1_optim: torch.optim.Optimizer, | 
					
						
							|  |  |  |         critic2: torch.nn.Module, | 
					
						
							|  |  |  |         critic2_optim: torch.optim.Optimizer, | 
					
						
							|  |  |  |         tau: float = 0.005, | 
					
						
							|  |  |  |         gamma: float = 0.99, | 
					
						
							| 
									
										
										
										
											2021-02-27 11:20:43 +08:00
										 |  |  |         alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, | 
					
						
							| 
									
										
										
										
											2020-09-12 08:44:50 +08:00
										 |  |  |         reward_normalization: bool = False, | 
					
						
							|  |  |  |         estimation_step: int = 1, | 
					
						
							|  |  |  |         exploration_noise: Optional[BaseNoise] = None, | 
					
						
							| 
									
										
										
										
											2020-11-09 16:43:55 +08:00
										 |  |  |         deterministic_eval: bool = True, | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         **kwargs: Any, | 
					
						
							| 
									
										
										
										
											2020-09-12 08:44:50 +08:00
										 |  |  |     ) -> None: | 
					
						
							| 
									
										
										
										
											2021-03-21 16:45:50 +08:00
										 |  |  |         super().__init__( | 
					
						
							|  |  |  |             None, None, None, None, tau, gamma, exploration_noise, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             reward_normalization, estimation_step, **kwargs | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-03-23 17:17:41 +08:00
										 |  |  |         self.actor, self.actor_optim = actor, actor_optim | 
					
						
							|  |  |  |         self.critic1, self.critic1_old = critic1, deepcopy(critic1) | 
					
						
							|  |  |  |         self.critic1_old.eval() | 
					
						
							|  |  |  |         self.critic1_optim = critic1_optim | 
					
						
							|  |  |  |         self.critic2, self.critic2_old = critic2, deepcopy(critic2) | 
					
						
							|  |  |  |         self.critic2_old.eval() | 
					
						
							|  |  |  |         self.critic2_optim = critic2_optim | 
					
						
							| 
									
										
										
										
											2020-06-16 22:17:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 08:44:50 +08:00
										 |  |  |         self._is_auto_alpha = False | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         self._alpha: Union[float, torch.Tensor] | 
					
						
							| 
									
										
										
										
											2020-09-12 08:44:50 +08:00
										 |  |  |         if isinstance(alpha, tuple): | 
					
						
							|  |  |  |             self._is_auto_alpha = True | 
					
						
							|  |  |  |             self._target_entropy, self._log_alpha, self._alpha_optim = alpha | 
					
						
							|  |  |  |             assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad | 
					
						
							|  |  |  |             self._alpha = self._log_alpha.detach().exp() | 
					
						
							| 
									
										
										
										
											2020-06-16 22:17:28 +08:00
										 |  |  |         else: | 
					
						
							|  |  |  |             self._alpha = alpha | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-09 16:43:55 +08:00
										 |  |  |         self._deterministic_eval = deterministic_eval | 
					
						
							| 
									
										
										
										
											2020-03-23 17:17:41 +08:00
										 |  |  |         self.__eps = np.finfo(np.float32).eps.item() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def train(self, mode: bool = True) -> "SACPolicy": | 
					
						
							| 
									
										
										
										
											2020-07-06 10:44:34 +08:00
										 |  |  |         self.training = mode | 
					
						
							|  |  |  |         self.actor.train(mode) | 
					
						
							|  |  |  |         self.critic1.train(mode) | 
					
						
							|  |  |  |         self.critic2.train(mode) | 
					
						
							|  |  |  |         return self | 
					
						
							| 
									
										
										
										
											2020-03-23 17:17:41 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def sync_weight(self) -> None: | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |         self.soft_update(self.critic1_old, self.critic1, self.tau) | 
					
						
							|  |  |  |         self.soft_update(self.critic2_old, self.critic2, self.tau) | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |     def forward(  # type: ignore | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         self, | 
					
						
							|  |  |  |         batch: Batch, | 
					
						
							|  |  |  |         state: Optional[Union[dict, Batch, np.ndarray]] = None, | 
					
						
							|  |  |  |         input: str = "obs", | 
					
						
							|  |  |  |         **kwargs: Any, | 
					
						
							|  |  |  |     ) -> Batch: | 
					
						
							| 
									
										
										
										
											2020-09-14 14:59:23 +08:00
										 |  |  |         obs = batch[input] | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |         logits, hidden = self.actor(obs, state=state, info=batch.info) | 
					
						
							| 
									
										
										
										
											2020-03-23 17:17:41 +08:00
										 |  |  |         assert isinstance(logits, tuple) | 
					
						
							| 
									
										
										
										
											2020-09-12 08:44:50 +08:00
										 |  |  |         dist = Independent(Normal(*logits), 1) | 
					
						
							| 
									
										
										
										
											2020-11-09 16:43:55 +08:00
										 |  |  |         if self._deterministic_eval and not self.training: | 
					
						
							| 
									
										
										
										
											2021-03-21 16:45:50 +08:00
										 |  |  |             act = logits[0] | 
					
						
							| 
									
										
										
										
											2020-11-09 16:43:55 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2021-03-21 16:45:50 +08:00
										 |  |  |             act = dist.rsample() | 
					
						
							|  |  |  |         log_prob = dist.log_prob(act).unsqueeze(-1) | 
					
						
							| 
									
										
										
										
											2021-04-04 17:33:35 +08:00
										 |  |  |         # apply correction for Tanh squashing when computing logprob from Gaussian | 
					
						
							|  |  |  |         # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. | 
					
						
							|  |  |  |         # in appendix C to get some understanding of this equation. | 
					
						
							|  |  |  |         if self.action_scaling and self.action_space is not None: | 
					
						
							|  |  |  |             action_scale = to_torch_as( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |                 (self.action_space.high - self.action_space.low) / 2.0, act | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2021-04-04 17:33:35 +08:00
										 |  |  |         else: | 
					
						
							|  |  |  |             action_scale = 1.0  # type: ignore | 
					
						
							|  |  |  |         squashed_action = torch.tanh(act) | 
					
						
							|  |  |  |         log_prob = log_prob - torch.log( | 
					
						
							|  |  |  |             action_scale * (1 - squashed_action.pow(2)) + self.__eps | 
					
						
							|  |  |  |         ).sum(-1, keepdim=True) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         return Batch( | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |             logits=logits, | 
					
						
							|  |  |  |             act=squashed_action, | 
					
						
							|  |  |  |             state=hidden, | 
					
						
							|  |  |  |             dist=dist, | 
					
						
							|  |  |  |             log_prob=log_prob | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-03-23 11:34:52 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-20 09:58:44 -04:00
										 |  |  |     def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: | 
					
						
							|  |  |  |         batch = buffer[indices]  # batch.obs: s_{t+n} | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |         obs_next_result = self(batch, input="obs_next") | 
					
						
							|  |  |  |         act_ = obs_next_result.act | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |         target_q = torch.min( | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |             self.critic1_old(batch.obs_next, act_), | 
					
						
							|  |  |  |             self.critic2_old(batch.obs_next, act_), | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |         ) - self._alpha * obs_next_result.log_prob | 
					
						
							| 
									
										
										
										
											2020-06-03 13:59:47 +08:00
										 |  |  |         return target_q | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: | 
					
						
							| 
									
										
										
										
											2021-03-07 19:21:02 +08:00
										 |  |  |         # critic 1&2 | 
					
						
							|  |  |  |         td1, critic1_loss = self._mse_optimizer( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             batch, self.critic1, self.critic1_optim | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-07 19:21:02 +08:00
										 |  |  |         td2, critic2_loss = self._mse_optimizer( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             batch, self.critic2, self.critic2_optim | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         batch.weight = (td1 + td2) / 2.0  # prio-buffer | 
					
						
							| 
									
										
										
										
											2020-09-14 14:59:23 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-23 17:17:41 +08:00
										 |  |  |         # actor | 
					
						
							| 
									
										
										
										
											2020-09-22 16:28:46 +08:00
										 |  |  |         obs_result = self(batch) | 
					
						
							| 
									
										
										
										
											2022-01-30 00:53:56 +08:00
										 |  |  |         act = obs_result.act | 
					
						
							|  |  |  |         current_q1a = self.critic1(batch.obs, act).flatten() | 
					
						
							|  |  |  |         current_q2a = self.critic2(batch.obs, act).flatten() | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         actor_loss = ( | 
					
						
							|  |  |  |             self._alpha * obs_result.log_prob.flatten() - | 
					
						
							|  |  |  |             torch.min(current_q1a, current_q2a) | 
					
						
							|  |  |  |         ).mean() | 
					
						
							| 
									
										
										
										
											2020-03-23 17:17:41 +08:00
										 |  |  |         self.actor_optim.zero_grad() | 
					
						
							|  |  |  |         actor_loss.backward() | 
					
						
							|  |  |  |         self.actor_optim.step() | 
					
						
							| 
									
										
										
										
											2020-06-16 22:17:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 08:44:50 +08:00
										 |  |  |         if self._is_auto_alpha: | 
					
						
							|  |  |  |             log_prob = obs_result.log_prob.detach() + self._target_entropy | 
					
						
							| 
									
										
										
										
											2020-06-16 22:17:28 +08:00
										 |  |  |             alpha_loss = -(self._log_alpha * log_prob).mean() | 
					
						
							|  |  |  |             self._alpha_optim.zero_grad() | 
					
						
							|  |  |  |             alpha_loss.backward() | 
					
						
							|  |  |  |             self._alpha_optim.step() | 
					
						
							| 
									
										
										
										
											2020-09-12 08:44:50 +08:00
										 |  |  |             self._alpha = self._log_alpha.detach().exp() | 
					
						
							| 
									
										
										
										
											2020-06-16 22:17:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-23 17:17:41 +08:00
										 |  |  |         self.sync_weight() | 
					
						
							| 
									
										
										
										
											2020-06-16 22:17:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         result = { | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |             "loss/actor": actor_loss.item(), | 
					
						
							|  |  |  |             "loss/critic1": critic1_loss.item(), | 
					
						
							|  |  |  |             "loss/critic2": critic2_loss.item(), | 
					
						
							| 
									
										
										
										
											2020-03-23 17:17:41 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2020-09-12 08:44:50 +08:00
										 |  |  |         if self._is_auto_alpha: | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |             result["loss/alpha"] = alpha_loss.item() | 
					
						
							| 
									
										
										
										
											2020-09-14 14:59:23 +08:00
										 |  |  |             result["alpha"] = self._alpha.item()  # type: ignore | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-16 22:17:28 +08:00
										 |  |  |         return result |