SAC implementation update (#212)

- replace DiagGuassian with Independent(Normal) (pytorch has already supported this)
- detach alpha from autograd
- add value/alpha to result (more informational)
- revert #204 to fix #211

Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
danagi 2020-09-12 08:44:50 +08:00 committed by GitHub
parent b86d78766b
commit 16d8e9b051
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 62 additions and 61 deletions

View File

@ -1,7 +1,6 @@
# Bipedal-Hardcore-SAC
- Our default choice: remove the done flag penalty, will soon converge to \~250 reward within 100 epochs (10M env steps, 3~4 hours, see the image below)
- Our default choice: remove the done flag penalty, will soon converge to \~270 reward within 100 epochs (10M env steps, 3~4 hours, see the image below)
- If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward)
- Action noise is only necessary in the beginning. It is a negative impact at the end of the training. Removing it can reach \~255 (our best result under the original env, no done penalty removed).
![](results/sac/BipedalHardcore.png)
![](results/sac/BipedalHardcore.png)

View File

@ -24,6 +24,8 @@ def get_args():
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--alpha', type=float, default=0.1)
parser.add_argument('--auto_alpha', type=int, default=1)
parser.add_argument('--alpha_lr', type=float, default=3e-4)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--collect-per-step', type=int, default=10)
@ -46,7 +48,7 @@ def get_args():
class EnvWrapper(object):
"""Env wrapper for reward scale, action repeat and action noise"""
def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.3):
def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.0):
self._env = gym.make(task)
self.action_repeat = action_repeat
self.reward_scale = reward_scale
@ -109,6 +111,12 @@ def test_sac_bipedal(args=get_args()):
critic2 = Critic(net_c2, args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
if args.auto_alpha:
target_entropy = -np.prod(env.action_space.shape)
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
args.alpha = (target_entropy, log_alpha, alpha_optim)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha,

Binary file not shown.

Before

Width:  |  Height:  |  Size: 40 KiB

After

Width:  |  Height:  |  Size: 46 KiB

View File

@ -5,11 +5,11 @@ import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal
from tianshou.policy import PPOPolicy
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.policy.dist import DiagGaussian
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.continuous import ActorProb, Critic
@ -84,7 +84,11 @@ def test_ppo(args=get_args()):
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(list(
actor.parameters()) + list(critic.parameters()), lr=args.lr)
dist = DiagGaussian
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits):
return Independent(Normal(*logits), 1)
policy = PPOPolicy(
actor, critic, optim, dist, args.gamma,
max_grad_norm=args.max_grad_norm,

View File

@ -1,11 +0,0 @@
import torch
class DiagGaussian(torch.distributions.Normal):
"""Diagonal Gaussian distribution."""
def log_prob(self, actions):
return super().log_prob(actions).sum(-1, keepdim=True)
def entropy(self):
return super().entropy().sum(-1)

View File

@ -2,9 +2,9 @@ import torch
import numpy as np
from copy import deepcopy
from typing import Dict, Tuple, Union, Optional
from torch.distributions import Normal, Independent
from tianshou.policy import DDPGPolicy
from tianshou.policy.dist import DiagGaussian
from tianshou.data import Batch, to_torch_as, ReplayBuffer
from tianshou.exploration import BaseNoise
@ -47,23 +47,26 @@ class SACPolicy(DDPGPolicy):
explanation.
"""
def __init__(self,
actor: torch.nn.Module,
actor_optim: torch.optim.Optimizer,
critic1: torch.nn.Module,
critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer,
tau: float = 0.005,
gamma: float = 0.99,
alpha: Tuple[float, torch.Tensor, torch.optim.Optimizer]
or float = 0.2,
action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
exploration_noise: Optional[BaseNoise] = None,
**kwargs) -> None:
def __init__(
self,
actor: torch.nn.Module,
actor_optim: torch.optim.Optimizer,
critic1: torch.nn.Module,
critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer,
tau: float = 0.005,
gamma: float = 0.99,
alpha: Union[
float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
] = 0.2,
action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
exploration_noise: Optional[BaseNoise] = None,
**kwargs
) -> None:
super().__init__(None, None, None, None, tau, gamma, exploration_noise,
action_range, reward_normalization, ignore_done,
estimation_step, **kwargs)
@ -75,14 +78,12 @@ class SACPolicy(DDPGPolicy):
self.critic2_old.eval()
self.critic2_optim = critic2_optim
self._automatic_alpha_tuning = not isinstance(alpha, float)
if self._automatic_alpha_tuning:
self._target_entropy = alpha[0]
assert(alpha[1].shape == torch.Size([1])
and alpha[1].requires_grad)
self._log_alpha = alpha[1]
self._alpha_optim = alpha[2]
self._alpha = self._log_alpha.exp()
self._is_auto_alpha = False
if isinstance(alpha, tuple):
self._is_auto_alpha = True
self._target_entropy, self._log_alpha, self._alpha_optim = alpha
assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad
self._alpha = self._log_alpha.detach().exp()
else:
self._alpha = alpha
@ -111,12 +112,13 @@ class SACPolicy(DDPGPolicy):
obs = getattr(batch, input)
logits, h = self.actor(obs, state=state, info=batch.info)
assert isinstance(logits, tuple)
dist = DiagGaussian(*logits)
dist = Independent(Normal(*logits), 1)
x = dist.rsample()
y = torch.tanh(x)
act = y * self._action_scale + self._action_bias
y = self._action_scale * (1 - y.pow(2)) + self.__eps
log_prob = dist.log_prob(x) - torch.log(y).sum(-1, keepdim=True)
log_prob = dist.log_prob(x).unsqueeze(-1)
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
if self._noise is not None and self.training and explorating:
act += to_torch_as(self._noise(act.shape), act)
act = act.clamp(self._range[0], self._range[1])
@ -167,13 +169,13 @@ class SACPolicy(DDPGPolicy):
actor_loss.backward()
self.actor_optim.step()
if self._automatic_alpha_tuning:
log_prob = (obs_result.log_prob + self._target_entropy).detach()
if self._is_auto_alpha:
log_prob = obs_result.log_prob.detach() + self._target_entropy
alpha_loss = -(self._log_alpha * log_prob).mean()
self._alpha_optim.zero_grad()
alpha_loss.backward()
self._alpha_optim.step()
self._alpha = self._log_alpha.exp()
self._alpha = self._log_alpha.detach().exp()
self.sync_weight()
@ -182,6 +184,7 @@ class SACPolicy(DDPGPolicy):
'loss/critic1': critic1_loss.item(),
'loss/critic2': critic2_loss.item(),
}
if self._automatic_alpha_tuning:
if self._is_auto_alpha:
result['loss/alpha'] = alpha_loss.item()
result['v/alpha'] = self._alpha.item()
return result

View File

@ -77,13 +77,13 @@ def offpolicy_trainer(
start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
# collect
if train_fn:
train_fn(epoch)
policy.eval()
result = train_collector.collect(n_step=collect_per_step)
data = {}
if test_in_train and stop_fn and stop_fn(result['rew']):
@ -100,10 +100,9 @@ def offpolicy_trainer(
start_time, train_collector, test_collector,
test_result['rew'])
else:
policy.train()
if train_fn:
train_fn(epoch)
# train
policy.train()
for i in range(update_per_step * min(
result['n/st'] // collect_per_step, t.total - t.n)):
global_step += collect_per_step

View File

@ -77,13 +77,13 @@ def onpolicy_trainer(
start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
# collect
if train_fn:
train_fn(epoch)
policy.eval()
result = train_collector.collect(n_episode=collect_per_step)
data = {}
if test_in_train and stop_fn and stop_fn(result['rew']):
@ -100,10 +100,9 @@ def onpolicy_trainer(
start_time, train_collector, test_collector,
test_result['rew'])
else:
policy.train()
if train_fn:
train_fn(epoch)
# train
policy.train()
losses = policy.update(
0, train_collector.buffer, batch_size, repeat_per_collect)
train_collector.reset_buffer()