2020-03-17 20:22:37 +08:00
|
|
|
import torch
|
2020-04-14 21:11:06 +08:00
|
|
|
import numpy as np
|
2020-03-18 21:45:41 +08:00
|
|
|
from torch import nn
|
2020-03-17 20:22:37 +08:00
|
|
|
import torch.nn.functional as F
|
2020-05-12 11:31:47 +08:00
|
|
|
from typing import Dict, List, Union, Optional
|
2020-03-17 20:22:37 +08:00
|
|
|
|
|
|
|
from tianshou.policy import PGPolicy
|
2020-05-12 11:31:47 +08:00
|
|
|
from tianshou.data import Batch, ReplayBuffer
|
2020-03-17 20:22:37 +08:00
|
|
|
|
|
|
|
|
|
|
|
class A2CPolicy(PGPolicy):
|
2020-04-06 19:36:59 +08:00
|
|
|
"""Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783
|
|
|
|
|
|
|
|
:param torch.nn.Module actor: the actor network following the rules in
|
|
|
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
|
|
|
:param torch.nn.Module critic: the critic network. (s -> V(s))
|
|
|
|
:param torch.optim.Optimizer optim: the optimizer for actor and critic
|
|
|
|
network.
|
|
|
|
:param torch.distributions.Distribution dist_fn: for computing the action,
|
|
|
|
defaults to ``torch.distributions.Categorical``.
|
|
|
|
:param float discount_factor: in [0, 1], defaults to 0.99.
|
|
|
|
:param float vf_coef: weight for value loss, defaults to 0.5.
|
|
|
|
:param float ent_coef: weight for entropy loss, defaults to 0.01.
|
|
|
|
:param float max_grad_norm: clipping gradients in back propagation,
|
|
|
|
defaults to ``None``.
|
2020-04-14 21:11:06 +08:00
|
|
|
:param float gae_lambda: in [0, 1], param for Generalized Advantage
|
|
|
|
Estimation, defaults to 0.95.
|
2020-04-09 21:36:53 +08:00
|
|
|
|
|
|
|
.. seealso::
|
|
|
|
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
|
|
explanation.
|
2020-04-06 19:36:59 +08:00
|
|
|
"""
|
2020-03-17 20:22:37 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def __init__(self,
|
|
|
|
actor: torch.nn.Module,
|
|
|
|
critic: torch.nn.Module,
|
|
|
|
optim: torch.optim.Optimizer,
|
2020-05-16 20:08:32 +08:00
|
|
|
dist_fn: torch.distributions.Distribution
|
2020-05-12 11:31:47 +08:00
|
|
|
= torch.distributions.Categorical,
|
2020-05-16 20:08:32 +08:00
|
|
|
discount_factor: float = 0.99,
|
|
|
|
vf_coef: float = .5,
|
|
|
|
ent_coef: float = .01,
|
2020-05-12 11:31:47 +08:00
|
|
|
max_grad_norm: Optional[float] = None,
|
2020-05-16 20:08:32 +08:00
|
|
|
gae_lambda: float = 0.95,
|
|
|
|
reward_normalization: bool = False,
|
2020-05-12 11:31:47 +08:00
|
|
|
**kwargs) -> None:
|
2020-04-08 21:13:15 +08:00
|
|
|
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
|
2020-03-19 17:23:46 +08:00
|
|
|
self.actor = actor
|
|
|
|
self.critic = critic
|
2020-04-14 21:11:06 +08:00
|
|
|
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
|
|
|
|
self._lambda = gae_lambda
|
2020-03-20 19:52:29 +08:00
|
|
|
self._w_vf = vf_coef
|
|
|
|
self._w_ent = ent_coef
|
2020-03-18 21:45:41 +08:00
|
|
|
self._grad_norm = max_grad_norm
|
2020-04-14 21:11:06 +08:00
|
|
|
self._batch = 64
|
2020-04-26 16:13:51 +08:00
|
|
|
self._rew_norm = reward_normalization
|
2020-04-14 21:11:06 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
|
|
|
indice: np.ndarray) -> Batch:
|
2020-04-14 21:11:06 +08:00
|
|
|
if self._lambda in [0, 1]:
|
|
|
|
return self.compute_episodic_return(
|
|
|
|
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
|
|
|
v_ = []
|
|
|
|
with torch.no_grad():
|
2020-04-28 20:56:02 +08:00
|
|
|
for b in batch.split(self._batch, shuffle=False):
|
2020-04-14 21:11:06 +08:00
|
|
|
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
|
|
|
|
v_ = np.concatenate(v_, axis=0)
|
|
|
|
return self.compute_episodic_return(
|
|
|
|
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
2020-03-17 20:22:37 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def forward(self, batch: Batch,
|
|
|
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
|
|
|
**kwargs) -> Batch:
|
2020-04-06 19:36:59 +08:00
|
|
|
"""Compute action over the given batch data.
|
|
|
|
|
|
|
|
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
|
|
|
|
|
|
|
* ``act`` the action.
|
|
|
|
* ``logits`` the network's raw output.
|
|
|
|
* ``dist`` the action distribution.
|
|
|
|
* ``state`` the hidden state.
|
|
|
|
|
2020-04-09 21:36:53 +08:00
|
|
|
.. seealso::
|
|
|
|
|
2020-04-10 10:47:16 +08:00
|
|
|
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
2020-04-09 21:36:53 +08:00
|
|
|
more detailed explanation.
|
2020-04-06 19:36:59 +08:00
|
|
|
"""
|
2020-03-19 17:23:46 +08:00
|
|
|
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
2020-04-06 19:36:59 +08:00
|
|
|
if isinstance(logits, tuple):
|
|
|
|
dist = self.dist_fn(*logits)
|
|
|
|
else:
|
|
|
|
dist = self.dist_fn(logits)
|
2020-03-18 21:45:41 +08:00
|
|
|
act = dist.sample()
|
2020-03-19 17:23:46 +08:00
|
|
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
2020-03-17 20:22:37 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def learn(self, batch: Batch, batch_size: int, repeat: int,
|
|
|
|
**kwargs) -> Dict[str, List[float]]:
|
2020-04-14 21:11:06 +08:00
|
|
|
self._batch = batch_size
|
2020-04-26 16:13:51 +08:00
|
|
|
r = batch.returns
|
2020-05-27 11:02:23 +08:00
|
|
|
if self._rew_norm and not np.isclose(r.std(), 0):
|
2020-04-26 16:13:51 +08:00
|
|
|
batch.returns = (r - r.mean()) / r.std()
|
2020-03-20 19:52:29 +08:00
|
|
|
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
|
|
|
for _ in range(repeat):
|
|
|
|
for b in batch.split(batch_size):
|
|
|
|
self.optim.zero_grad()
|
2020-04-29 17:48:48 +08:00
|
|
|
dist = self(b).dist
|
2020-03-20 19:52:29 +08:00
|
|
|
v = self.critic(b.obs)
|
2020-04-14 21:11:06 +08:00
|
|
|
a = torch.tensor(b.act, device=v.device)
|
|
|
|
r = torch.tensor(b.returns, device=v.device)
|
2020-03-26 11:42:34 +08:00
|
|
|
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
2020-03-20 19:52:29 +08:00
|
|
|
vf_loss = F.mse_loss(r[:, None], v)
|
|
|
|
ent_loss = dist.entropy().mean()
|
2020-03-26 11:42:34 +08:00
|
|
|
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
|
2020-03-20 19:52:29 +08:00
|
|
|
loss.backward()
|
|
|
|
if self._grad_norm:
|
|
|
|
nn.utils.clip_grad_norm_(
|
2020-05-04 12:33:04 +08:00
|
|
|
list(self.actor.parameters()) +
|
|
|
|
list(self.critic.parameters()),
|
|
|
|
max_norm=self._grad_norm)
|
2020-03-20 19:52:29 +08:00
|
|
|
self.optim.step()
|
2020-04-03 21:28:12 +08:00
|
|
|
actor_losses.append(a_loss.item())
|
|
|
|
vf_losses.append(vf_loss.item())
|
|
|
|
ent_losses.append(ent_loss.item())
|
|
|
|
losses.append(loss.item())
|
2020-03-20 19:52:29 +08:00
|
|
|
return {
|
|
|
|
'loss': losses,
|
|
|
|
'loss/actor': actor_losses,
|
|
|
|
'loss/vf': vf_losses,
|
|
|
|
'loss/ent': ent_losses,
|
|
|
|
}
|