Add auto alpha tuning and exploration noise for sac. (#80)

Add class BaseNoise and GaussianNoise for the concept of exploration noise.
Add new test for sac tested in MountainCarContinuous-v0,
which should benefits from the two above new feature.
This commit is contained in:
danagi 2020-06-16 22:17:28 +08:00 committed by GitHub
parent 263e490b76
commit c59ad40aef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 255 additions and 49 deletions

View File

@ -10,6 +10,7 @@ from tianshou.env import VectorEnv
from tianshou.policy import DDPGPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.exploration import GaussianNoise
if __name__ == '__main__':
from net import Actor, Critic
@ -78,7 +79,7 @@ def test_ddpg(args=get_args()):
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim,
args.tau, args.gamma, args.exploration_noise,
args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
[env.action_space.low[0], env.action_space.high[0]],
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,

View File

@ -0,0 +1,130 @@
import os
import gym
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import SACPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv
from tianshou.exploration import OUNoise
if __name__ == '__main__':
from net import ActorProb, Critic
else: # pytest
from test.continuous.net import ActorProb, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='MountainCarContinuous-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=50000)
parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--critic-lr', type=float, default=3e-4)
parser.add_argument('--alpha-lr', type=float, default=3e-4)
parser.add_argument('--noise_std', type=float, default=0.5)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--auto_alpha', type=bool, default=True)
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--training-num', type=int, default=80)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=1.0/35.0)
parser.add_argument('--rew-norm', type=bool, default=False)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
def test_sac(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
# train_envs = gym.make(args.task)
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic(
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
if args.auto_alpha:
target_entropy = -np.prod(env.action_space.shape)
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
alpha = (target_entropy, log_alpha, alpha_optim)
else:
alpha = args.alpha
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, alpha,
[env.action_space.low[0], env.action_space.high[0]],
reward_normalization=args.rew_norm, ignore_done=True,
exploration_noise=OUNoise(0.0, args.noise_std))
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()
if __name__ == '__main__':
test_sac()

View File

@ -10,6 +10,7 @@ from tianshou.env import VectorEnv
from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.exploration import GaussianNoise
if __name__ == '__main__':
from net import Actor, Critic
@ -85,8 +86,8 @@ def test_td3(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.exploration_noise, args.policy_noise,
args.update_actor_freq, args.noise_clip,
args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
args.policy_noise, args.update_actor_freq, args.noise_clip,
[env.action_space.low[0], env.action_space.high[0]],
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,

View File

@ -1,5 +1,7 @@
from tianshou.exploration.random import OUNoise
from tianshou.exploration.random import BaseNoise, GaussianNoise, OUNoise
__all__ = [
'BaseNoise',
'GaussianNoise',
'OUNoise',
]

View File

@ -1,8 +1,42 @@
import numpy as np
from typing import Union, Optional
from abc import ABC, abstractmethod
class OUNoise(object):
class BaseNoise(ABC, object):
"""The action noise base class."""
def __init__(self, **kwargs) -> None:
super(BaseNoise, self).__init__()
@abstractmethod
def __call__(self, **kwargs) -> np.ndarray:
"""Generate new noise."""
raise NotImplementedError
def reset(self, **kwargs) -> None:
"""Reset to the initial state."""
pass
class GaussianNoise(BaseNoise):
"""Class for vanilla gaussian process,
used for exploration in DDPG by default.
"""
def __init__(self,
mu: float = 0.0,
sigma: float = 1.0):
super().__init__()
self._mu = mu
assert 0 <= sigma, 'noise std should not be negative'
self._sigma = sigma
def __call__(self, size: tuple) -> np.ndarray:
return np.random.normal(self._mu, self._sigma, size)
class OUNoise(BaseNoise):
"""Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG.
Usage:
::
@ -19,26 +53,31 @@ class OUNoise(object):
"""
def __init__(self,
mu: float = 0.0,
sigma: float = 0.3,
theta: float = 0.15,
dt: float = 1e-2,
x0: Optional[Union[float, np.ndarray]] = None
) -> None:
self.alpha = theta * dt
self.beta = sigma * np.sqrt(dt)
self.x0 = x0
super(BaseNoise, self).__init__()
self._mu = mu
self._alpha = theta * dt
self._beta = sigma * np.sqrt(dt)
self._x0 = x0
self.reset()
def __call__(self, size: tuple, mu: float = .1) -> np.ndarray:
def __call__(self, size: tuple, mu: Optional[float] = None) -> np.ndarray:
"""Generate new noise. Return a ``numpy.ndarray`` which size is equal
to ``size``.
"""
if self.x is None or self.x.shape != size:
self.x = 0
r = self.beta * np.random.normal(size=size)
self.x = self.x + self.alpha * (mu - self.x) + r
return self.x
if self._x is None or self._x.shape != size:
self._x = 0
if mu is None:
mu = self._mu
r = self._beta * np.random.normal(size=size)
self._x = self._x + self._alpha * (mu - self._x) + r
return self._x
def reset(self) -> None:
"""Reset to the initial state."""
self.x = None
self._x = self._x0

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from typing import Dict, Tuple, Union, Optional
from tianshou.policy import BasePolicy
# from tianshou.exploration import OUNoise
from tianshou.exploration import BaseNoise, GaussianNoise
from tianshou.data import Batch, ReplayBuffer, to_torch_as
@ -21,8 +21,8 @@ class DDPGPolicy(BasePolicy):
:param float tau: param for soft update of the target network, defaults to
0.005.
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
:param float exploration_noise: the noise intensity, add to the action,
defaults to 0.1.
:param BaseNoise exploration_noise: the exploration noise,
add to the action, defaults to ``GaussianNoise(sigma=0.1)``.
:param action_range: the action range (minimum, maximum).
:type action_range: (float, float)
:param bool reward_normalization: normalize the reward to Normal(0, 1),
@ -45,7 +45,8 @@ class DDPGPolicy(BasePolicy):
critic_optim: torch.optim.Optimizer,
tau: float = 0.005,
gamma: float = 0.99,
exploration_noise: float = 0.1,
exploration_noise: Optional[BaseNoise]
= GaussianNoise(sigma=0.1),
action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False,
ignore_done: bool = False,
@ -64,8 +65,7 @@ class DDPGPolicy(BasePolicy):
self._tau = tau
assert 0 <= gamma <= 1, 'gamma should in [0, 1]'
self._gamma = gamma
assert 0 <= exploration_noise, 'noise should not be negative'
self._eps = exploration_noise
self._noise = exploration_noise
assert action_range is not None
self._range = action_range
self._action_bias = (action_range[0] + action_range[1]) / 2
@ -77,9 +77,9 @@ class DDPGPolicy(BasePolicy):
assert estimation_step > 0, 'estimation_step should greater than 0'
self._n_step = estimation_step
def set_eps(self, eps: float) -> None:
"""Set the eps for exploration."""
self._eps = eps
def set_exp_noise(self, noise: Optional[BaseNoise]) -> None:
"""Set the exploration noise."""
self._noise = noise
def train(self) -> None:
"""Set the module in training mode, except for the target network."""
@ -106,7 +106,8 @@ class DDPGPolicy(BasePolicy):
batch = buffer[indice] # batch.obs_next: s_{t+n}
with torch.no_grad():
target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next', eps=0).act)
batch, model='actor_old', input='obs_next',
explorating=False).act)
return target_q
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
@ -122,7 +123,7 @@ class DDPGPolicy(BasePolicy):
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = 'actor',
input: str = 'obs',
eps: Optional[float] = None,
explorating: bool = True,
**kwargs) -> Batch:
"""Compute action over the given batch data.
@ -142,14 +143,8 @@ class DDPGPolicy(BasePolicy):
obs = getattr(batch, input)
logits, h = model(obs, state=state, info=batch.info)
logits += self._action_bias
if eps is None:
eps = self._eps
if eps > 0:
# noise = np.random.normal(0, eps, size=logits.shape)
# logits += to_torch(noise, device=logits.device)
# noise = self.noise(logits.shape, eps)
logits += torch.randn(
size=logits.shape, device=logits.device) * eps
if self.training and explorating:
logits += to_torch_as(self._noise(logits.shape), logits)
logits = logits.clamp(self._range[0], self._range[1])
return Batch(act=logits, state=h)
@ -161,7 +156,8 @@ class DDPGPolicy(BasePolicy):
self.critic_optim.zero_grad()
critic_loss.backward()
self.critic_optim.step()
actor_loss = -self.critic(batch.obs, self(batch, eps=0).act).mean()
action = self(batch, explorating=False).act
actor_loss = -self.critic(batch.obs, action).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()

View File

@ -7,6 +7,7 @@ from typing import Dict, Tuple, Union, Optional
from tianshou.policy import DDPGPolicy
from tianshou.policy.dist import DiagGaussian
from tianshou.data import Batch, to_torch_as, ReplayBuffer
from tianshou.exploration import BaseNoise
class SACPolicy(DDPGPolicy):
@ -28,13 +29,18 @@ class SACPolicy(DDPGPolicy):
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
:param float exploration_noise: the noise intensity, add to the action,
defaults to 0.1.
:param float alpha: entropy regularization coefficient, default to 0.2.
:param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
regularization coefficient, default to 0.2.
If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
alpha is automatatically tuned.
:param action_range: the action range (minimum, maximum).
:type action_range: (float, float)
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to ``False``.
:param BaseNoise exploration_noise: add a noise to action for exploration.
This is useful when solving hard-exploration problem.
.. seealso::
@ -51,13 +57,15 @@ class SACPolicy(DDPGPolicy):
critic2_optim: torch.optim.Optimizer,
tau: float = 0.005,
gamma: float = 0.99,
alpha: float = 0.2,
alpha: Tuple[float, torch.Tensor, torch.optim.Optimizer]
or float = 0.2,
action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
exploration_noise: Optional[BaseNoise] = None,
**kwargs) -> None:
super().__init__(None, None, None, None, tau, gamma, 0,
super().__init__(None, None, None, None, tau, gamma, exploration_noise,
action_range, reward_normalization, ignore_done,
estimation_step, **kwargs)
self.actor, self.actor_optim = actor, actor_optim
@ -67,7 +75,18 @@ class SACPolicy(DDPGPolicy):
self.critic2, self.critic2_old = critic2, deepcopy(critic2)
self.critic2_old.eval()
self.critic2_optim = critic2_optim
self._alpha = alpha
self._automatic_alpha_tuning = not isinstance(alpha, float)
if self._automatic_alpha_tuning:
self._target_entropy = alpha[0]
assert(alpha[1].shape == torch.Size([1])
and alpha[1].requires_grad)
self._log_alpha = alpha[1]
self._alpha_optim = alpha[2]
self._alpha = self._log_alpha.exp()
else:
self._alpha = alpha
self.__eps = np.finfo(np.float32).eps.item()
def train(self) -> None:
@ -92,7 +111,9 @@ class SACPolicy(DDPGPolicy):
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
input: str = 'obs', **kwargs) -> Batch:
input: str = 'obs',
explorating: bool = True,
**kwargs) -> Batch:
obs = getattr(batch, input)
logits, h = self.actor(obs, state=state, info=batch.info)
assert isinstance(logits, tuple)
@ -100,9 +121,10 @@ class SACPolicy(DDPGPolicy):
x = dist.rsample()
y = torch.tanh(x)
act = y * self._action_scale + self._action_bias
log_prob = dist.log_prob(x) - torch.log(
self._action_scale * (1 - y.pow(2)) + self.__eps
).sum(-1, keepdim=True)
y = self._action_scale * (1 - y.pow(2)) + self.__eps
log_prob = dist.log_prob(x) - torch.log(y).sum(-1, keepdim=True)
if self._noise is not None and self.training and explorating:
act += to_torch_as(self._noise(act.shape), act)
act = act.clamp(self._range[0], self._range[1])
return Batch(
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
@ -111,7 +133,7 @@ class SACPolicy(DDPGPolicy):
indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs: s_{t+n}
with torch.no_grad():
obs_next_result = self(batch, input='obs_next')
obs_next_result = self(batch, input='obs_next', explorating=False)
a_ = obs_next_result.act
batch.act = to_torch_as(batch.act, a_)
target_q = torch.min(
@ -135,7 +157,7 @@ class SACPolicy(DDPGPolicy):
critic2_loss.backward()
self.critic2_optim.step()
# actor
obs_result = self(batch)
obs_result = self(batch, explorating=False)
a = obs_result.act
current_q1a = self.critic1(batch.obs, a)
current_q2a = self.critic2(batch.obs, a)
@ -144,9 +166,22 @@ class SACPolicy(DDPGPolicy):
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()
if self._automatic_alpha_tuning:
log_prob = (obs_result.log_prob + self._target_entropy).detach()
alpha_loss = -(self._log_alpha * log_prob).mean()
self._alpha_optim.zero_grad()
alpha_loss.backward()
self._alpha_optim.step()
self._alpha = self._log_alpha.exp()
self.sync_weight()
return {
result = {
'loss/actor': actor_loss.item(),
'loss/critic1': critic1_loss.item(),
'loss/critic2': critic2_loss.item(),
}
if self._automatic_alpha_tuning:
result['loss/alpha'] = alpha_loss.item()
return result

View File

@ -6,6 +6,7 @@ from typing import Dict, Tuple, Optional
from tianshou.policy import DDPGPolicy
from tianshou.data import Batch, ReplayBuffer
from tianshou.exploration import BaseNoise, GaussianNoise
class TD3Policy(DDPGPolicy):
@ -26,8 +27,8 @@ class TD3Policy(DDPGPolicy):
:param float tau: param for soft update of the target network, defaults to
0.005.
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
:param float exploration_noise: the noise intensity, add to the action,
defaults to 0.1.
:param float exploration_noise: the exploration noise, add to the action,
defaults to ``GaussianNoise(sigma=0.1)``
:param float policy_noise: the noise used in updating policy network,
default to 0.2.
:param int update_actor_freq: the update frequency of actor network,
@ -56,7 +57,8 @@ class TD3Policy(DDPGPolicy):
critic2_optim: torch.optim.Optimizer,
tau: float = 0.005,
gamma: float = 0.99,
exploration_noise: float = 0.1,
exploration_noise: Optional[BaseNoise]
= GaussianNoise(sigma=0.1),
policy_noise: float = 0.2,
update_actor_freq: int = 2,
noise_clip: float = 0.5,