2020-03-19 17:23:46 +08:00
|
|
|
import torch
|
2020-03-20 19:52:29 +08:00
|
|
|
import numpy as np
|
2020-03-19 17:23:46 +08:00
|
|
|
from torch import nn
|
2020-05-12 11:31:47 +08:00
|
|
|
from typing import Dict, List, Tuple, Union, Optional
|
2020-03-19 17:23:46 +08:00
|
|
|
|
|
|
|
from tianshou.policy import PGPolicy
|
2020-06-03 13:59:47 +08:00
|
|
|
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
|
2020-03-19 17:23:46 +08:00
|
|
|
|
|
|
|
|
|
|
|
class PPOPolicy(PGPolicy):
|
2020-09-11 07:55:37 +08:00
|
|
|
r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347.
|
2020-04-06 19:36:59 +08:00
|
|
|
|
|
|
|
:param torch.nn.Module actor: the actor network following the rules in
|
|
|
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
|
|
|
:param torch.nn.Module critic: the critic network. (s -> V(s))
|
|
|
|
:param torch.optim.Optimizer optim: the optimizer for actor and critic
|
|
|
|
network.
|
2020-09-11 07:55:37 +08:00
|
|
|
:param dist_fn: distribution class for computing the action.
|
2020-04-06 19:36:59 +08:00
|
|
|
:param float discount_factor: in [0, 1], defaults to 0.99.
|
|
|
|
:param float max_grad_norm: clipping gradients in back propagation,
|
2020-09-11 07:55:37 +08:00
|
|
|
defaults to None.
|
2020-04-06 19:36:59 +08:00
|
|
|
:param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original
|
|
|
|
paper, defaults to 0.2.
|
|
|
|
:param float vf_coef: weight for value loss, defaults to 0.5.
|
|
|
|
:param float ent_coef: weight for entropy loss, defaults to 0.01.
|
|
|
|
:param action_range: the action range (minimum, maximum).
|
2020-05-12 11:31:47 +08:00
|
|
|
:type action_range: (float, float)
|
2020-04-14 21:11:06 +08:00
|
|
|
:param float gae_lambda: in [0, 1], param for Generalized Advantage
|
|
|
|
Estimation, defaults to 0.95.
|
2020-04-19 14:30:42 +08:00
|
|
|
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
|
|
|
|
where c > 1 is a constant indicating the lower bound,
|
|
|
|
defaults to 5.0 (set ``None`` if you do not want to use it).
|
|
|
|
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1,
|
2020-09-11 07:55:37 +08:00
|
|
|
defaults to True.
|
2020-04-19 14:30:42 +08:00
|
|
|
:param bool reward_normalization: normalize the returns to Normal(0, 1),
|
2020-09-11 07:55:37 +08:00
|
|
|
defaults to True.
|
2020-08-27 12:15:18 +08:00
|
|
|
:param int max_batchsize: the maximum size of the batch when computing GAE,
|
|
|
|
depends on the size of available memory and the memory cost of the
|
|
|
|
model; should be as large as possible within the memory constraint;
|
|
|
|
defaults to 256.
|
2020-04-09 21:36:53 +08:00
|
|
|
|
|
|
|
.. seealso::
|
|
|
|
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
|
|
explanation.
|
2020-04-06 19:36:59 +08:00
|
|
|
"""
|
2020-03-19 17:23:46 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def __init__(self,
|
|
|
|
actor: torch.nn.Module,
|
|
|
|
critic: torch.nn.Module,
|
|
|
|
optim: torch.optim.Optimizer,
|
|
|
|
dist_fn: torch.distributions.Distribution,
|
2020-05-16 20:08:32 +08:00
|
|
|
discount_factor: float = 0.99,
|
2020-05-12 11:31:47 +08:00
|
|
|
max_grad_norm: Optional[float] = None,
|
2020-05-16 20:08:32 +08:00
|
|
|
eps_clip: float = .2,
|
|
|
|
vf_coef: float = .5,
|
|
|
|
ent_coef: float = .01,
|
2020-05-12 11:31:47 +08:00
|
|
|
action_range: Optional[Tuple[float, float]] = None,
|
2020-05-16 20:08:32 +08:00
|
|
|
gae_lambda: float = 0.95,
|
2020-05-27 11:02:23 +08:00
|
|
|
dual_clip: Optional[float] = None,
|
2020-05-16 20:08:32 +08:00
|
|
|
value_clip: bool = True,
|
|
|
|
reward_normalization: bool = True,
|
2020-08-27 12:15:18 +08:00
|
|
|
max_batchsize: int = 256,
|
2020-05-12 11:31:47 +08:00
|
|
|
**kwargs) -> None:
|
2020-04-08 21:13:15 +08:00
|
|
|
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
2020-03-20 19:52:29 +08:00
|
|
|
self._max_grad_norm = max_grad_norm
|
2020-03-19 17:23:46 +08:00
|
|
|
self._eps_clip = eps_clip
|
2020-03-20 19:52:29 +08:00
|
|
|
self._w_vf = vf_coef
|
|
|
|
self._w_ent = ent_coef
|
|
|
|
self._range = action_range
|
2020-04-19 14:30:42 +08:00
|
|
|
self.actor = actor
|
|
|
|
self.critic = critic
|
2020-03-20 19:52:29 +08:00
|
|
|
self.optim = optim
|
2020-08-27 12:15:18 +08:00
|
|
|
self._batch = max_batchsize
|
2020-04-14 21:11:06 +08:00
|
|
|
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
|
|
|
|
self._lambda = gae_lambda
|
2020-04-19 14:30:42 +08:00
|
|
|
assert dual_clip is None or dual_clip > 1, \
|
|
|
|
'Dual-clip PPO parameter should greater than 1.'
|
|
|
|
self._dual_clip = dual_clip
|
|
|
|
self._value_clip = value_clip
|
|
|
|
self._rew_norm = reward_normalization
|
2020-03-20 19:52:29 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
|
|
|
indice: np.ndarray) -> Batch:
|
2020-04-19 14:30:42 +08:00
|
|
|
if self._rew_norm:
|
|
|
|
mean, std = batch.rew.mean(), batch.rew.std()
|
2020-08-15 16:10:42 +08:00
|
|
|
if not np.isclose(std, 0, 1e-2):
|
2020-04-19 14:30:42 +08:00
|
|
|
batch.rew = (batch.rew - mean) / std
|
2020-08-15 16:10:42 +08:00
|
|
|
v, v_, old_log_prob = [], [], []
|
2020-04-14 21:11:06 +08:00
|
|
|
with torch.no_grad():
|
2020-08-27 12:15:18 +08:00
|
|
|
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
2020-04-19 14:30:42 +08:00
|
|
|
v_.append(self.critic(b.obs_next))
|
2020-08-15 16:10:42 +08:00
|
|
|
v.append(self.critic(b.obs))
|
|
|
|
old_log_prob.append(self(b).dist.log_prob(
|
|
|
|
to_torch_as(b.act, v[0])))
|
2020-05-29 14:45:21 +02:00
|
|
|
v_ = to_numpy(torch.cat(v_, dim=0))
|
2020-08-15 16:10:42 +08:00
|
|
|
batch = self.compute_episodic_return(
|
|
|
|
batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
|
|
|
|
rew_norm=self._rew_norm)
|
|
|
|
batch.v = torch.cat(v, dim=0).flatten() # old value
|
|
|
|
batch.act = to_torch_as(batch.act, v[0])
|
|
|
|
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
|
|
|
batch.returns = to_torch_as(batch.returns, v[0])
|
|
|
|
batch.adv = batch.returns - batch.v
|
|
|
|
if self._rew_norm:
|
|
|
|
mean, std = batch.adv.mean(), batch.adv.std()
|
|
|
|
if not np.isclose(std.item(), 0, 1e-2):
|
|
|
|
batch.adv = (batch.adv - mean) / std
|
|
|
|
return batch
|
2020-04-14 21:11:06 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def forward(self, batch: Batch,
|
|
|
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
|
|
|
**kwargs) -> Batch:
|
2020-04-06 19:36:59 +08:00
|
|
|
"""Compute action over the given batch data.
|
|
|
|
|
|
|
|
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
|
|
|
|
|
|
|
* ``act`` the action.
|
|
|
|
* ``logits`` the network's raw output.
|
|
|
|
* ``dist`` the action distribution.
|
|
|
|
* ``state`` the hidden state.
|
|
|
|
|
2020-04-09 21:36:53 +08:00
|
|
|
.. seealso::
|
|
|
|
|
2020-04-10 10:47:16 +08:00
|
|
|
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
2020-04-09 21:36:53 +08:00
|
|
|
more detailed explanation.
|
2020-04-06 19:36:59 +08:00
|
|
|
"""
|
2020-04-19 14:30:42 +08:00
|
|
|
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
2020-03-20 19:52:29 +08:00
|
|
|
if isinstance(logits, tuple):
|
|
|
|
dist = self.dist_fn(*logits)
|
|
|
|
else:
|
|
|
|
dist = self.dist_fn(logits)
|
2020-03-19 17:23:46 +08:00
|
|
|
act = dist.sample()
|
2020-03-20 19:52:29 +08:00
|
|
|
if self._range:
|
|
|
|
act = act.clamp(self._range[0], self._range[1])
|
|
|
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def learn(self, batch: Batch, batch_size: int, repeat: int,
|
|
|
|
**kwargs) -> Dict[str, List[float]]:
|
2020-03-20 19:52:29 +08:00
|
|
|
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
|
|
|
for _ in range(repeat):
|
2020-08-27 12:15:18 +08:00
|
|
|
for b in batch.split(batch_size, merge_last=True):
|
2020-03-20 19:52:29 +08:00
|
|
|
dist = self(b).dist
|
2020-07-23 15:12:02 +08:00
|
|
|
value = self.critic(b.obs).flatten()
|
2020-07-24 17:38:12 +08:00
|
|
|
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
|
|
|
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
2020-04-19 14:30:42 +08:00
|
|
|
surr1 = ratio * b.adv
|
2020-07-23 15:12:02 +08:00
|
|
|
surr2 = ratio.clamp(1. - self._eps_clip,
|
|
|
|
1. + self._eps_clip) * b.adv
|
2020-04-19 14:30:42 +08:00
|
|
|
if self._dual_clip:
|
|
|
|
clip_loss = -torch.max(torch.min(surr1, surr2),
|
|
|
|
self._dual_clip * b.adv).mean()
|
|
|
|
else:
|
|
|
|
clip_loss = -torch.min(surr1, surr2).mean()
|
2020-04-03 21:28:12 +08:00
|
|
|
clip_losses.append(clip_loss.item())
|
2020-04-19 14:30:42 +08:00
|
|
|
if self._value_clip:
|
|
|
|
v_clip = b.v + (value - b.v).clamp(
|
|
|
|
-self._eps_clip, self._eps_clip)
|
|
|
|
vf1 = (b.returns - value).pow(2)
|
|
|
|
vf2 = (b.returns - v_clip).pow(2)
|
|
|
|
vf_loss = .5 * torch.max(vf1, vf2).mean()
|
|
|
|
else:
|
|
|
|
vf_loss = .5 * (b.returns - value).pow(2).mean()
|
2020-04-03 21:28:12 +08:00
|
|
|
vf_losses.append(vf_loss.item())
|
2020-03-26 11:42:34 +08:00
|
|
|
e_loss = dist.entropy().mean()
|
2020-04-03 21:28:12 +08:00
|
|
|
ent_losses.append(e_loss.item())
|
2020-03-26 11:42:34 +08:00
|
|
|
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss
|
2020-04-03 21:28:12 +08:00
|
|
|
losses.append(loss.item())
|
2020-03-20 19:52:29 +08:00
|
|
|
self.optim.zero_grad()
|
|
|
|
loss.backward()
|
|
|
|
nn.utils.clip_grad_norm_(list(
|
|
|
|
self.actor.parameters()) + list(self.critic.parameters()),
|
2020-04-03 21:28:12 +08:00
|
|
|
self._max_grad_norm)
|
2020-03-20 19:52:29 +08:00
|
|
|
self.optim.step()
|
|
|
|
return {
|
|
|
|
'loss': losses,
|
|
|
|
'loss/clip': clip_losses,
|
|
|
|
'loss/vf': vf_losses,
|
|
|
|
'loss/ent': ent_losses,
|
|
|
|
}
|