This commit is contained in:
Trinkle23897 2020-04-14 21:11:06 +08:00
parent 7b65d43394
commit 680fc0ffbe
13 changed files with 129 additions and 51 deletions

View File

@ -26,6 +26,7 @@
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- Vanilla Imitation Learning
- [Generalized Advantage Estimation (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development.

View File

@ -16,7 +16,8 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy`
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimation <https://arxiv.org/pdf/1506.02438.pdf>`_
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.

View File

@ -12,8 +12,6 @@ else: # pytest
class MyPolicy(BasePolicy):
"""docstring for MyPolicy"""
def __init__(self):
super().__init__()

View File

@ -6,8 +6,8 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PPOPolicy
from tianshou.env import VectorEnv
from tianshou.policy import PPOPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
@ -22,15 +22,15 @@ def get_args():
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=10)
parser.add_argument('--collect-per-step', type=int, default=1)
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--layer-num', type=int, default=2)
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
@ -42,6 +42,7 @@ def get_args():
parser.add_argument('--ent-coef', type=float, default=0.0)
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--gae-lambda', type=float, default=0.95)
args = parser.parse_known_args()[0]
return args
@ -84,12 +85,12 @@ def _test_ppo(args=get_args()):
eps_clip=args.eps_clip,
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
action_range=[env.action_space.low[0], env.action_space.high[0]])
action_range=[env.action_space.low[0], env.action_space.high[0]],
gae_lambda=args.gae_lambda)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.step_per_epoch)
# log
log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path)

View File

@ -24,7 +24,7 @@ def get_args():
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=1)
@ -41,6 +41,7 @@ def get_args():
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.001)
parser.add_argument('--max-grad-norm', type=float, default=None)
parser.add_argument('--gae-lambda', type=float, default=1.)
args = parser.parse_known_args()[0]
return args
@ -70,8 +71,9 @@ def test_a2c(args=get_args()):
actor.parameters()) + list(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = A2CPolicy(
actor, critic, optim, dist, args.gamma, vf_coef=args.vf_coef,
ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm)
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
max_grad_norm=args.max_grad_norm)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))

View File

@ -28,7 +28,7 @@ def get_args():
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)

View File

@ -29,7 +29,7 @@ def get_args():
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--n-step', type=int, default=4)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)

View File

@ -80,7 +80,7 @@ def get_args():
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=2)

View File

@ -23,7 +23,7 @@ def get_args():
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
@ -42,6 +42,7 @@ def get_args():
parser.add_argument('--ent-coef', type=float, default=0.0)
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--gae-lambda', type=float, default=0.95)
args = parser.parse_known_args()[0]
return args
@ -76,7 +77,8 @@ def test_ppo(args=get_args()):
eps_clip=args.eps_clip,
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
action_range=None)
action_range=None,
gae_lambda=args.gae_lambda)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))

View File

@ -1,3 +1,4 @@
import numpy as np
from torch import nn
from abc import ABC, abstractmethod
@ -74,3 +75,34 @@ class BasePolicy(ABC, nn.Module):
:return: A dict which includes loss and its corresponding label.
"""
pass
def compute_episodic_return(self, batch, v_s_=None,
gamma=0.99, gae_lambda=0.95):
"""Compute returns over given full-length episodes, including the
implementation of Generalized Advantage Estimation (arXiv:1506.02438).
:param batch: a data batch which contains several full-episode data
chronologically.
:type batch: :class:`~tianshou.data.Batch`
:param v_s_: the value function of all next states :math:`V(s')`.
:type v_s_: numpy.ndarray
:param float gamma: the discount factor, should be in [0, 1], defaults
to 0.99.
:param float gae_lambda: the parameter for Generalized Advantage
Estimation, should be in [0, 1], defaults to 0.95.
"""
if v_s_ is None:
v_s_ = np.zeros_like(batch.rew)
if not isinstance(v_s_, np.ndarray):
v_s_ = np.array(v_s_, np.float)
else:
v_s_ = v_s_.flatten()
batch.returns = np.roll(v_s_, 1)
m = (1. - batch.done) * gamma
delta = batch.rew + v_s_ * m - batch.returns
m *= gae_lambda
gae = 0.
for i in range(len(batch.rew) - 1, -1, -1):
gae = delta[i] + m[i] * gae
batch.returns[i] += gae
return batch

View File

@ -1,4 +1,5 @@
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
@ -21,6 +22,8 @@ class A2CPolicy(PGPolicy):
:param float ent_coef: weight for entropy loss, defaults to 0.01.
:param float max_grad_norm: clipping gradients in back propagation,
defaults to ``None``.
:param float gae_lambda: in [0, 1], param for Generalized Advantage
Estimation, defaults to 0.95.
.. seealso::
@ -31,13 +34,28 @@ class A2CPolicy(PGPolicy):
def __init__(self, actor, critic, optim,
dist_fn=torch.distributions.Categorical,
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
max_grad_norm=None, **kwargs):
max_grad_norm=None, gae_lambda=0.95, **kwargs):
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
self.actor = actor
self.critic = critic
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
self._lambda = gae_lambda
self._w_vf = vf_coef
self._w_ent = ent_coef
self._grad_norm = max_grad_norm
self._batch = 64
def process_fn(self, batch, buffer, indice):
if self._lambda in [0, 1]:
return self.compute_episodic_return(
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
v_ = []
with torch.no_grad():
for b in batch.split(self._batch * 4, permute=False):
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
v_ = np.concatenate(v_, axis=0)
return self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
def forward(self, batch, state=None, **kwargs):
"""Compute action over the given batch data.
@ -63,6 +81,7 @@ class A2CPolicy(PGPolicy):
return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
self._batch = batch_size
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat):
for b in batch.split(batch_size):
@ -70,12 +89,11 @@ class A2CPolicy(PGPolicy):
result = self(b)
dist = result.dist
v = self.critic(b.obs)
a = torch.tensor(b.act, device=dist.logits.device)
r = torch.tensor(b.returns, device=dist.logits.device)
a = torch.tensor(b.act, device=v.device)
r = torch.tensor(b.returns, device=v.device)
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
vf_loss = F.mse_loss(r[:, None], v)
ent_loss = dist.entropy().mean()
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
loss.backward()
if self._grad_norm:

View File

@ -39,9 +39,11 @@ class PGPolicy(BasePolicy):
, where :math:`T` is the terminal time step, :math:`\gamma` is the
discount factor, :math:`\gamma \in [0, 1]`.
"""
batch.returns = self._vanilla_returns(batch)
# batch.returns = self._vanilla_returns(batch)
# batch.returns = self._vectorized_returns(batch)
return batch
# return batch
return self.compute_episodic_return(
batch, gamma=self._gamma, gae_lambda=1.)
def forward(self, batch, state=None, **kwargs):
"""Compute action over the given batch data.
@ -82,26 +84,26 @@ class PGPolicy(BasePolicy):
losses.append(loss.item())
return {'loss': losses}
def _vanilla_returns(self, batch):
returns = batch.rew[:]
last = 0
for i in range(len(returns) - 1, -1, -1):
if not batch.done[i]:
returns[i] += self._gamma * last
last = returns[i]
return returns
# def _vanilla_returns(self, batch):
# returns = batch.rew[:]
# last = 0
# for i in range(len(returns) - 1, -1, -1):
# if not batch.done[i]:
# returns[i] += self._gamma * last
# last = returns[i]
# return returns
def _vectorized_returns(self, batch):
# according to my tests, it is slower than _vanilla_returns
# import scipy.signal
convolve = np.convolve
# convolve = scipy.signal.convolve
rew = batch.rew[::-1]
batch_size = len(rew)
gammas = self._gamma ** np.arange(batch_size)
c = convolve(rew, gammas)[:batch_size]
T = np.where(batch.done[::-1])[0]
d = np.zeros_like(rew)
d[T] += c[T] - rew[T]
d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T)
return (c - convolve(d, gammas)[:batch_size])[::-1]
# def _vectorized_returns(self, batch):
# # according to my tests, it is slower than _vanilla_returns
# # import scipy.signal
# convolve = np.convolve
# # convolve = scipy.signal.convolve
# rew = batch.rew[::-1]
# batch_size = len(rew)
# gammas = self._gamma ** np.arange(batch_size)
# c = convolve(rew, gammas)[:batch_size]
# T = np.where(batch.done[::-1])[0]
# d = np.zeros_like(rew)
# d[T] += c[T] - rew[T]
# d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T)
# return (c - convolve(d, gammas)[:batch_size])[::-1]

View File

@ -26,6 +26,8 @@ class PPOPolicy(PGPolicy):
:param float ent_coef: weight for entropy loss, defaults to 0.01.
:param action_range: the action range (minimum, maximum).
:type action_range: [float, float]
:param float gae_lambda: in [0, 1], param for Generalized Advantage
Estimation, defaults to 0.95.
.. seealso::
@ -40,6 +42,7 @@ class PPOPolicy(PGPolicy):
vf_coef=.5,
ent_coef=.0,
action_range=None,
gae_lambda=0.95,
**kwargs):
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
self._max_grad_norm = max_grad_norm
@ -52,6 +55,9 @@ class PPOPolicy(PGPolicy):
self.critic, self.critic_old = critic, deepcopy(critic)
self.critic_old.eval()
self.optim = optim
self._batch = 64
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
self._lambda = gae_lambda
def train(self):
"""Set the module in training mode, except for the target network."""
@ -65,6 +71,19 @@ class PPOPolicy(PGPolicy):
self.actor.eval()
self.critic.eval()
def process_fn(self, batch, buffer, indice):
if self._lambda in [0, 1]:
return self.compute_episodic_return(
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
v_ = []
with torch.no_grad():
for b in batch.split(self._batch * 4, permute=False):
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
v_ = np.concatenate(v_, axis=0)
batch.v_ = v_
return self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
def forward(self, batch, state=None, model='actor', **kwargs):
"""Compute action over the given batch data.
@ -97,18 +116,20 @@ class PPOPolicy(PGPolicy):
self.critic_old.load_state_dict(self.critic.state_dict())
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
self._batch = batch_size
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
r = batch.returns
batch.returns = (r - r.mean()) / (r.std() + self._eps)
batch.act = torch.tensor(batch.act)
batch.returns = torch.tensor(batch.returns)[:, None]
batch.v_ = torch.tensor(batch.v_)
for _ in range(repeat):
for b in batch.split(batch_size):
vs_old, vs__old = self.critic_old(np.concatenate([
b.obs, b.obs_next])).split(b.obs.shape[0])
vs_old = self.critic_old(b.obs)
vs__old = b.v_.to(vs_old.device)
dist = self(b).dist
dist_old = self(b, model='actor_old').dist
target_v = b.returns.to(vs__old.device) + self._gamma * vs__old
target_v = b.returns.to(vs_old.device) + self._gamma * vs__old
adv = (target_v - vs_old).detach()
a = b.act.to(adv.device)
ratio = torch.exp(dist.log_prob(a) - dist_old.log_prob(a))