2020-03-19 17:23:46 +08:00
|
|
|
import torch
|
2020-03-20 19:52:29 +08:00
|
|
|
import numpy as np
|
2020-03-19 17:23:46 +08:00
|
|
|
from torch import nn
|
2021-03-23 22:05:48 +08:00
|
|
|
from typing import Any, Dict, List, Type, Optional
|
2020-03-19 17:23:46 +08:00
|
|
|
|
2021-03-23 22:05:48 +08:00
|
|
|
from tianshou.policy import A2CPolicy
|
2020-06-03 13:59:47 +08:00
|
|
|
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
|
2020-03-19 17:23:46 +08:00
|
|
|
|
|
|
|
|
2021-03-23 22:05:48 +08:00
|
|
|
class PPOPolicy(A2CPolicy):
|
2020-09-11 07:55:37 +08:00
|
|
|
r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347.
|
2020-04-06 19:36:59 +08:00
|
|
|
|
|
|
|
:param torch.nn.Module actor: the actor network following the rules in
|
|
|
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
|
|
|
:param torch.nn.Module critic: the critic network. (s -> V(s))
|
2021-02-27 11:20:43 +08:00
|
|
|
:param torch.optim.Optimizer optim: the optimizer for actor and critic network.
|
2020-09-11 07:55:37 +08:00
|
|
|
:param dist_fn: distribution class for computing the action.
|
2021-02-27 11:20:43 +08:00
|
|
|
:type dist_fn: Type[torch.distributions.Distribution]
|
|
|
|
:param float discount_factor: in [0, 1]. Default to 0.99.
|
|
|
|
:param float max_grad_norm: clipping gradients in back propagation.
|
|
|
|
Default to None.
|
2020-04-06 19:36:59 +08:00
|
|
|
:param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original
|
2021-02-27 11:20:43 +08:00
|
|
|
paper. Default to 0.2.
|
|
|
|
:param float vf_coef: weight for value loss. Default to 0.5.
|
|
|
|
:param float ent_coef: weight for entropy loss. Default to 0.01.
|
2020-04-14 21:11:06 +08:00
|
|
|
:param float gae_lambda: in [0, 1], param for Generalized Advantage
|
2021-02-27 11:20:43 +08:00
|
|
|
Estimation. Default to 0.95.
|
2020-04-19 14:30:42 +08:00
|
|
|
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
|
2021-02-27 11:20:43 +08:00
|
|
|
where c > 1 is a constant indicating the lower bound.
|
|
|
|
Default to 5.0 (set None if you do not want to use it).
|
|
|
|
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1.
|
|
|
|
Default to True.
|
2021-03-23 22:05:48 +08:00
|
|
|
:param bool reward_normalization: normalize the returns and advantage to
|
|
|
|
Normal(0, 1). Default to False.
|
2020-08-27 12:15:18 +08:00
|
|
|
:param int max_batchsize: the maximum size of the batch when computing GAE,
|
|
|
|
depends on the size of available memory and the memory cost of the
|
2021-02-27 11:20:43 +08:00
|
|
|
model; should be as large as possible within the memory constraint.
|
|
|
|
Default to 256.
|
2021-03-21 16:45:50 +08:00
|
|
|
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
|
|
|
[action_spaces.low, action_spaces.high]. Default to True.
|
|
|
|
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
|
|
|
either "clip" (for simply clipping the action), "tanh" (for applying tanh
|
|
|
|
squashing) for now, or empty string for no bounding. Default to "clip".
|
|
|
|
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
|
|
|
to use option "action_scaling" or "action_bound_method". Default to None.
|
2021-03-22 16:57:24 +08:00
|
|
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
|
|
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
2020-04-09 21:36:53 +08:00
|
|
|
|
|
|
|
.. seealso::
|
|
|
|
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
|
|
explanation.
|
2020-04-06 19:36:59 +08:00
|
|
|
"""
|
2020-03-19 17:23:46 +08:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
actor: torch.nn.Module,
|
|
|
|
critic: torch.nn.Module,
|
|
|
|
optim: torch.optim.Optimizer,
|
2021-02-27 11:20:43 +08:00
|
|
|
dist_fn: Type[torch.distributions.Distribution],
|
2020-09-12 15:39:01 +08:00
|
|
|
max_grad_norm: Optional[float] = None,
|
|
|
|
eps_clip: float = 0.2,
|
|
|
|
vf_coef: float = 0.5,
|
|
|
|
ent_coef: float = 0.01,
|
|
|
|
gae_lambda: float = 0.95,
|
|
|
|
dual_clip: Optional[float] = None,
|
|
|
|
value_clip: bool = True,
|
|
|
|
max_batchsize: int = 256,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> None:
|
2021-03-23 22:05:48 +08:00
|
|
|
super().__init__(
|
|
|
|
actor, critic, optim, dist_fn, max_grad_norm=max_grad_norm,
|
|
|
|
vf_coef=vf_coef, ent_coef=ent_coef, gae_lambda=gae_lambda,
|
|
|
|
max_batchsize=max_batchsize, **kwargs)
|
2020-03-19 17:23:46 +08:00
|
|
|
self._eps_clip = eps_clip
|
2021-02-27 11:20:43 +08:00
|
|
|
assert dual_clip is None or dual_clip > 1.0, \
|
|
|
|
"Dual-clip PPO parameter should greater than 1.0."
|
2020-04-19 14:30:42 +08:00
|
|
|
self._dual_clip = dual_clip
|
|
|
|
self._value_clip = value_clip
|
2020-03-20 19:52:29 +08:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def process_fn(
|
|
|
|
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
|
|
|
) -> Batch:
|
2021-03-23 22:05:48 +08:00
|
|
|
v_s, v_s_, old_log_prob = [], [], []
|
2020-04-14 21:11:06 +08:00
|
|
|
with torch.no_grad():
|
2020-08-27 12:15:18 +08:00
|
|
|
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
2021-03-23 22:05:48 +08:00
|
|
|
v_s.append(self.critic(b.obs))
|
|
|
|
v_s_.append(self.critic(b.obs_next))
|
|
|
|
old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v_s[0])))
|
|
|
|
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
|
|
|
|
v_s = to_numpy(batch.v_s)
|
|
|
|
v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten())
|
|
|
|
if self._rew_norm: # unnormalize v_s & v_s_
|
|
|
|
v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
|
|
|
|
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
|
|
|
|
unnormalized_returns, advantages = self.compute_episodic_return(
|
|
|
|
batch, buffer, indice, v_s_, v_s,
|
|
|
|
gamma=self._gamma, gae_lambda=self._lambda)
|
2020-08-15 16:10:42 +08:00
|
|
|
if self._rew_norm:
|
2021-03-23 22:05:48 +08:00
|
|
|
batch.returns = (unnormalized_returns - self.ret_rms.mean) / \
|
|
|
|
np.sqrt(self.ret_rms.var + self._eps)
|
|
|
|
self.ret_rms.update(unnormalized_returns)
|
|
|
|
mean, std = np.mean(advantages), np.std(advantages)
|
|
|
|
advantages = (advantages - mean) / std # per-batch norm
|
2020-03-20 19:52:29 +08:00
|
|
|
else:
|
2021-03-23 22:05:48 +08:00
|
|
|
batch.returns = unnormalized_returns
|
|
|
|
batch.act = to_torch_as(batch.act, batch.v_s)
|
|
|
|
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
|
|
|
batch.returns = to_torch_as(batch.returns, batch.v_s)
|
|
|
|
batch.adv = to_torch_as(advantages, batch.v_s)
|
|
|
|
return batch
|
2020-03-20 19:52:29 +08:00
|
|
|
|
2020-09-13 19:31:50 +08:00
|
|
|
def learn( # type: ignore
|
2020-09-12 15:39:01 +08:00
|
|
|
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
|
|
|
) -> Dict[str, List[float]]:
|
2020-03-20 19:52:29 +08:00
|
|
|
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
|
|
|
for _ in range(repeat):
|
2020-08-27 12:15:18 +08:00
|
|
|
for b in batch.split(batch_size, merge_last=True):
|
2020-03-20 19:52:29 +08:00
|
|
|
dist = self(b).dist
|
2020-07-23 15:12:02 +08:00
|
|
|
value = self.critic(b.obs).flatten()
|
2020-07-24 17:38:12 +08:00
|
|
|
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
|
|
|
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
2020-04-19 14:30:42 +08:00
|
|
|
surr1 = ratio * b.adv
|
2021-02-27 11:20:43 +08:00
|
|
|
surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv
|
2020-04-19 14:30:42 +08:00
|
|
|
if self._dual_clip:
|
2020-09-13 19:31:50 +08:00
|
|
|
clip_loss = -torch.max(
|
|
|
|
torch.min(surr1, surr2), self._dual_clip * b.adv
|
|
|
|
).mean()
|
2020-04-19 14:30:42 +08:00
|
|
|
else:
|
|
|
|
clip_loss = -torch.min(surr1, surr2).mean()
|
2020-04-03 21:28:12 +08:00
|
|
|
clip_losses.append(clip_loss.item())
|
2020-04-19 14:30:42 +08:00
|
|
|
if self._value_clip:
|
2021-03-23 22:05:48 +08:00
|
|
|
v_clip = b.v_s + (value - b.v_s).clamp(
|
|
|
|
-self._eps_clip, self._eps_clip)
|
2020-04-19 14:30:42 +08:00
|
|
|
vf1 = (b.returns - value).pow(2)
|
|
|
|
vf2 = (b.returns - v_clip).pow(2)
|
2020-09-12 15:39:01 +08:00
|
|
|
vf_loss = 0.5 * torch.max(vf1, vf2).mean()
|
2020-04-19 14:30:42 +08:00
|
|
|
else:
|
2020-09-12 15:39:01 +08:00
|
|
|
vf_loss = 0.5 * (b.returns - value).pow(2).mean()
|
2020-04-03 21:28:12 +08:00
|
|
|
vf_losses.append(vf_loss.item())
|
2020-03-26 11:42:34 +08:00
|
|
|
e_loss = dist.entropy().mean()
|
2020-04-03 21:28:12 +08:00
|
|
|
ent_losses.append(e_loss.item())
|
2021-02-27 11:20:43 +08:00
|
|
|
loss = clip_loss + self._weight_vf * vf_loss \
|
|
|
|
- self._weight_ent * e_loss
|
2020-04-03 21:28:12 +08:00
|
|
|
losses.append(loss.item())
|
2020-03-20 19:52:29 +08:00
|
|
|
self.optim.zero_grad()
|
|
|
|
loss.backward()
|
2021-03-23 22:05:48 +08:00
|
|
|
if self._grad_norm is not None:
|
2020-10-31 16:38:54 +08:00
|
|
|
nn.utils.clip_grad_norm_(
|
2021-02-27 11:20:43 +08:00
|
|
|
list(self.actor.parameters()) + list(self.critic.parameters()),
|
2021-03-23 22:05:48 +08:00
|
|
|
self._grad_norm)
|
2020-03-20 19:52:29 +08:00
|
|
|
self.optim.step()
|
2021-03-22 16:57:24 +08:00
|
|
|
# update learning rate if lr_scheduler is given
|
|
|
|
if self.lr_scheduler is not None:
|
|
|
|
self.lr_scheduler.step()
|
|
|
|
|
2020-03-20 19:52:29 +08:00
|
|
|
return {
|
2020-09-12 15:39:01 +08:00
|
|
|
"loss": losses,
|
|
|
|
"loss/clip": clip_losses,
|
|
|
|
"loss/vf": vf_losses,
|
|
|
|
"loss/ent": ent_losses,
|
2020-03-20 19:52:29 +08:00
|
|
|
}
|