From 680fc0ffbe3848a6ff42b1950351dcb6aaa37ac8 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Tue, 14 Apr 2020 21:11:06 +0800 Subject: [PATCH] gae --- README.md | 1 + docs/index.rst | 3 +- test/base/test_collector.py | 2 -- test/continuous/test_ppo.py | 19 ++++++------ test/discrete/test_a2c.py | 8 +++-- test/discrete/test_dqn.py | 2 +- test/discrete/test_drqn.py | 2 +- test/discrete/test_pg.py | 2 +- test/discrete/test_ppo.py | 6 ++-- tianshou/policy/base.py | 32 ++++++++++++++++++++ tianshou/policy/modelfree/a2c.py | 26 ++++++++++++++--- tianshou/policy/modelfree/pg.py | 50 +++++++++++++++++--------------- tianshou/policy/modelfree/ppo.py | 27 +++++++++++++++-- 13 files changed, 129 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index 75c7cb1..fbe62c9 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ - [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf) - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) - Vanilla Imitation Learning +- [Generalized Advantage Estimation (GAE)](https://arxiv.org/pdf/1506.02438.pdf) Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development. diff --git a/docs/index.rst b/docs/index.rst index 155661e..0981f84 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -16,7 +16,8 @@ Welcome to Tianshou! * :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization `_ * :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG `_ * :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic `_ -* :class:`~tianshou.policy.ImitationPolicy` +* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning +* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimation `_ Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. diff --git a/test/base/test_collector.py b/test/base/test_collector.py index e97b84c..2b20756 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -12,8 +12,6 @@ else: # pytest class MyPolicy(BasePolicy): - """docstring for MyPolicy""" - def __init__(self): super().__init__() diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index e1877fb..dac27d6 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -6,8 +6,8 @@ import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import PPOPolicy from tianshou.env import VectorEnv +from tianshou.policy import PPOPolicy from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer @@ -22,15 +22,15 @@ def get_args(): parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=3e-4) - parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--repeat-per-collect', type=int, default=10) + parser.add_argument('--collect-per-step', type=int, default=1) + parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--layer-num', type=int, default=1) - parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--layer-num', type=int, default=2) + parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) @@ -42,6 +42,7 @@ def get_args(): parser.add_argument('--ent-coef', type=float, default=0.0) parser.add_argument('--eps-clip', type=float, default=0.2) parser.add_argument('--max-grad-norm', type=float, default=0.5) + parser.add_argument('--gae-lambda', type=float, default=0.95) args = parser.parse_known_args()[0] return args @@ -84,12 +85,12 @@ def _test_ppo(args=get_args()): eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - action_range=[env.action_space.low[0], env.action_space.high[0]]) + action_range=[env.action_space.low[0], env.action_space.high[0]], + gae_lambda=args.gae_lambda) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) test_collector = Collector(policy, test_envs) - train_collector.collect(n_step=args.step_per_epoch) # log log_path = os.path.join(args.logdir, args.task, 'ppo') writer = SummaryWriter(log_path) diff --git a/test/discrete/test_a2c.py b/test/discrete/test_a2c.py index 93cfe11..0c167fa 100644 --- a/test/discrete/test_a2c.py +++ b/test/discrete/test_a2c.py @@ -24,7 +24,7 @@ def get_args(): parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--gamma', type=float, default=0.9) - parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--repeat-per-collect', type=int, default=1) @@ -41,6 +41,7 @@ def get_args(): parser.add_argument('--vf-coef', type=float, default=0.5) parser.add_argument('--ent-coef', type=float, default=0.001) parser.add_argument('--max-grad-norm', type=float, default=None) + parser.add_argument('--gae-lambda', type=float, default=1.) args = parser.parse_known_args()[0] return args @@ -70,8 +71,9 @@ def test_a2c(args=get_args()): actor.parameters()) + list(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical policy = A2CPolicy( - actor, critic, optim, dist, args.gamma, vf_coef=args.vf_coef, - ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm) + actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda, + vf_coef=args.vf_coef, ent_coef=args.ent_coef, + max_grad_norm=args.max_grad_norm) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 2816613..b378194 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -28,7 +28,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) - parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--batch-size', type=int, default=64) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 2be8a05..8c15e3b 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -29,7 +29,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--n-step', type=int, default=4) parser.add_argument('--target-update-freq', type=int, default=320) - parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--batch-size', type=int, default=64) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 0293b07..40be1a9 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -80,7 +80,7 @@ def get_args(): parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--gamma', type=float, default=0.9) - parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--repeat-per-collect', type=int, default=2) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 28ca0e9..9f84d40 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -23,7 +23,7 @@ def get_args(): parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--collect-per-step', type=int, default=10) @@ -42,6 +42,7 @@ def get_args(): parser.add_argument('--ent-coef', type=float, default=0.0) parser.add_argument('--eps-clip', type=float, default=0.2) parser.add_argument('--max-grad-norm', type=float, default=0.5) + parser.add_argument('--gae-lambda', type=float, default=0.95) args = parser.parse_known_args()[0] return args @@ -76,7 +77,8 @@ def test_ppo(args=get_args()): eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - action_range=None) + action_range=None, + gae_lambda=args.gae_lambda) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 9b67281..445dc29 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -1,3 +1,4 @@ +import numpy as np from torch import nn from abc import ABC, abstractmethod @@ -74,3 +75,34 @@ class BasePolicy(ABC, nn.Module): :return: A dict which includes loss and its corresponding label. """ pass + + def compute_episodic_return(self, batch, v_s_=None, + gamma=0.99, gae_lambda=0.95): + """Compute returns over given full-length episodes, including the + implementation of Generalized Advantage Estimation (arXiv:1506.02438). + + :param batch: a data batch which contains several full-episode data + chronologically. + :type batch: :class:`~tianshou.data.Batch` + :param v_s_: the value function of all next states :math:`V(s')`. + :type v_s_: numpy.ndarray + :param float gamma: the discount factor, should be in [0, 1], defaults + to 0.99. + :param float gae_lambda: the parameter for Generalized Advantage + Estimation, should be in [0, 1], defaults to 0.95. + """ + if v_s_ is None: + v_s_ = np.zeros_like(batch.rew) + if not isinstance(v_s_, np.ndarray): + v_s_ = np.array(v_s_, np.float) + else: + v_s_ = v_s_.flatten() + batch.returns = np.roll(v_s_, 1) + m = (1. - batch.done) * gamma + delta = batch.rew + v_s_ * m - batch.returns + m *= gae_lambda + gae = 0. + for i in range(len(batch.rew) - 1, -1, -1): + gae = delta[i] + m[i] * gae + batch.returns[i] += gae + return batch diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 09937b4..df35cf9 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -1,4 +1,5 @@ import torch +import numpy as np from torch import nn import torch.nn.functional as F @@ -21,6 +22,8 @@ class A2CPolicy(PGPolicy): :param float ent_coef: weight for entropy loss, defaults to 0.01. :param float max_grad_norm: clipping gradients in back propagation, defaults to ``None``. + :param float gae_lambda: in [0, 1], param for Generalized Advantage + Estimation, defaults to 0.95. .. seealso:: @@ -31,13 +34,28 @@ class A2CPolicy(PGPolicy): def __init__(self, actor, critic, optim, dist_fn=torch.distributions.Categorical, discount_factor=0.99, vf_coef=.5, ent_coef=.01, - max_grad_norm=None, **kwargs): + max_grad_norm=None, gae_lambda=0.95, **kwargs): super().__init__(None, optim, dist_fn, discount_factor, **kwargs) self.actor = actor self.critic = critic + assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].' + self._lambda = gae_lambda self._w_vf = vf_coef self._w_ent = ent_coef self._grad_norm = max_grad_norm + self._batch = 64 + + def process_fn(self, batch, buffer, indice): + if self._lambda in [0, 1]: + return self.compute_episodic_return( + batch, None, gamma=self._gamma, gae_lambda=self._lambda) + v_ = [] + with torch.no_grad(): + for b in batch.split(self._batch * 4, permute=False): + v_.append(self.critic(b.obs_next).detach().cpu().numpy()) + v_ = np.concatenate(v_, axis=0) + return self.compute_episodic_return( + batch, v_, gamma=self._gamma, gae_lambda=self._lambda) def forward(self, batch, state=None, **kwargs): """Compute action over the given batch data. @@ -63,6 +81,7 @@ class A2CPolicy(PGPolicy): return Batch(logits=logits, act=act, state=h, dist=dist) def learn(self, batch, batch_size=None, repeat=1, **kwargs): + self._batch = batch_size losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size): @@ -70,12 +89,11 @@ class A2CPolicy(PGPolicy): result = self(b) dist = result.dist v = self.critic(b.obs) - a = torch.tensor(b.act, device=dist.logits.device) - r = torch.tensor(b.returns, device=dist.logits.device) + a = torch.tensor(b.act, device=v.device) + r = torch.tensor(b.returns, device=v.device) a_loss = -(dist.log_prob(a) * (r - v).detach()).mean() vf_loss = F.mse_loss(r[:, None], v) ent_loss = dist.entropy().mean() - loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss loss.backward() if self._grad_norm: diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 1966bcc..a79a122 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -39,9 +39,11 @@ class PGPolicy(BasePolicy): , where :math:`T` is the terminal time step, :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. """ - batch.returns = self._vanilla_returns(batch) + # batch.returns = self._vanilla_returns(batch) # batch.returns = self._vectorized_returns(batch) - return batch + # return batch + return self.compute_episodic_return( + batch, gamma=self._gamma, gae_lambda=1.) def forward(self, batch, state=None, **kwargs): """Compute action over the given batch data. @@ -82,26 +84,26 @@ class PGPolicy(BasePolicy): losses.append(loss.item()) return {'loss': losses} - def _vanilla_returns(self, batch): - returns = batch.rew[:] - last = 0 - for i in range(len(returns) - 1, -1, -1): - if not batch.done[i]: - returns[i] += self._gamma * last - last = returns[i] - return returns + # def _vanilla_returns(self, batch): + # returns = batch.rew[:] + # last = 0 + # for i in range(len(returns) - 1, -1, -1): + # if not batch.done[i]: + # returns[i] += self._gamma * last + # last = returns[i] + # return returns - def _vectorized_returns(self, batch): - # according to my tests, it is slower than _vanilla_returns - # import scipy.signal - convolve = np.convolve - # convolve = scipy.signal.convolve - rew = batch.rew[::-1] - batch_size = len(rew) - gammas = self._gamma ** np.arange(batch_size) - c = convolve(rew, gammas)[:batch_size] - T = np.where(batch.done[::-1])[0] - d = np.zeros_like(rew) - d[T] += c[T] - rew[T] - d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T) - return (c - convolve(d, gammas)[:batch_size])[::-1] + # def _vectorized_returns(self, batch): + # # according to my tests, it is slower than _vanilla_returns + # # import scipy.signal + # convolve = np.convolve + # # convolve = scipy.signal.convolve + # rew = batch.rew[::-1] + # batch_size = len(rew) + # gammas = self._gamma ** np.arange(batch_size) + # c = convolve(rew, gammas)[:batch_size] + # T = np.where(batch.done[::-1])[0] + # d = np.zeros_like(rew) + # d[T] += c[T] - rew[T] + # d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T) + # return (c - convolve(d, gammas)[:batch_size])[::-1] diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index ba64a6b..e230c5a 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -26,6 +26,8 @@ class PPOPolicy(PGPolicy): :param float ent_coef: weight for entropy loss, defaults to 0.01. :param action_range: the action range (minimum, maximum). :type action_range: [float, float] + :param float gae_lambda: in [0, 1], param for Generalized Advantage + Estimation, defaults to 0.95. .. seealso:: @@ -40,6 +42,7 @@ class PPOPolicy(PGPolicy): vf_coef=.5, ent_coef=.0, action_range=None, + gae_lambda=0.95, **kwargs): super().__init__(None, None, dist_fn, discount_factor, **kwargs) self._max_grad_norm = max_grad_norm @@ -52,6 +55,9 @@ class PPOPolicy(PGPolicy): self.critic, self.critic_old = critic, deepcopy(critic) self.critic_old.eval() self.optim = optim + self._batch = 64 + assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].' + self._lambda = gae_lambda def train(self): """Set the module in training mode, except for the target network.""" @@ -65,6 +71,19 @@ class PPOPolicy(PGPolicy): self.actor.eval() self.critic.eval() + def process_fn(self, batch, buffer, indice): + if self._lambda in [0, 1]: + return self.compute_episodic_return( + batch, None, gamma=self._gamma, gae_lambda=self._lambda) + v_ = [] + with torch.no_grad(): + for b in batch.split(self._batch * 4, permute=False): + v_.append(self.critic(b.obs_next).detach().cpu().numpy()) + v_ = np.concatenate(v_, axis=0) + batch.v_ = v_ + return self.compute_episodic_return( + batch, v_, gamma=self._gamma, gae_lambda=self._lambda) + def forward(self, batch, state=None, model='actor', **kwargs): """Compute action over the given batch data. @@ -97,18 +116,20 @@ class PPOPolicy(PGPolicy): self.critic_old.load_state_dict(self.critic.state_dict()) def learn(self, batch, batch_size=None, repeat=1, **kwargs): + self._batch = batch_size losses, clip_losses, vf_losses, ent_losses = [], [], [], [] r = batch.returns batch.returns = (r - r.mean()) / (r.std() + self._eps) batch.act = torch.tensor(batch.act) batch.returns = torch.tensor(batch.returns)[:, None] + batch.v_ = torch.tensor(batch.v_) for _ in range(repeat): for b in batch.split(batch_size): - vs_old, vs__old = self.critic_old(np.concatenate([ - b.obs, b.obs_next])).split(b.obs.shape[0]) + vs_old = self.critic_old(b.obs) + vs__old = b.v_.to(vs_old.device) dist = self(b).dist dist_old = self(b, model='actor_old').dist - target_v = b.returns.to(vs__old.device) + self._gamma * vs__old + target_v = b.returns.to(vs_old.device) + self._gamma * vs__old adv = (target_v - vs_old).detach() a = b.act.to(adv.device) ratio = torch.exp(dist.log_prob(a) - dist_old.log_prob(a))