SAC implementation update (#212)
- replace DiagGuassian with Independent(Normal) (pytorch has already supported this) - detach alpha from autograd - add value/alpha to result (more informational) - revert #204 to fix #211 Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
		
							parent
							
								
									b86d78766b
								
							
						
					
					
						commit
						16d8e9b051
					
				| @ -1,7 +1,6 @@ | |||||||
| # Bipedal-Hardcore-SAC | # Bipedal-Hardcore-SAC | ||||||
| 
 | 
 | ||||||
| - Our default choice: remove the done flag penalty, will soon converge to \~250 reward within 100 epochs (10M env steps, 3~4 hours, see the image below) | - Our default choice: remove the done flag penalty, will soon converge to \~270 reward within 100 epochs (10M env steps, 3~4 hours, see the image below) | ||||||
| - If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward) | - If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward) | ||||||
| - Action noise is only necessary in the beginning. It is a negative impact at the end of the training. Removing it can reach \~255 (our best result under the original env, no done penalty removed). |  | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | |||||||
| @ -24,6 +24,8 @@ def get_args(): | |||||||
|     parser.add_argument('--gamma', type=float, default=0.99) |     parser.add_argument('--gamma', type=float, default=0.99) | ||||||
|     parser.add_argument('--tau', type=float, default=0.005) |     parser.add_argument('--tau', type=float, default=0.005) | ||||||
|     parser.add_argument('--alpha', type=float, default=0.1) |     parser.add_argument('--alpha', type=float, default=0.1) | ||||||
|  |     parser.add_argument('--auto_alpha', type=int, default=1) | ||||||
|  |     parser.add_argument('--alpha_lr', type=float, default=3e-4) | ||||||
|     parser.add_argument('--epoch', type=int, default=100) |     parser.add_argument('--epoch', type=int, default=100) | ||||||
|     parser.add_argument('--step-per-epoch', type=int, default=10000) |     parser.add_argument('--step-per-epoch', type=int, default=10000) | ||||||
|     parser.add_argument('--collect-per-step', type=int, default=10) |     parser.add_argument('--collect-per-step', type=int, default=10) | ||||||
| @ -46,7 +48,7 @@ def get_args(): | |||||||
| class EnvWrapper(object): | class EnvWrapper(object): | ||||||
|     """Env wrapper for reward scale, action repeat and action noise""" |     """Env wrapper for reward scale, action repeat and action noise""" | ||||||
| 
 | 
 | ||||||
|     def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.3): |     def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.0): | ||||||
|         self._env = gym.make(task) |         self._env = gym.make(task) | ||||||
|         self.action_repeat = action_repeat |         self.action_repeat = action_repeat | ||||||
|         self.reward_scale = reward_scale |         self.reward_scale = reward_scale | ||||||
| @ -109,6 +111,12 @@ def test_sac_bipedal(args=get_args()): | |||||||
|     critic2 = Critic(net_c2, args.device).to(args.device) |     critic2 = Critic(net_c2, args.device).to(args.device) | ||||||
|     critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) |     critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) | ||||||
| 
 | 
 | ||||||
|  |     if args.auto_alpha: | ||||||
|  |         target_entropy = -np.prod(env.action_space.shape) | ||||||
|  |         log_alpha = torch.zeros(1, requires_grad=True, device=args.device) | ||||||
|  |         alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) | ||||||
|  |         args.alpha = (target_entropy, log_alpha, alpha_optim) | ||||||
|  | 
 | ||||||
|     policy = SACPolicy( |     policy = SACPolicy( | ||||||
|         actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, |         actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, | ||||||
|         args.tau, args.gamma, args.alpha, |         args.tau, args.gamma, args.alpha, | ||||||
|  | |||||||
										
											Binary file not shown.
										
									
								
							| Before Width: | Height: | Size: 40 KiB After Width: | Height: | Size: 46 KiB | 
| @ -5,11 +5,11 @@ import pprint | |||||||
| import argparse | import argparse | ||||||
| import numpy as np | import numpy as np | ||||||
| from torch.utils.tensorboard import SummaryWriter | from torch.utils.tensorboard import SummaryWriter | ||||||
|  | from torch.distributions import Independent, Normal | ||||||
| 
 | 
 | ||||||
| from tianshou.policy import PPOPolicy | from tianshou.policy import PPOPolicy | ||||||
| from tianshou.env import DummyVectorEnv | from tianshou.env import DummyVectorEnv | ||||||
| from tianshou.utils.net.common import Net | from tianshou.utils.net.common import Net | ||||||
| from tianshou.policy.dist import DiagGaussian |  | ||||||
| from tianshou.trainer import onpolicy_trainer | from tianshou.trainer import onpolicy_trainer | ||||||
| from tianshou.data import Collector, ReplayBuffer | from tianshou.data import Collector, ReplayBuffer | ||||||
| from tianshou.utils.net.continuous import ActorProb, Critic | from tianshou.utils.net.continuous import ActorProb, Critic | ||||||
| @ -84,7 +84,11 @@ def test_ppo(args=get_args()): | |||||||
|             torch.nn.init.zeros_(m.bias) |             torch.nn.init.zeros_(m.bias) | ||||||
|     optim = torch.optim.Adam(list( |     optim = torch.optim.Adam(list( | ||||||
|         actor.parameters()) + list(critic.parameters()), lr=args.lr) |         actor.parameters()) + list(critic.parameters()), lr=args.lr) | ||||||
|     dist = DiagGaussian | 
 | ||||||
|  |     # replace DiagGuassian with Independent(Normal) which is equivalent | ||||||
|  |     # pass *logits to be consistent with policy.forward | ||||||
|  |     def dist(*logits): | ||||||
|  |         return Independent(Normal(*logits), 1) | ||||||
|     policy = PPOPolicy( |     policy = PPOPolicy( | ||||||
|         actor, critic, optim, dist, args.gamma, |         actor, critic, optim, dist, args.gamma, | ||||||
|         max_grad_norm=args.max_grad_norm, |         max_grad_norm=args.max_grad_norm, | ||||||
|  | |||||||
| @ -1,11 +0,0 @@ | |||||||
| import torch |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class DiagGaussian(torch.distributions.Normal): |  | ||||||
|     """Diagonal Gaussian distribution.""" |  | ||||||
| 
 |  | ||||||
|     def log_prob(self, actions): |  | ||||||
|         return super().log_prob(actions).sum(-1, keepdim=True) |  | ||||||
| 
 |  | ||||||
|     def entropy(self): |  | ||||||
|         return super().entropy().sum(-1) |  | ||||||
| @ -2,9 +2,9 @@ import torch | |||||||
| import numpy as np | import numpy as np | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from typing import Dict, Tuple, Union, Optional | from typing import Dict, Tuple, Union, Optional | ||||||
|  | from torch.distributions import Normal, Independent | ||||||
| 
 | 
 | ||||||
| from tianshou.policy import DDPGPolicy | from tianshou.policy import DDPGPolicy | ||||||
| from tianshou.policy.dist import DiagGaussian |  | ||||||
| from tianshou.data import Batch, to_torch_as, ReplayBuffer | from tianshou.data import Batch, to_torch_as, ReplayBuffer | ||||||
| from tianshou.exploration import BaseNoise | from tianshou.exploration import BaseNoise | ||||||
| 
 | 
 | ||||||
| @ -47,23 +47,26 @@ class SACPolicy(DDPGPolicy): | |||||||
|         explanation. |         explanation. | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__(self, |     def __init__( | ||||||
|                  actor: torch.nn.Module, |         self, | ||||||
|                  actor_optim: torch.optim.Optimizer, |         actor: torch.nn.Module, | ||||||
|                  critic1: torch.nn.Module, |         actor_optim: torch.optim.Optimizer, | ||||||
|                  critic1_optim: torch.optim.Optimizer, |         critic1: torch.nn.Module, | ||||||
|                  critic2: torch.nn.Module, |         critic1_optim: torch.optim.Optimizer, | ||||||
|                  critic2_optim: torch.optim.Optimizer, |         critic2: torch.nn.Module, | ||||||
|                  tau: float = 0.005, |         critic2_optim: torch.optim.Optimizer, | ||||||
|                  gamma: float = 0.99, |         tau: float = 0.005, | ||||||
|                  alpha: Tuple[float, torch.Tensor, torch.optim.Optimizer] |         gamma: float = 0.99, | ||||||
|                  or float = 0.2, |         alpha: Union[ | ||||||
|                  action_range: Optional[Tuple[float, float]] = None, |             float, Tuple[float, torch.Tensor, torch.optim.Optimizer] | ||||||
|                  reward_normalization: bool = False, |         ] = 0.2, | ||||||
|                  ignore_done: bool = False, |         action_range: Optional[Tuple[float, float]] = None, | ||||||
|                  estimation_step: int = 1, |         reward_normalization: bool = False, | ||||||
|                  exploration_noise: Optional[BaseNoise] = None, |         ignore_done: bool = False, | ||||||
|                  **kwargs) -> None: |         estimation_step: int = 1, | ||||||
|  |         exploration_noise: Optional[BaseNoise] = None, | ||||||
|  |         **kwargs | ||||||
|  |     ) -> None: | ||||||
|         super().__init__(None, None, None, None, tau, gamma, exploration_noise, |         super().__init__(None, None, None, None, tau, gamma, exploration_noise, | ||||||
|                          action_range, reward_normalization, ignore_done, |                          action_range, reward_normalization, ignore_done, | ||||||
|                          estimation_step, **kwargs) |                          estimation_step, **kwargs) | ||||||
| @ -75,14 +78,12 @@ class SACPolicy(DDPGPolicy): | |||||||
|         self.critic2_old.eval() |         self.critic2_old.eval() | ||||||
|         self.critic2_optim = critic2_optim |         self.critic2_optim = critic2_optim | ||||||
| 
 | 
 | ||||||
|         self._automatic_alpha_tuning = not isinstance(alpha, float) |         self._is_auto_alpha = False | ||||||
|         if self._automatic_alpha_tuning: |         if isinstance(alpha, tuple): | ||||||
|             self._target_entropy = alpha[0] |             self._is_auto_alpha = True | ||||||
|             assert(alpha[1].shape == torch.Size([1]) |             self._target_entropy, self._log_alpha, self._alpha_optim = alpha | ||||||
|                    and alpha[1].requires_grad) |             assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad | ||||||
|             self._log_alpha = alpha[1] |             self._alpha = self._log_alpha.detach().exp() | ||||||
|             self._alpha_optim = alpha[2] |  | ||||||
|             self._alpha = self._log_alpha.exp() |  | ||||||
|         else: |         else: | ||||||
|             self._alpha = alpha |             self._alpha = alpha | ||||||
| 
 | 
 | ||||||
| @ -111,12 +112,13 @@ class SACPolicy(DDPGPolicy): | |||||||
|         obs = getattr(batch, input) |         obs = getattr(batch, input) | ||||||
|         logits, h = self.actor(obs, state=state, info=batch.info) |         logits, h = self.actor(obs, state=state, info=batch.info) | ||||||
|         assert isinstance(logits, tuple) |         assert isinstance(logits, tuple) | ||||||
|         dist = DiagGaussian(*logits) |         dist = Independent(Normal(*logits), 1) | ||||||
|         x = dist.rsample() |         x = dist.rsample() | ||||||
|         y = torch.tanh(x) |         y = torch.tanh(x) | ||||||
|         act = y * self._action_scale + self._action_bias |         act = y * self._action_scale + self._action_bias | ||||||
|         y = self._action_scale * (1 - y.pow(2)) + self.__eps |         y = self._action_scale * (1 - y.pow(2)) + self.__eps | ||||||
|         log_prob = dist.log_prob(x) - torch.log(y).sum(-1, keepdim=True) |         log_prob = dist.log_prob(x).unsqueeze(-1) | ||||||
|  |         log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) | ||||||
|         if self._noise is not None and self.training and explorating: |         if self._noise is not None and self.training and explorating: | ||||||
|             act += to_torch_as(self._noise(act.shape), act) |             act += to_torch_as(self._noise(act.shape), act) | ||||||
|         act = act.clamp(self._range[0], self._range[1]) |         act = act.clamp(self._range[0], self._range[1]) | ||||||
| @ -167,13 +169,13 @@ class SACPolicy(DDPGPolicy): | |||||||
|         actor_loss.backward() |         actor_loss.backward() | ||||||
|         self.actor_optim.step() |         self.actor_optim.step() | ||||||
| 
 | 
 | ||||||
|         if self._automatic_alpha_tuning: |         if self._is_auto_alpha: | ||||||
|             log_prob = (obs_result.log_prob + self._target_entropy).detach() |             log_prob = obs_result.log_prob.detach() + self._target_entropy | ||||||
|             alpha_loss = -(self._log_alpha * log_prob).mean() |             alpha_loss = -(self._log_alpha * log_prob).mean() | ||||||
|             self._alpha_optim.zero_grad() |             self._alpha_optim.zero_grad() | ||||||
|             alpha_loss.backward() |             alpha_loss.backward() | ||||||
|             self._alpha_optim.step() |             self._alpha_optim.step() | ||||||
|             self._alpha = self._log_alpha.exp() |             self._alpha = self._log_alpha.detach().exp() | ||||||
| 
 | 
 | ||||||
|         self.sync_weight() |         self.sync_weight() | ||||||
| 
 | 
 | ||||||
| @ -182,6 +184,7 @@ class SACPolicy(DDPGPolicy): | |||||||
|             'loss/critic1': critic1_loss.item(), |             'loss/critic1': critic1_loss.item(), | ||||||
|             'loss/critic2': critic2_loss.item(), |             'loss/critic2': critic2_loss.item(), | ||||||
|         } |         } | ||||||
|         if self._automatic_alpha_tuning: |         if self._is_auto_alpha: | ||||||
|             result['loss/alpha'] = alpha_loss.item() |             result['loss/alpha'] = alpha_loss.item() | ||||||
|  |             result['v/alpha'] = self._alpha.item() | ||||||
|         return result |         return result | ||||||
|  | |||||||
| @ -77,13 +77,13 @@ def offpolicy_trainer( | |||||||
|     start_time = time.time() |     start_time = time.time() | ||||||
|     test_in_train = test_in_train and train_collector.policy == policy |     test_in_train = test_in_train and train_collector.policy == policy | ||||||
|     for epoch in range(1, 1 + max_epoch): |     for epoch in range(1, 1 + max_epoch): | ||||||
|  |         # train | ||||||
|  |         policy.train() | ||||||
|  |         if train_fn: | ||||||
|  |             train_fn(epoch) | ||||||
|         with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', |         with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', | ||||||
|                        **tqdm_config) as t: |                        **tqdm_config) as t: | ||||||
|             while t.n < t.total: |             while t.n < t.total: | ||||||
|                 # collect |  | ||||||
|                 if train_fn: |  | ||||||
|                     train_fn(epoch) |  | ||||||
|                 policy.eval() |  | ||||||
|                 result = train_collector.collect(n_step=collect_per_step) |                 result = train_collector.collect(n_step=collect_per_step) | ||||||
|                 data = {} |                 data = {} | ||||||
|                 if test_in_train and stop_fn and stop_fn(result['rew']): |                 if test_in_train and stop_fn and stop_fn(result['rew']): | ||||||
| @ -100,10 +100,9 @@ def offpolicy_trainer( | |||||||
|                             start_time, train_collector, test_collector, |                             start_time, train_collector, test_collector, | ||||||
|                             test_result['rew']) |                             test_result['rew']) | ||||||
|                     else: |                     else: | ||||||
|  |                         policy.train() | ||||||
|                         if train_fn: |                         if train_fn: | ||||||
|                             train_fn(epoch) |                             train_fn(epoch) | ||||||
|                 # train |  | ||||||
|                 policy.train() |  | ||||||
|                 for i in range(update_per_step * min( |                 for i in range(update_per_step * min( | ||||||
|                         result['n/st'] // collect_per_step, t.total - t.n)): |                         result['n/st'] // collect_per_step, t.total - t.n)): | ||||||
|                     global_step += collect_per_step |                     global_step += collect_per_step | ||||||
|  | |||||||
| @ -77,13 +77,13 @@ def onpolicy_trainer( | |||||||
|     start_time = time.time() |     start_time = time.time() | ||||||
|     test_in_train = test_in_train and train_collector.policy == policy |     test_in_train = test_in_train and train_collector.policy == policy | ||||||
|     for epoch in range(1, 1 + max_epoch): |     for epoch in range(1, 1 + max_epoch): | ||||||
|  |         # train | ||||||
|  |         policy.train() | ||||||
|  |         if train_fn: | ||||||
|  |             train_fn(epoch) | ||||||
|         with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', |         with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', | ||||||
|                        **tqdm_config) as t: |                        **tqdm_config) as t: | ||||||
|             while t.n < t.total: |             while t.n < t.total: | ||||||
|                 # collect |  | ||||||
|                 if train_fn: |  | ||||||
|                     train_fn(epoch) |  | ||||||
|                 policy.eval() |  | ||||||
|                 result = train_collector.collect(n_episode=collect_per_step) |                 result = train_collector.collect(n_episode=collect_per_step) | ||||||
|                 data = {} |                 data = {} | ||||||
|                 if test_in_train and stop_fn and stop_fn(result['rew']): |                 if test_in_train and stop_fn and stop_fn(result['rew']): | ||||||
| @ -100,10 +100,9 @@ def onpolicy_trainer( | |||||||
|                             start_time, train_collector, test_collector, |                             start_time, train_collector, test_collector, | ||||||
|                             test_result['rew']) |                             test_result['rew']) | ||||||
|                     else: |                     else: | ||||||
|  |                         policy.train() | ||||||
|                         if train_fn: |                         if train_fn: | ||||||
|                             train_fn(epoch) |                             train_fn(epoch) | ||||||
|                 # train |  | ||||||
|                 policy.train() |  | ||||||
|                 losses = policy.update( |                 losses = policy.update( | ||||||
|                     0, train_collector.buffer, batch_size, repeat_per_collect) |                     0, train_collector.buffer, batch_size, repeat_per_collect) | ||||||
|                 train_collector.reset_buffer() |                 train_collector.reset_buffer() | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user