gae
This commit is contained in:
parent
7b65d43394
commit
680fc0ffbe
@ -26,6 +26,7 @@
|
|||||||
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
|
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
|
||||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
||||||
- Vanilla Imitation Learning
|
- Vanilla Imitation Learning
|
||||||
|
- [Generalized Advantage Estimation (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
||||||
|
|
||||||
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development.
|
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development.
|
||||||
|
|
||||||
|
@ -16,7 +16,8 @@ Welcome to Tianshou!
|
|||||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||||
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
||||||
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
||||||
* :class:`~tianshou.policy.ImitationPolicy`
|
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
||||||
|
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimation <https://arxiv.org/pdf/1506.02438.pdf>`_
|
||||||
|
|
||||||
|
|
||||||
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
|
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
|
||||||
|
@ -12,8 +12,6 @@ else: # pytest
|
|||||||
|
|
||||||
|
|
||||||
class MyPolicy(BasePolicy):
|
class MyPolicy(BasePolicy):
|
||||||
"""docstring for MyPolicy"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -6,8 +6,8 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import PPOPolicy
|
|
||||||
from tianshou.env import VectorEnv
|
from tianshou.env import VectorEnv
|
||||||
|
from tianshou.policy import PPOPolicy
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
|
||||||
@ -22,15 +22,15 @@ def get_args():
|
|||||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=3e-4)
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
parser.add_argument('--epoch', type=int, default=20)
|
parser.add_argument('--epoch', type=int, default=20)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--collect-per-step', type=int, default=1)
|
||||||
parser.add_argument('--repeat-per-collect', type=int, default=10)
|
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--layer-num', type=int, default=1)
|
parser.add_argument('--layer-num', type=int, default=2)
|
||||||
parser.add_argument('--training-num', type=int, default=16)
|
parser.add_argument('--training-num', type=int, default=8)
|
||||||
parser.add_argument('--test-num', type=int, default=100)
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
parser.add_argument('--logdir', type=str, default='log')
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
parser.add_argument('--render', type=float, default=0.)
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
@ -42,6 +42,7 @@ def get_args():
|
|||||||
parser.add_argument('--ent-coef', type=float, default=0.0)
|
parser.add_argument('--ent-coef', type=float, default=0.0)
|
||||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||||
|
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||||
args = parser.parse_known_args()[0]
|
args = parser.parse_known_args()[0]
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@ -84,12 +85,12 @@ def _test_ppo(args=get_args()):
|
|||||||
eps_clip=args.eps_clip,
|
eps_clip=args.eps_clip,
|
||||||
vf_coef=args.vf_coef,
|
vf_coef=args.vf_coef,
|
||||||
ent_coef=args.ent_coef,
|
ent_coef=args.ent_coef,
|
||||||
action_range=[env.action_space.low[0], env.action_space.high[0]])
|
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||||
|
gae_lambda=args.gae_lambda)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
train_collector.collect(n_step=args.step_per_epoch)
|
|
||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
@ -24,7 +24,7 @@ def get_args():
|
|||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=3e-4)
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||||
@ -41,6 +41,7 @@ def get_args():
|
|||||||
parser.add_argument('--vf-coef', type=float, default=0.5)
|
parser.add_argument('--vf-coef', type=float, default=0.5)
|
||||||
parser.add_argument('--ent-coef', type=float, default=0.001)
|
parser.add_argument('--ent-coef', type=float, default=0.001)
|
||||||
parser.add_argument('--max-grad-norm', type=float, default=None)
|
parser.add_argument('--max-grad-norm', type=float, default=None)
|
||||||
|
parser.add_argument('--gae-lambda', type=float, default=1.)
|
||||||
args = parser.parse_known_args()[0]
|
args = parser.parse_known_args()[0]
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@ -70,8 +71,9 @@ def test_a2c(args=get_args()):
|
|||||||
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||||
dist = torch.distributions.Categorical
|
dist = torch.distributions.Categorical
|
||||||
policy = A2CPolicy(
|
policy = A2CPolicy(
|
||||||
actor, critic, optim, dist, args.gamma, vf_coef=args.vf_coef,
|
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
|
||||||
ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm)
|
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
|
||||||
|
max_grad_norm=args.max_grad_norm)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
|
@ -28,7 +28,7 @@ def get_args():
|
|||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--n-step', type=int, default=3)
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
@ -29,7 +29,7 @@ def get_args():
|
|||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--n-step', type=int, default=4)
|
parser.add_argument('--n-step', type=int, default=4)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
@ -80,7 +80,7 @@ def get_args():
|
|||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=3e-4)
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||||
|
@ -23,7 +23,7 @@ def get_args():
|
|||||||
parser.add_argument('--seed', type=int, default=1626)
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=1e-3)
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
@ -42,6 +42,7 @@ def get_args():
|
|||||||
parser.add_argument('--ent-coef', type=float, default=0.0)
|
parser.add_argument('--ent-coef', type=float, default=0.0)
|
||||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||||
|
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||||
args = parser.parse_known_args()[0]
|
args = parser.parse_known_args()[0]
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@ -76,7 +77,8 @@ def test_ppo(args=get_args()):
|
|||||||
eps_clip=args.eps_clip,
|
eps_clip=args.eps_clip,
|
||||||
vf_coef=args.vf_coef,
|
vf_coef=args.vf_coef,
|
||||||
ent_coef=args.ent_coef,
|
ent_coef=args.ent_coef,
|
||||||
action_range=None)
|
action_range=None,
|
||||||
|
gae_lambda=args.gae_lambda)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
@ -74,3 +75,34 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
:return: A dict which includes loss and its corresponding label.
|
:return: A dict which includes loss and its corresponding label.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def compute_episodic_return(self, batch, v_s_=None,
|
||||||
|
gamma=0.99, gae_lambda=0.95):
|
||||||
|
"""Compute returns over given full-length episodes, including the
|
||||||
|
implementation of Generalized Advantage Estimation (arXiv:1506.02438).
|
||||||
|
|
||||||
|
:param batch: a data batch which contains several full-episode data
|
||||||
|
chronologically.
|
||||||
|
:type batch: :class:`~tianshou.data.Batch`
|
||||||
|
:param v_s_: the value function of all next states :math:`V(s')`.
|
||||||
|
:type v_s_: numpy.ndarray
|
||||||
|
:param float gamma: the discount factor, should be in [0, 1], defaults
|
||||||
|
to 0.99.
|
||||||
|
:param float gae_lambda: the parameter for Generalized Advantage
|
||||||
|
Estimation, should be in [0, 1], defaults to 0.95.
|
||||||
|
"""
|
||||||
|
if v_s_ is None:
|
||||||
|
v_s_ = np.zeros_like(batch.rew)
|
||||||
|
if not isinstance(v_s_, np.ndarray):
|
||||||
|
v_s_ = np.array(v_s_, np.float)
|
||||||
|
else:
|
||||||
|
v_s_ = v_s_.flatten()
|
||||||
|
batch.returns = np.roll(v_s_, 1)
|
||||||
|
m = (1. - batch.done) * gamma
|
||||||
|
delta = batch.rew + v_s_ * m - batch.returns
|
||||||
|
m *= gae_lambda
|
||||||
|
gae = 0.
|
||||||
|
for i in range(len(batch.rew) - 1, -1, -1):
|
||||||
|
gae = delta[i] + m[i] * gae
|
||||||
|
batch.returns[i] += gae
|
||||||
|
return batch
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
@ -21,6 +22,8 @@ class A2CPolicy(PGPolicy):
|
|||||||
:param float ent_coef: weight for entropy loss, defaults to 0.01.
|
:param float ent_coef: weight for entropy loss, defaults to 0.01.
|
||||||
:param float max_grad_norm: clipping gradients in back propagation,
|
:param float max_grad_norm: clipping gradients in back propagation,
|
||||||
defaults to ``None``.
|
defaults to ``None``.
|
||||||
|
:param float gae_lambda: in [0, 1], param for Generalized Advantage
|
||||||
|
Estimation, defaults to 0.95.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
@ -31,13 +34,28 @@ class A2CPolicy(PGPolicy):
|
|||||||
def __init__(self, actor, critic, optim,
|
def __init__(self, actor, critic, optim,
|
||||||
dist_fn=torch.distributions.Categorical,
|
dist_fn=torch.distributions.Categorical,
|
||||||
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
|
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
|
||||||
max_grad_norm=None, **kwargs):
|
max_grad_norm=None, gae_lambda=0.95, **kwargs):
|
||||||
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
|
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
|
||||||
self.actor = actor
|
self.actor = actor
|
||||||
self.critic = critic
|
self.critic = critic
|
||||||
|
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
|
||||||
|
self._lambda = gae_lambda
|
||||||
self._w_vf = vf_coef
|
self._w_vf = vf_coef
|
||||||
self._w_ent = ent_coef
|
self._w_ent = ent_coef
|
||||||
self._grad_norm = max_grad_norm
|
self._grad_norm = max_grad_norm
|
||||||
|
self._batch = 64
|
||||||
|
|
||||||
|
def process_fn(self, batch, buffer, indice):
|
||||||
|
if self._lambda in [0, 1]:
|
||||||
|
return self.compute_episodic_return(
|
||||||
|
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||||
|
v_ = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for b in batch.split(self._batch * 4, permute=False):
|
||||||
|
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
|
||||||
|
v_ = np.concatenate(v_, axis=0)
|
||||||
|
return self.compute_episodic_return(
|
||||||
|
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||||
|
|
||||||
def forward(self, batch, state=None, **kwargs):
|
def forward(self, batch, state=None, **kwargs):
|
||||||
"""Compute action over the given batch data.
|
"""Compute action over the given batch data.
|
||||||
@ -63,6 +81,7 @@ class A2CPolicy(PGPolicy):
|
|||||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||||
|
self._batch = batch_size
|
||||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
@ -70,12 +89,11 @@ class A2CPolicy(PGPolicy):
|
|||||||
result = self(b)
|
result = self(b)
|
||||||
dist = result.dist
|
dist = result.dist
|
||||||
v = self.critic(b.obs)
|
v = self.critic(b.obs)
|
||||||
a = torch.tensor(b.act, device=dist.logits.device)
|
a = torch.tensor(b.act, device=v.device)
|
||||||
r = torch.tensor(b.returns, device=dist.logits.device)
|
r = torch.tensor(b.returns, device=v.device)
|
||||||
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
||||||
vf_loss = F.mse_loss(r[:, None], v)
|
vf_loss = F.mse_loss(r[:, None], v)
|
||||||
ent_loss = dist.entropy().mean()
|
ent_loss = dist.entropy().mean()
|
||||||
|
|
||||||
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
|
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if self._grad_norm:
|
if self._grad_norm:
|
||||||
|
@ -39,9 +39,11 @@ class PGPolicy(BasePolicy):
|
|||||||
, where :math:`T` is the terminal time step, :math:`\gamma` is the
|
, where :math:`T` is the terminal time step, :math:`\gamma` is the
|
||||||
discount factor, :math:`\gamma \in [0, 1]`.
|
discount factor, :math:`\gamma \in [0, 1]`.
|
||||||
"""
|
"""
|
||||||
batch.returns = self._vanilla_returns(batch)
|
# batch.returns = self._vanilla_returns(batch)
|
||||||
# batch.returns = self._vectorized_returns(batch)
|
# batch.returns = self._vectorized_returns(batch)
|
||||||
return batch
|
# return batch
|
||||||
|
return self.compute_episodic_return(
|
||||||
|
batch, gamma=self._gamma, gae_lambda=1.)
|
||||||
|
|
||||||
def forward(self, batch, state=None, **kwargs):
|
def forward(self, batch, state=None, **kwargs):
|
||||||
"""Compute action over the given batch data.
|
"""Compute action over the given batch data.
|
||||||
@ -82,26 +84,26 @@ class PGPolicy(BasePolicy):
|
|||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
return {'loss': losses}
|
return {'loss': losses}
|
||||||
|
|
||||||
def _vanilla_returns(self, batch):
|
# def _vanilla_returns(self, batch):
|
||||||
returns = batch.rew[:]
|
# returns = batch.rew[:]
|
||||||
last = 0
|
# last = 0
|
||||||
for i in range(len(returns) - 1, -1, -1):
|
# for i in range(len(returns) - 1, -1, -1):
|
||||||
if not batch.done[i]:
|
# if not batch.done[i]:
|
||||||
returns[i] += self._gamma * last
|
# returns[i] += self._gamma * last
|
||||||
last = returns[i]
|
# last = returns[i]
|
||||||
return returns
|
# return returns
|
||||||
|
|
||||||
def _vectorized_returns(self, batch):
|
# def _vectorized_returns(self, batch):
|
||||||
# according to my tests, it is slower than _vanilla_returns
|
# # according to my tests, it is slower than _vanilla_returns
|
||||||
# import scipy.signal
|
# # import scipy.signal
|
||||||
convolve = np.convolve
|
# convolve = np.convolve
|
||||||
# convolve = scipy.signal.convolve
|
# # convolve = scipy.signal.convolve
|
||||||
rew = batch.rew[::-1]
|
# rew = batch.rew[::-1]
|
||||||
batch_size = len(rew)
|
# batch_size = len(rew)
|
||||||
gammas = self._gamma ** np.arange(batch_size)
|
# gammas = self._gamma ** np.arange(batch_size)
|
||||||
c = convolve(rew, gammas)[:batch_size]
|
# c = convolve(rew, gammas)[:batch_size]
|
||||||
T = np.where(batch.done[::-1])[0]
|
# T = np.where(batch.done[::-1])[0]
|
||||||
d = np.zeros_like(rew)
|
# d = np.zeros_like(rew)
|
||||||
d[T] += c[T] - rew[T]
|
# d[T] += c[T] - rew[T]
|
||||||
d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T)
|
# d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T)
|
||||||
return (c - convolve(d, gammas)[:batch_size])[::-1]
|
# return (c - convolve(d, gammas)[:batch_size])[::-1]
|
||||||
|
@ -26,6 +26,8 @@ class PPOPolicy(PGPolicy):
|
|||||||
:param float ent_coef: weight for entropy loss, defaults to 0.01.
|
:param float ent_coef: weight for entropy loss, defaults to 0.01.
|
||||||
:param action_range: the action range (minimum, maximum).
|
:param action_range: the action range (minimum, maximum).
|
||||||
:type action_range: [float, float]
|
:type action_range: [float, float]
|
||||||
|
:param float gae_lambda: in [0, 1], param for Generalized Advantage
|
||||||
|
Estimation, defaults to 0.95.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
@ -40,6 +42,7 @@ class PPOPolicy(PGPolicy):
|
|||||||
vf_coef=.5,
|
vf_coef=.5,
|
||||||
ent_coef=.0,
|
ent_coef=.0,
|
||||||
action_range=None,
|
action_range=None,
|
||||||
|
gae_lambda=0.95,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
||||||
self._max_grad_norm = max_grad_norm
|
self._max_grad_norm = max_grad_norm
|
||||||
@ -52,6 +55,9 @@ class PPOPolicy(PGPolicy):
|
|||||||
self.critic, self.critic_old = critic, deepcopy(critic)
|
self.critic, self.critic_old = critic, deepcopy(critic)
|
||||||
self.critic_old.eval()
|
self.critic_old.eval()
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
|
self._batch = 64
|
||||||
|
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
|
||||||
|
self._lambda = gae_lambda
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
"""Set the module in training mode, except for the target network."""
|
"""Set the module in training mode, except for the target network."""
|
||||||
@ -65,6 +71,19 @@ class PPOPolicy(PGPolicy):
|
|||||||
self.actor.eval()
|
self.actor.eval()
|
||||||
self.critic.eval()
|
self.critic.eval()
|
||||||
|
|
||||||
|
def process_fn(self, batch, buffer, indice):
|
||||||
|
if self._lambda in [0, 1]:
|
||||||
|
return self.compute_episodic_return(
|
||||||
|
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||||
|
v_ = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for b in batch.split(self._batch * 4, permute=False):
|
||||||
|
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
|
||||||
|
v_ = np.concatenate(v_, axis=0)
|
||||||
|
batch.v_ = v_
|
||||||
|
return self.compute_episodic_return(
|
||||||
|
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||||
|
|
||||||
def forward(self, batch, state=None, model='actor', **kwargs):
|
def forward(self, batch, state=None, model='actor', **kwargs):
|
||||||
"""Compute action over the given batch data.
|
"""Compute action over the given batch data.
|
||||||
|
|
||||||
@ -97,18 +116,20 @@ class PPOPolicy(PGPolicy):
|
|||||||
self.critic_old.load_state_dict(self.critic.state_dict())
|
self.critic_old.load_state_dict(self.critic.state_dict())
|
||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||||
|
self._batch = batch_size
|
||||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||||
r = batch.returns
|
r = batch.returns
|
||||||
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
||||||
batch.act = torch.tensor(batch.act)
|
batch.act = torch.tensor(batch.act)
|
||||||
batch.returns = torch.tensor(batch.returns)[:, None]
|
batch.returns = torch.tensor(batch.returns)[:, None]
|
||||||
|
batch.v_ = torch.tensor(batch.v_)
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
vs_old, vs__old = self.critic_old(np.concatenate([
|
vs_old = self.critic_old(b.obs)
|
||||||
b.obs, b.obs_next])).split(b.obs.shape[0])
|
vs__old = b.v_.to(vs_old.device)
|
||||||
dist = self(b).dist
|
dist = self(b).dist
|
||||||
dist_old = self(b, model='actor_old').dist
|
dist_old = self(b, model='actor_old').dist
|
||||||
target_v = b.returns.to(vs__old.device) + self._gamma * vs__old
|
target_v = b.returns.to(vs_old.device) + self._gamma * vs__old
|
||||||
adv = (target_v - vs_old).detach()
|
adv = (target_v - vs_old).detach()
|
||||||
a = b.act.to(adv.device)
|
a = b.act.to(adv.device)
|
||||||
ratio = torch.exp(dist.log_prob(a) - dist_old.log_prob(a))
|
ratio = torch.exp(dist.log_prob(a) - dist_old.log_prob(a))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user