gae
This commit is contained in:
parent
7b65d43394
commit
680fc0ffbe
@ -26,6 +26,7 @@
|
||||
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
|
||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
||||
- Vanilla Imitation Learning
|
||||
- [Generalized Advantage Estimation (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
||||
|
||||
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development.
|
||||
|
||||
|
@ -16,7 +16,8 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
||||
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
||||
* :class:`~tianshou.policy.ImitationPolicy`
|
||||
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
||||
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimation <https://arxiv.org/pdf/1506.02438.pdf>`_
|
||||
|
||||
|
||||
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
|
||||
|
@ -12,8 +12,6 @@ else: # pytest
|
||||
|
||||
|
||||
class MyPolicy(BasePolicy):
|
||||
"""docstring for MyPolicy"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -6,8 +6,8 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
@ -22,15 +22,15 @@ def get_args():
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--epoch', type=int, default=20)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=10)
|
||||
parser.add_argument('--collect-per-step', type=int, default=1)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
parser.add_argument('--layer-num', type=int, default=2)
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
@ -42,6 +42,7 @@ def get_args():
|
||||
parser.add_argument('--ent-coef', type=float, default=0.0)
|
||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
@ -84,12 +85,12 @@ def _test_ppo(args=get_args()):
|
||||
eps_clip=args.eps_clip,
|
||||
vf_coef=args.vf_coef,
|
||||
ent_coef=args.ent_coef,
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]])
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
gae_lambda=args.gae_lambda)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.collect(n_step=args.step_per_epoch)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -24,7 +24,7 @@ def get_args():
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--epoch', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||
@ -41,6 +41,7 @@ def get_args():
|
||||
parser.add_argument('--vf-coef', type=float, default=0.5)
|
||||
parser.add_argument('--ent-coef', type=float, default=0.001)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=None)
|
||||
parser.add_argument('--gae-lambda', type=float, default=1.)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
@ -70,8 +71,9 @@ def test_a2c(args=get_args()):
|
||||
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = A2CPolicy(
|
||||
actor, critic, optim, dist, args.gamma, vf_coef=args.vf_coef,
|
||||
ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm)
|
||||
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
|
||||
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
|
||||
max_grad_norm=args.max_grad_norm)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
|
@ -28,7 +28,7 @@ def get_args():
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--epoch', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
|
@ -29,7 +29,7 @@ def get_args():
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--n-step', type=int, default=4)
|
||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--epoch', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
|
@ -80,7 +80,7 @@ def get_args():
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--epoch', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||
|
@ -23,7 +23,7 @@ def get_args():
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
@ -42,6 +42,7 @@ def get_args():
|
||||
parser.add_argument('--ent-coef', type=float, default=0.0)
|
||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
@ -76,7 +77,8 @@ def test_ppo(args=get_args()):
|
||||
eps_clip=args.eps_clip,
|
||||
vf_coef=args.vf_coef,
|
||||
ent_coef=args.ent_coef,
|
||||
action_range=None)
|
||||
action_range=None,
|
||||
gae_lambda=args.gae_lambda)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
|
@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
@ -74,3 +75,34 @@ class BasePolicy(ABC, nn.Module):
|
||||
:return: A dict which includes loss and its corresponding label.
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_episodic_return(self, batch, v_s_=None,
|
||||
gamma=0.99, gae_lambda=0.95):
|
||||
"""Compute returns over given full-length episodes, including the
|
||||
implementation of Generalized Advantage Estimation (arXiv:1506.02438).
|
||||
|
||||
:param batch: a data batch which contains several full-episode data
|
||||
chronologically.
|
||||
:type batch: :class:`~tianshou.data.Batch`
|
||||
:param v_s_: the value function of all next states :math:`V(s')`.
|
||||
:type v_s_: numpy.ndarray
|
||||
:param float gamma: the discount factor, should be in [0, 1], defaults
|
||||
to 0.99.
|
||||
:param float gae_lambda: the parameter for Generalized Advantage
|
||||
Estimation, should be in [0, 1], defaults to 0.95.
|
||||
"""
|
||||
if v_s_ is None:
|
||||
v_s_ = np.zeros_like(batch.rew)
|
||||
if not isinstance(v_s_, np.ndarray):
|
||||
v_s_ = np.array(v_s_, np.float)
|
||||
else:
|
||||
v_s_ = v_s_.flatten()
|
||||
batch.returns = np.roll(v_s_, 1)
|
||||
m = (1. - batch.done) * gamma
|
||||
delta = batch.rew + v_s_ * m - batch.returns
|
||||
m *= gae_lambda
|
||||
gae = 0.
|
||||
for i in range(len(batch.rew) - 1, -1, -1):
|
||||
gae = delta[i] + m[i] * gae
|
||||
batch.returns[i] += gae
|
||||
return batch
|
||||
|
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@ -21,6 +22,8 @@ class A2CPolicy(PGPolicy):
|
||||
:param float ent_coef: weight for entropy loss, defaults to 0.01.
|
||||
:param float max_grad_norm: clipping gradients in back propagation,
|
||||
defaults to ``None``.
|
||||
:param float gae_lambda: in [0, 1], param for Generalized Advantage
|
||||
Estimation, defaults to 0.95.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -31,13 +34,28 @@ class A2CPolicy(PGPolicy):
|
||||
def __init__(self, actor, critic, optim,
|
||||
dist_fn=torch.distributions.Categorical,
|
||||
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
|
||||
max_grad_norm=None, **kwargs):
|
||||
max_grad_norm=None, gae_lambda=0.95, **kwargs):
|
||||
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
|
||||
self._lambda = gae_lambda
|
||||
self._w_vf = vf_coef
|
||||
self._w_ent = ent_coef
|
||||
self._grad_norm = max_grad_norm
|
||||
self._batch = 64
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
if self._lambda in [0, 1]:
|
||||
return self.compute_episodic_return(
|
||||
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
v_ = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch * 4, permute=False):
|
||||
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
|
||||
v_ = np.concatenate(v_, axis=0)
|
||||
return self.compute_episodic_return(
|
||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
|
||||
def forward(self, batch, state=None, **kwargs):
|
||||
"""Compute action over the given batch data.
|
||||
@ -63,6 +81,7 @@ class A2CPolicy(PGPolicy):
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||
self._batch = batch_size
|
||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
@ -70,12 +89,11 @@ class A2CPolicy(PGPolicy):
|
||||
result = self(b)
|
||||
dist = result.dist
|
||||
v = self.critic(b.obs)
|
||||
a = torch.tensor(b.act, device=dist.logits.device)
|
||||
r = torch.tensor(b.returns, device=dist.logits.device)
|
||||
a = torch.tensor(b.act, device=v.device)
|
||||
r = torch.tensor(b.returns, device=v.device)
|
||||
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
||||
vf_loss = F.mse_loss(r[:, None], v)
|
||||
ent_loss = dist.entropy().mean()
|
||||
|
||||
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
|
||||
loss.backward()
|
||||
if self._grad_norm:
|
||||
|
@ -39,9 +39,11 @@ class PGPolicy(BasePolicy):
|
||||
, where :math:`T` is the terminal time step, :math:`\gamma` is the
|
||||
discount factor, :math:`\gamma \in [0, 1]`.
|
||||
"""
|
||||
batch.returns = self._vanilla_returns(batch)
|
||||
# batch.returns = self._vanilla_returns(batch)
|
||||
# batch.returns = self._vectorized_returns(batch)
|
||||
return batch
|
||||
# return batch
|
||||
return self.compute_episodic_return(
|
||||
batch, gamma=self._gamma, gae_lambda=1.)
|
||||
|
||||
def forward(self, batch, state=None, **kwargs):
|
||||
"""Compute action over the given batch data.
|
||||
@ -82,26 +84,26 @@ class PGPolicy(BasePolicy):
|
||||
losses.append(loss.item())
|
||||
return {'loss': losses}
|
||||
|
||||
def _vanilla_returns(self, batch):
|
||||
returns = batch.rew[:]
|
||||
last = 0
|
||||
for i in range(len(returns) - 1, -1, -1):
|
||||
if not batch.done[i]:
|
||||
returns[i] += self._gamma * last
|
||||
last = returns[i]
|
||||
return returns
|
||||
# def _vanilla_returns(self, batch):
|
||||
# returns = batch.rew[:]
|
||||
# last = 0
|
||||
# for i in range(len(returns) - 1, -1, -1):
|
||||
# if not batch.done[i]:
|
||||
# returns[i] += self._gamma * last
|
||||
# last = returns[i]
|
||||
# return returns
|
||||
|
||||
def _vectorized_returns(self, batch):
|
||||
# according to my tests, it is slower than _vanilla_returns
|
||||
# import scipy.signal
|
||||
convolve = np.convolve
|
||||
# convolve = scipy.signal.convolve
|
||||
rew = batch.rew[::-1]
|
||||
batch_size = len(rew)
|
||||
gammas = self._gamma ** np.arange(batch_size)
|
||||
c = convolve(rew, gammas)[:batch_size]
|
||||
T = np.where(batch.done[::-1])[0]
|
||||
d = np.zeros_like(rew)
|
||||
d[T] += c[T] - rew[T]
|
||||
d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T)
|
||||
return (c - convolve(d, gammas)[:batch_size])[::-1]
|
||||
# def _vectorized_returns(self, batch):
|
||||
# # according to my tests, it is slower than _vanilla_returns
|
||||
# # import scipy.signal
|
||||
# convolve = np.convolve
|
||||
# # convolve = scipy.signal.convolve
|
||||
# rew = batch.rew[::-1]
|
||||
# batch_size = len(rew)
|
||||
# gammas = self._gamma ** np.arange(batch_size)
|
||||
# c = convolve(rew, gammas)[:batch_size]
|
||||
# T = np.where(batch.done[::-1])[0]
|
||||
# d = np.zeros_like(rew)
|
||||
# d[T] += c[T] - rew[T]
|
||||
# d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T)
|
||||
# return (c - convolve(d, gammas)[:batch_size])[::-1]
|
||||
|
@ -26,6 +26,8 @@ class PPOPolicy(PGPolicy):
|
||||
:param float ent_coef: weight for entropy loss, defaults to 0.01.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: [float, float]
|
||||
:param float gae_lambda: in [0, 1], param for Generalized Advantage
|
||||
Estimation, defaults to 0.95.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -40,6 +42,7 @@ class PPOPolicy(PGPolicy):
|
||||
vf_coef=.5,
|
||||
ent_coef=.0,
|
||||
action_range=None,
|
||||
gae_lambda=0.95,
|
||||
**kwargs):
|
||||
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
||||
self._max_grad_norm = max_grad_norm
|
||||
@ -52,6 +55,9 @@ class PPOPolicy(PGPolicy):
|
||||
self.critic, self.critic_old = critic, deepcopy(critic)
|
||||
self.critic_old.eval()
|
||||
self.optim = optim
|
||||
self._batch = 64
|
||||
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
|
||||
self._lambda = gae_lambda
|
||||
|
||||
def train(self):
|
||||
"""Set the module in training mode, except for the target network."""
|
||||
@ -65,6 +71,19 @@ class PPOPolicy(PGPolicy):
|
||||
self.actor.eval()
|
||||
self.critic.eval()
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
if self._lambda in [0, 1]:
|
||||
return self.compute_episodic_return(
|
||||
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
v_ = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch * 4, permute=False):
|
||||
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
|
||||
v_ = np.concatenate(v_, axis=0)
|
||||
batch.v_ = v_
|
||||
return self.compute_episodic_return(
|
||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
|
||||
def forward(self, batch, state=None, model='actor', **kwargs):
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
@ -97,18 +116,20 @@ class PPOPolicy(PGPolicy):
|
||||
self.critic_old.load_state_dict(self.critic.state_dict())
|
||||
|
||||
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||
self._batch = batch_size
|
||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||
r = batch.returns
|
||||
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
||||
batch.act = torch.tensor(batch.act)
|
||||
batch.returns = torch.tensor(batch.returns)[:, None]
|
||||
batch.v_ = torch.tensor(batch.v_)
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
vs_old, vs__old = self.critic_old(np.concatenate([
|
||||
b.obs, b.obs_next])).split(b.obs.shape[0])
|
||||
vs_old = self.critic_old(b.obs)
|
||||
vs__old = b.v_.to(vs_old.device)
|
||||
dist = self(b).dist
|
||||
dist_old = self(b, model='actor_old').dist
|
||||
target_v = b.returns.to(vs__old.device) + self._gamma * vs__old
|
||||
target_v = b.returns.to(vs_old.device) + self._gamma * vs__old
|
||||
adv = (target_v - vs_old).detach()
|
||||
a = b.act.to(adv.device)
|
||||
ratio = torch.exp(dist.log_prob(a) - dist_old.log_prob(a))
|
||||
|
Loading…
x
Reference in New Issue
Block a user