SAC implementation update (#212)
- replace DiagGuassian with Independent(Normal) (pytorch has already supported this) - detach alpha from autograd - add value/alpha to result (more informational) - revert #204 to fix #211 Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
parent
b86d78766b
commit
16d8e9b051
@ -1,7 +1,6 @@
|
||||
# Bipedal-Hardcore-SAC
|
||||
|
||||
- Our default choice: remove the done flag penalty, will soon converge to \~250 reward within 100 epochs (10M env steps, 3~4 hours, see the image below)
|
||||
- Our default choice: remove the done flag penalty, will soon converge to \~270 reward within 100 epochs (10M env steps, 3~4 hours, see the image below)
|
||||
- If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward)
|
||||
- Action noise is only necessary in the beginning. It is a negative impact at the end of the training. Removing it can reach \~255 (our best result under the original env, no done penalty removed).
|
||||
|
||||

|
||||

|
||||
|
@ -24,6 +24,8 @@ def get_args():
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--tau', type=float, default=0.005)
|
||||
parser.add_argument('--alpha', type=float, default=0.1)
|
||||
parser.add_argument('--auto_alpha', type=int, default=1)
|
||||
parser.add_argument('--alpha_lr', type=float, default=3e-4)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
@ -46,7 +48,7 @@ def get_args():
|
||||
class EnvWrapper(object):
|
||||
"""Env wrapper for reward scale, action repeat and action noise"""
|
||||
|
||||
def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.3):
|
||||
def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.0):
|
||||
self._env = gym.make(task)
|
||||
self.action_repeat = action_repeat
|
||||
self.reward_scale = reward_scale
|
||||
@ -109,6 +111,12 @@ def test_sac_bipedal(args=get_args()):
|
||||
critic2 = Critic(net_c2, args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
|
||||
if args.auto_alpha:
|
||||
target_entropy = -np.prod(env.action_space.shape)
|
||||
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
|
||||
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
|
||||
args.alpha = (target_entropy, log_alpha, alpha_optim)
|
||||
|
||||
policy = SACPolicy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma, args.alpha,
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 40 KiB After Width: | Height: | Size: 46 KiB |
@ -5,11 +5,11 @@ import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.distributions import Independent, Normal
|
||||
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.policy.dist import DiagGaussian
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
@ -84,7 +84,11 @@ def test_ppo(args=get_args()):
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(list(
|
||||
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
dist = DiagGaussian
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
def dist(*logits):
|
||||
return Independent(Normal(*logits), 1)
|
||||
policy = PPOPolicy(
|
||||
actor, critic, optim, dist, args.gamma,
|
||||
max_grad_norm=args.max_grad_norm,
|
||||
|
@ -1,11 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class DiagGaussian(torch.distributions.Normal):
|
||||
"""Diagonal Gaussian distribution."""
|
||||
|
||||
def log_prob(self, actions):
|
||||
return super().log_prob(actions).sum(-1, keepdim=True)
|
||||
|
||||
def entropy(self):
|
||||
return super().entropy().sum(-1)
|
@ -2,9 +2,9 @@ import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Tuple, Union, Optional
|
||||
from torch.distributions import Normal, Independent
|
||||
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.policy.dist import DiagGaussian
|
||||
from tianshou.data import Batch, to_torch_as, ReplayBuffer
|
||||
from tianshou.exploration import BaseNoise
|
||||
|
||||
@ -47,23 +47,26 @@ class SACPolicy(DDPGPolicy):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
actor: torch.nn.Module,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic1: torch.nn.Module,
|
||||
critic1_optim: torch.optim.Optimizer,
|
||||
critic2: torch.nn.Module,
|
||||
critic2_optim: torch.optim.Optimizer,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
alpha: Tuple[float, torch.Tensor, torch.optim.Optimizer]
|
||||
or float = 0.2,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
estimation_step: int = 1,
|
||||
exploration_noise: Optional[BaseNoise] = None,
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic1: torch.nn.Module,
|
||||
critic1_optim: torch.optim.Optimizer,
|
||||
critic2: torch.nn.Module,
|
||||
critic2_optim: torch.optim.Optimizer,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
alpha: Union[
|
||||
float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
|
||||
] = 0.2,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
estimation_step: int = 1,
|
||||
exploration_noise: Optional[BaseNoise] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(None, None, None, None, tau, gamma, exploration_noise,
|
||||
action_range, reward_normalization, ignore_done,
|
||||
estimation_step, **kwargs)
|
||||
@ -75,14 +78,12 @@ class SACPolicy(DDPGPolicy):
|
||||
self.critic2_old.eval()
|
||||
self.critic2_optim = critic2_optim
|
||||
|
||||
self._automatic_alpha_tuning = not isinstance(alpha, float)
|
||||
if self._automatic_alpha_tuning:
|
||||
self._target_entropy = alpha[0]
|
||||
assert(alpha[1].shape == torch.Size([1])
|
||||
and alpha[1].requires_grad)
|
||||
self._log_alpha = alpha[1]
|
||||
self._alpha_optim = alpha[2]
|
||||
self._alpha = self._log_alpha.exp()
|
||||
self._is_auto_alpha = False
|
||||
if isinstance(alpha, tuple):
|
||||
self._is_auto_alpha = True
|
||||
self._target_entropy, self._log_alpha, self._alpha_optim = alpha
|
||||
assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad
|
||||
self._alpha = self._log_alpha.detach().exp()
|
||||
else:
|
||||
self._alpha = alpha
|
||||
|
||||
@ -111,12 +112,13 @@ class SACPolicy(DDPGPolicy):
|
||||
obs = getattr(batch, input)
|
||||
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||
assert isinstance(logits, tuple)
|
||||
dist = DiagGaussian(*logits)
|
||||
dist = Independent(Normal(*logits), 1)
|
||||
x = dist.rsample()
|
||||
y = torch.tanh(x)
|
||||
act = y * self._action_scale + self._action_bias
|
||||
y = self._action_scale * (1 - y.pow(2)) + self.__eps
|
||||
log_prob = dist.log_prob(x) - torch.log(y).sum(-1, keepdim=True)
|
||||
log_prob = dist.log_prob(x).unsqueeze(-1)
|
||||
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
|
||||
if self._noise is not None and self.training and explorating:
|
||||
act += to_torch_as(self._noise(act.shape), act)
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
@ -167,13 +169,13 @@ class SACPolicy(DDPGPolicy):
|
||||
actor_loss.backward()
|
||||
self.actor_optim.step()
|
||||
|
||||
if self._automatic_alpha_tuning:
|
||||
log_prob = (obs_result.log_prob + self._target_entropy).detach()
|
||||
if self._is_auto_alpha:
|
||||
log_prob = obs_result.log_prob.detach() + self._target_entropy
|
||||
alpha_loss = -(self._log_alpha * log_prob).mean()
|
||||
self._alpha_optim.zero_grad()
|
||||
alpha_loss.backward()
|
||||
self._alpha_optim.step()
|
||||
self._alpha = self._log_alpha.exp()
|
||||
self._alpha = self._log_alpha.detach().exp()
|
||||
|
||||
self.sync_weight()
|
||||
|
||||
@ -182,6 +184,7 @@ class SACPolicy(DDPGPolicy):
|
||||
'loss/critic1': critic1_loss.item(),
|
||||
'loss/critic2': critic2_loss.item(),
|
||||
}
|
||||
if self._automatic_alpha_tuning:
|
||||
if self._is_auto_alpha:
|
||||
result['loss/alpha'] = alpha_loss.item()
|
||||
result['v/alpha'] = self._alpha.item()
|
||||
return result
|
||||
|
@ -77,13 +77,13 @@ def offpolicy_trainer(
|
||||
start_time = time.time()
|
||||
test_in_train = test_in_train and train_collector.policy == policy
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
# train
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||
**tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
# collect
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
policy.eval()
|
||||
result = train_collector.collect(n_step=collect_per_step)
|
||||
data = {}
|
||||
if test_in_train and stop_fn and stop_fn(result['rew']):
|
||||
@ -100,10 +100,9 @@ def offpolicy_trainer(
|
||||
start_time, train_collector, test_collector,
|
||||
test_result['rew'])
|
||||
else:
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
# train
|
||||
policy.train()
|
||||
for i in range(update_per_step * min(
|
||||
result['n/st'] // collect_per_step, t.total - t.n)):
|
||||
global_step += collect_per_step
|
||||
|
@ -77,13 +77,13 @@ def onpolicy_trainer(
|
||||
start_time = time.time()
|
||||
test_in_train = test_in_train and train_collector.policy == policy
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
# train
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||
**tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
# collect
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
policy.eval()
|
||||
result = train_collector.collect(n_episode=collect_per_step)
|
||||
data = {}
|
||||
if test_in_train and stop_fn and stop_fn(result['rew']):
|
||||
@ -100,10 +100,9 @@ def onpolicy_trainer(
|
||||
start_time, train_collector, test_collector,
|
||||
test_result['rew'])
|
||||
else:
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
# train
|
||||
policy.train()
|
||||
losses = policy.update(
|
||||
0, train_collector.buffer, batch_size, repeat_per_collect)
|
||||
train_collector.reset_buffer()
|
||||
|
Loading…
x
Reference in New Issue
Block a user