SAC implementation update (#212)
- replace DiagGuassian with Independent(Normal) (pytorch has already supported this) - detach alpha from autograd - add value/alpha to result (more informational) - revert #204 to fix #211 Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
parent
b86d78766b
commit
16d8e9b051
@ -1,7 +1,6 @@
|
|||||||
# Bipedal-Hardcore-SAC
|
# Bipedal-Hardcore-SAC
|
||||||
|
|
||||||
- Our default choice: remove the done flag penalty, will soon converge to \~250 reward within 100 epochs (10M env steps, 3~4 hours, see the image below)
|
- Our default choice: remove the done flag penalty, will soon converge to \~270 reward within 100 epochs (10M env steps, 3~4 hours, see the image below)
|
||||||
- If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward)
|
- If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward)
|
||||||
- Action noise is only necessary in the beginning. It is a negative impact at the end of the training. Removing it can reach \~255 (our best result under the original env, no done penalty removed).
|
|
||||||
|
|
||||||

|

|
@ -24,6 +24,8 @@ def get_args():
|
|||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
parser.add_argument('--tau', type=float, default=0.005)
|
parser.add_argument('--tau', type=float, default=0.005)
|
||||||
parser.add_argument('--alpha', type=float, default=0.1)
|
parser.add_argument('--alpha', type=float, default=0.1)
|
||||||
|
parser.add_argument('--auto_alpha', type=int, default=1)
|
||||||
|
parser.add_argument('--alpha_lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
@ -46,7 +48,7 @@ def get_args():
|
|||||||
class EnvWrapper(object):
|
class EnvWrapper(object):
|
||||||
"""Env wrapper for reward scale, action repeat and action noise"""
|
"""Env wrapper for reward scale, action repeat and action noise"""
|
||||||
|
|
||||||
def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.3):
|
def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.0):
|
||||||
self._env = gym.make(task)
|
self._env = gym.make(task)
|
||||||
self.action_repeat = action_repeat
|
self.action_repeat = action_repeat
|
||||||
self.reward_scale = reward_scale
|
self.reward_scale = reward_scale
|
||||||
@ -109,6 +111,12 @@ def test_sac_bipedal(args=get_args()):
|
|||||||
critic2 = Critic(net_c2, args.device).to(args.device)
|
critic2 = Critic(net_c2, args.device).to(args.device)
|
||||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||||
|
|
||||||
|
if args.auto_alpha:
|
||||||
|
target_entropy = -np.prod(env.action_space.shape)
|
||||||
|
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
|
||||||
|
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
|
||||||
|
args.alpha = (target_entropy, log_alpha, alpha_optim)
|
||||||
|
|
||||||
policy = SACPolicy(
|
policy = SACPolicy(
|
||||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||||
args.tau, args.gamma, args.alpha,
|
args.tau, args.gamma, args.alpha,
|
||||||
|
Binary file not shown.
Before Width: | Height: | Size: 40 KiB After Width: | Height: | Size: 46 KiB |
@ -5,11 +5,11 @@ import pprint
|
|||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
from tianshou.policy import PPOPolicy
|
from tianshou.policy import PPOPolicy
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.policy.dist import DiagGaussian
|
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||||
@ -84,7 +84,11 @@ def test_ppo(args=get_args()):
|
|||||||
torch.nn.init.zeros_(m.bias)
|
torch.nn.init.zeros_(m.bias)
|
||||||
optim = torch.optim.Adam(list(
|
optim = torch.optim.Adam(list(
|
||||||
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||||
dist = DiagGaussian
|
|
||||||
|
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||||
|
# pass *logits to be consistent with policy.forward
|
||||||
|
def dist(*logits):
|
||||||
|
return Independent(Normal(*logits), 1)
|
||||||
policy = PPOPolicy(
|
policy = PPOPolicy(
|
||||||
actor, critic, optim, dist, args.gamma,
|
actor, critic, optim, dist, args.gamma,
|
||||||
max_grad_norm=args.max_grad_norm,
|
max_grad_norm=args.max_grad_norm,
|
||||||
|
@ -1,11 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class DiagGaussian(torch.distributions.Normal):
|
|
||||||
"""Diagonal Gaussian distribution."""
|
|
||||||
|
|
||||||
def log_prob(self, actions):
|
|
||||||
return super().log_prob(actions).sum(-1, keepdim=True)
|
|
||||||
|
|
||||||
def entropy(self):
|
|
||||||
return super().entropy().sum(-1)
|
|
@ -2,9 +2,9 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Dict, Tuple, Union, Optional
|
from typing import Dict, Tuple, Union, Optional
|
||||||
|
from torch.distributions import Normal, Independent
|
||||||
|
|
||||||
from tianshou.policy import DDPGPolicy
|
from tianshou.policy import DDPGPolicy
|
||||||
from tianshou.policy.dist import DiagGaussian
|
|
||||||
from tianshou.data import Batch, to_torch_as, ReplayBuffer
|
from tianshou.data import Batch, to_torch_as, ReplayBuffer
|
||||||
from tianshou.exploration import BaseNoise
|
from tianshou.exploration import BaseNoise
|
||||||
|
|
||||||
@ -47,7 +47,8 @@ class SACPolicy(DDPGPolicy):
|
|||||||
explanation.
|
explanation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
actor: torch.nn.Module,
|
actor: torch.nn.Module,
|
||||||
actor_optim: torch.optim.Optimizer,
|
actor_optim: torch.optim.Optimizer,
|
||||||
critic1: torch.nn.Module,
|
critic1: torch.nn.Module,
|
||||||
@ -56,14 +57,16 @@ class SACPolicy(DDPGPolicy):
|
|||||||
critic2_optim: torch.optim.Optimizer,
|
critic2_optim: torch.optim.Optimizer,
|
||||||
tau: float = 0.005,
|
tau: float = 0.005,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
alpha: Tuple[float, torch.Tensor, torch.optim.Optimizer]
|
alpha: Union[
|
||||||
or float = 0.2,
|
float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
|
||||||
|
] = 0.2,
|
||||||
action_range: Optional[Tuple[float, float]] = None,
|
action_range: Optional[Tuple[float, float]] = None,
|
||||||
reward_normalization: bool = False,
|
reward_normalization: bool = False,
|
||||||
ignore_done: bool = False,
|
ignore_done: bool = False,
|
||||||
estimation_step: int = 1,
|
estimation_step: int = 1,
|
||||||
exploration_noise: Optional[BaseNoise] = None,
|
exploration_noise: Optional[BaseNoise] = None,
|
||||||
**kwargs) -> None:
|
**kwargs
|
||||||
|
) -> None:
|
||||||
super().__init__(None, None, None, None, tau, gamma, exploration_noise,
|
super().__init__(None, None, None, None, tau, gamma, exploration_noise,
|
||||||
action_range, reward_normalization, ignore_done,
|
action_range, reward_normalization, ignore_done,
|
||||||
estimation_step, **kwargs)
|
estimation_step, **kwargs)
|
||||||
@ -75,14 +78,12 @@ class SACPolicy(DDPGPolicy):
|
|||||||
self.critic2_old.eval()
|
self.critic2_old.eval()
|
||||||
self.critic2_optim = critic2_optim
|
self.critic2_optim = critic2_optim
|
||||||
|
|
||||||
self._automatic_alpha_tuning = not isinstance(alpha, float)
|
self._is_auto_alpha = False
|
||||||
if self._automatic_alpha_tuning:
|
if isinstance(alpha, tuple):
|
||||||
self._target_entropy = alpha[0]
|
self._is_auto_alpha = True
|
||||||
assert(alpha[1].shape == torch.Size([1])
|
self._target_entropy, self._log_alpha, self._alpha_optim = alpha
|
||||||
and alpha[1].requires_grad)
|
assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad
|
||||||
self._log_alpha = alpha[1]
|
self._alpha = self._log_alpha.detach().exp()
|
||||||
self._alpha_optim = alpha[2]
|
|
||||||
self._alpha = self._log_alpha.exp()
|
|
||||||
else:
|
else:
|
||||||
self._alpha = alpha
|
self._alpha = alpha
|
||||||
|
|
||||||
@ -111,12 +112,13 @@ class SACPolicy(DDPGPolicy):
|
|||||||
obs = getattr(batch, input)
|
obs = getattr(batch, input)
|
||||||
logits, h = self.actor(obs, state=state, info=batch.info)
|
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||||
assert isinstance(logits, tuple)
|
assert isinstance(logits, tuple)
|
||||||
dist = DiagGaussian(*logits)
|
dist = Independent(Normal(*logits), 1)
|
||||||
x = dist.rsample()
|
x = dist.rsample()
|
||||||
y = torch.tanh(x)
|
y = torch.tanh(x)
|
||||||
act = y * self._action_scale + self._action_bias
|
act = y * self._action_scale + self._action_bias
|
||||||
y = self._action_scale * (1 - y.pow(2)) + self.__eps
|
y = self._action_scale * (1 - y.pow(2)) + self.__eps
|
||||||
log_prob = dist.log_prob(x) - torch.log(y).sum(-1, keepdim=True)
|
log_prob = dist.log_prob(x).unsqueeze(-1)
|
||||||
|
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
|
||||||
if self._noise is not None and self.training and explorating:
|
if self._noise is not None and self.training and explorating:
|
||||||
act += to_torch_as(self._noise(act.shape), act)
|
act += to_torch_as(self._noise(act.shape), act)
|
||||||
act = act.clamp(self._range[0], self._range[1])
|
act = act.clamp(self._range[0], self._range[1])
|
||||||
@ -167,13 +169,13 @@ class SACPolicy(DDPGPolicy):
|
|||||||
actor_loss.backward()
|
actor_loss.backward()
|
||||||
self.actor_optim.step()
|
self.actor_optim.step()
|
||||||
|
|
||||||
if self._automatic_alpha_tuning:
|
if self._is_auto_alpha:
|
||||||
log_prob = (obs_result.log_prob + self._target_entropy).detach()
|
log_prob = obs_result.log_prob.detach() + self._target_entropy
|
||||||
alpha_loss = -(self._log_alpha * log_prob).mean()
|
alpha_loss = -(self._log_alpha * log_prob).mean()
|
||||||
self._alpha_optim.zero_grad()
|
self._alpha_optim.zero_grad()
|
||||||
alpha_loss.backward()
|
alpha_loss.backward()
|
||||||
self._alpha_optim.step()
|
self._alpha_optim.step()
|
||||||
self._alpha = self._log_alpha.exp()
|
self._alpha = self._log_alpha.detach().exp()
|
||||||
|
|
||||||
self.sync_weight()
|
self.sync_weight()
|
||||||
|
|
||||||
@ -182,6 +184,7 @@ class SACPolicy(DDPGPolicy):
|
|||||||
'loss/critic1': critic1_loss.item(),
|
'loss/critic1': critic1_loss.item(),
|
||||||
'loss/critic2': critic2_loss.item(),
|
'loss/critic2': critic2_loss.item(),
|
||||||
}
|
}
|
||||||
if self._automatic_alpha_tuning:
|
if self._is_auto_alpha:
|
||||||
result['loss/alpha'] = alpha_loss.item()
|
result['loss/alpha'] = alpha_loss.item()
|
||||||
|
result['v/alpha'] = self._alpha.item()
|
||||||
return result
|
return result
|
||||||
|
@ -77,13 +77,13 @@ def offpolicy_trainer(
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
test_in_train = test_in_train and train_collector.policy == policy
|
test_in_train = test_in_train and train_collector.policy == policy
|
||||||
for epoch in range(1, 1 + max_epoch):
|
for epoch in range(1, 1 + max_epoch):
|
||||||
|
# train
|
||||||
|
policy.train()
|
||||||
|
if train_fn:
|
||||||
|
train_fn(epoch)
|
||||||
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||||
**tqdm_config) as t:
|
**tqdm_config) as t:
|
||||||
while t.n < t.total:
|
while t.n < t.total:
|
||||||
# collect
|
|
||||||
if train_fn:
|
|
||||||
train_fn(epoch)
|
|
||||||
policy.eval()
|
|
||||||
result = train_collector.collect(n_step=collect_per_step)
|
result = train_collector.collect(n_step=collect_per_step)
|
||||||
data = {}
|
data = {}
|
||||||
if test_in_train and stop_fn and stop_fn(result['rew']):
|
if test_in_train and stop_fn and stop_fn(result['rew']):
|
||||||
@ -100,10 +100,9 @@ def offpolicy_trainer(
|
|||||||
start_time, train_collector, test_collector,
|
start_time, train_collector, test_collector,
|
||||||
test_result['rew'])
|
test_result['rew'])
|
||||||
else:
|
else:
|
||||||
|
policy.train()
|
||||||
if train_fn:
|
if train_fn:
|
||||||
train_fn(epoch)
|
train_fn(epoch)
|
||||||
# train
|
|
||||||
policy.train()
|
|
||||||
for i in range(update_per_step * min(
|
for i in range(update_per_step * min(
|
||||||
result['n/st'] // collect_per_step, t.total - t.n)):
|
result['n/st'] // collect_per_step, t.total - t.n)):
|
||||||
global_step += collect_per_step
|
global_step += collect_per_step
|
||||||
|
@ -77,13 +77,13 @@ def onpolicy_trainer(
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
test_in_train = test_in_train and train_collector.policy == policy
|
test_in_train = test_in_train and train_collector.policy == policy
|
||||||
for epoch in range(1, 1 + max_epoch):
|
for epoch in range(1, 1 + max_epoch):
|
||||||
|
# train
|
||||||
|
policy.train()
|
||||||
|
if train_fn:
|
||||||
|
train_fn(epoch)
|
||||||
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||||
**tqdm_config) as t:
|
**tqdm_config) as t:
|
||||||
while t.n < t.total:
|
while t.n < t.total:
|
||||||
# collect
|
|
||||||
if train_fn:
|
|
||||||
train_fn(epoch)
|
|
||||||
policy.eval()
|
|
||||||
result = train_collector.collect(n_episode=collect_per_step)
|
result = train_collector.collect(n_episode=collect_per_step)
|
||||||
data = {}
|
data = {}
|
||||||
if test_in_train and stop_fn and stop_fn(result['rew']):
|
if test_in_train and stop_fn and stop_fn(result['rew']):
|
||||||
@ -100,10 +100,9 @@ def onpolicy_trainer(
|
|||||||
start_time, train_collector, test_collector,
|
start_time, train_collector, test_collector,
|
||||||
test_result['rew'])
|
test_result['rew'])
|
||||||
else:
|
else:
|
||||||
|
policy.train()
|
||||||
if train_fn:
|
if train_fn:
|
||||||
train_fn(epoch)
|
train_fn(epoch)
|
||||||
# train
|
|
||||||
policy.train()
|
|
||||||
losses = policy.update(
|
losses = policy.update(
|
||||||
0, train_collector.buffer, batch_size, repeat_per_collect)
|
0, train_collector.buffer, batch_size, repeat_per_collect)
|
||||||
train_collector.reset_buffer()
|
train_collector.reset_buffer()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user