Trinkle23897 680fc0ffbe gae
2020-04-14 21:11:06 +08:00

160 lines
6.2 KiB
Python

import torch
import numpy as np
from torch import nn
from copy import deepcopy
import torch.nn.functional as F
from tianshou.data import Batch
from tianshou.policy import PGPolicy
class PPOPolicy(PGPolicy):
r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.nn.Module critic: the critic network. (s -> V(s))
:param torch.optim.Optimizer optim: the optimizer for actor and critic
network.
:param torch.distributions.Distribution dist_fn: for computing the action.
:param float discount_factor: in [0, 1], defaults to 0.99.
:param float max_grad_norm: clipping gradients in back propagation,
defaults to ``None``.
:param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original
paper, defaults to 0.2.
:param float vf_coef: weight for value loss, defaults to 0.5.
:param float ent_coef: weight for entropy loss, defaults to 0.01.
:param action_range: the action range (minimum, maximum).
:type action_range: [float, float]
:param float gae_lambda: in [0, 1], param for Generalized Advantage
Estimation, defaults to 0.95.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(self, actor, critic, optim, dist_fn,
discount_factor=0.99,
max_grad_norm=.5,
eps_clip=.2,
vf_coef=.5,
ent_coef=.0,
action_range=None,
gae_lambda=0.95,
**kwargs):
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
self._max_grad_norm = max_grad_norm
self._eps_clip = eps_clip
self._w_vf = vf_coef
self._w_ent = ent_coef
self._range = action_range
self.actor, self.actor_old = actor, deepcopy(actor)
self.actor_old.eval()
self.critic, self.critic_old = critic, deepcopy(critic)
self.critic_old.eval()
self.optim = optim
self._batch = 64
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
self._lambda = gae_lambda
def train(self):
"""Set the module in training mode, except for the target network."""
self.training = True
self.actor.train()
self.critic.train()
def eval(self):
"""Set the module in evaluation mode, except for the target network."""
self.training = False
self.actor.eval()
self.critic.eval()
def process_fn(self, batch, buffer, indice):
if self._lambda in [0, 1]:
return self.compute_episodic_return(
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
v_ = []
with torch.no_grad():
for b in batch.split(self._batch * 4, permute=False):
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
v_ = np.concatenate(v_, axis=0)
batch.v_ = v_
return self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
def forward(self, batch, state=None, model='actor', **kwargs):
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``dist`` the action distribution.
* ``state`` the hidden state.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
model = getattr(self, model)
logits, h = model(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
dist = self.dist_fn(logits)
act = dist.sample()
if self._range:
act = act.clamp(self._range[0], self._range[1])
return Batch(logits=logits, act=act, state=h, dist=dist)
def sync_weight(self):
"""Synchronize the weight for the target network."""
self.actor_old.load_state_dict(self.actor.state_dict())
self.critic_old.load_state_dict(self.critic.state_dict())
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
self._batch = batch_size
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
r = batch.returns
batch.returns = (r - r.mean()) / (r.std() + self._eps)
batch.act = torch.tensor(batch.act)
batch.returns = torch.tensor(batch.returns)[:, None]
batch.v_ = torch.tensor(batch.v_)
for _ in range(repeat):
for b in batch.split(batch_size):
vs_old = self.critic_old(b.obs)
vs__old = b.v_.to(vs_old.device)
dist = self(b).dist
dist_old = self(b, model='actor_old').dist
target_v = b.returns.to(vs_old.device) + self._gamma * vs__old
adv = (target_v - vs_old).detach()
a = b.act.to(adv.device)
ratio = torch.exp(dist.log_prob(a) - dist_old.log_prob(a))
surr1 = ratio * adv
surr2 = ratio.clamp(
1. - self._eps_clip, 1. + self._eps_clip) * adv
clip_loss = -torch.min(surr1, surr2).mean()
clip_losses.append(clip_loss.item())
vf_loss = F.smooth_l1_loss(self.critic(b.obs), target_v)
vf_losses.append(vf_loss.item())
e_loss = dist.entropy().mean()
ent_losses.append(e_loss.item())
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss
losses.append(loss.item())
self.optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(list(
self.actor.parameters()) + list(self.critic.parameters()),
self._max_grad_norm)
self.optim.step()
self.sync_weight()
return {
'loss': losses,
'loss/clip': clip_losses,
'loss/vf': vf_losses,
'loss/ent': ent_losses,
}