2020-03-17 20:22:37 +08:00
|
|
|
import torch
|
2020-04-14 21:11:06 +08:00
|
|
|
import numpy as np
|
2020-03-18 21:45:41 +08:00
|
|
|
from torch import nn
|
2020-03-17 20:22:37 +08:00
|
|
|
import torch.nn.functional as F
|
2021-02-27 11:20:43 +08:00
|
|
|
from typing import Any, Dict, List, Type, Union, Optional
|
2020-03-17 20:22:37 +08:00
|
|
|
|
|
|
|
from tianshou.policy import PGPolicy
|
2020-06-03 13:59:47 +08:00
|
|
|
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
2020-03-17 20:22:37 +08:00
|
|
|
|
|
|
|
|
|
|
|
class A2CPolicy(PGPolicy):
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783.
|
2020-04-06 19:36:59 +08:00
|
|
|
|
|
|
|
:param torch.nn.Module actor: the actor network following the rules in
|
|
|
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
|
|
|
:param torch.nn.Module critic: the critic network. (s -> V(s))
|
|
|
|
:param torch.optim.Optimizer optim: the optimizer for actor and critic
|
|
|
|
network.
|
2020-09-11 07:55:37 +08:00
|
|
|
:param dist_fn: distribution class for computing the action.
|
2021-02-27 11:20:43 +08:00
|
|
|
:type dist_fn: Type[torch.distributions.Distribution]
|
|
|
|
:param float discount_factor: in [0, 1]. Default to 0.99.
|
|
|
|
:param float vf_coef: weight for value loss. Default to 0.5.
|
|
|
|
:param float ent_coef: weight for entropy loss. Default to 0.01.
|
|
|
|
:param float max_grad_norm: clipping gradients in back propagation.
|
|
|
|
Default to None.
|
2020-04-14 21:11:06 +08:00
|
|
|
:param float gae_lambda: in [0, 1], param for Generalized Advantage
|
2021-02-27 11:20:43 +08:00
|
|
|
Estimation. Default to 0.95.
|
|
|
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
|
|
|
Default to False.
|
2020-08-27 12:15:18 +08:00
|
|
|
:param int max_batchsize: the maximum size of the batch when computing GAE,
|
|
|
|
depends on the size of available memory and the memory cost of the
|
2021-02-27 11:20:43 +08:00
|
|
|
model; should be as large as possible within the memory constraint.
|
|
|
|
Default to 256.
|
2021-03-21 16:45:50 +08:00
|
|
|
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
|
|
|
[action_spaces.low, action_spaces.high]. Default to True.
|
|
|
|
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
|
|
|
either "clip" (for simply clipping the action), "tanh" (for applying tanh
|
|
|
|
squashing) for now, or empty string for no bounding. Default to "clip".
|
|
|
|
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
|
|
|
to use option "action_scaling" or "action_bound_method". Default to None.
|
2020-04-09 21:36:53 +08:00
|
|
|
|
|
|
|
.. seealso::
|
|
|
|
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
|
|
explanation.
|
2020-04-06 19:36:59 +08:00
|
|
|
"""
|
2020-03-17 20:22:37 +08:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
actor: torch.nn.Module,
|
|
|
|
critic: torch.nn.Module,
|
|
|
|
optim: torch.optim.Optimizer,
|
2021-02-27 11:20:43 +08:00
|
|
|
dist_fn: Type[torch.distributions.Distribution],
|
2020-09-12 15:39:01 +08:00
|
|
|
discount_factor: float = 0.99,
|
|
|
|
vf_coef: float = 0.5,
|
|
|
|
ent_coef: float = 0.01,
|
|
|
|
max_grad_norm: Optional[float] = None,
|
|
|
|
gae_lambda: float = 0.95,
|
|
|
|
reward_normalization: bool = False,
|
|
|
|
max_batchsize: int = 256,
|
|
|
|
**kwargs: Any
|
|
|
|
) -> None:
|
2020-04-08 21:13:15 +08:00
|
|
|
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
|
2020-03-19 17:23:46 +08:00
|
|
|
self.actor = actor
|
|
|
|
self.critic = critic
|
2020-09-12 15:39:01 +08:00
|
|
|
assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]."
|
2020-04-14 21:11:06 +08:00
|
|
|
self._lambda = gae_lambda
|
2021-01-20 02:13:04 -08:00
|
|
|
self._weight_vf = vf_coef
|
|
|
|
self._weight_ent = ent_coef
|
2020-03-18 21:45:41 +08:00
|
|
|
self._grad_norm = max_grad_norm
|
2020-08-27 12:15:18 +08:00
|
|
|
self._batch = max_batchsize
|
2020-04-26 16:13:51 +08:00
|
|
|
self._rew_norm = reward_normalization
|
2020-04-14 21:11:06 +08:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def process_fn(
|
|
|
|
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
|
|
|
) -> Batch:
|
|
|
|
if self._lambda in [0.0, 1.0]:
|
2020-04-14 21:11:06 +08:00
|
|
|
return self.compute_episodic_return(
|
2021-02-19 10:33:49 +08:00
|
|
|
batch, buffer, indice,
|
|
|
|
None, gamma=self._gamma, gae_lambda=self._lambda)
|
2020-04-14 21:11:06 +08:00
|
|
|
v_ = []
|
|
|
|
with torch.no_grad():
|
2020-08-27 12:15:18 +08:00
|
|
|
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
2020-05-29 14:45:21 +02:00
|
|
|
v_.append(to_numpy(self.critic(b.obs_next)))
|
2020-04-14 21:11:06 +08:00
|
|
|
v_ = np.concatenate(v_, axis=0)
|
|
|
|
return self.compute_episodic_return(
|
2021-02-27 11:20:43 +08:00
|
|
|
batch, buffer, indice, v_,
|
|
|
|
gamma=self._gamma, gae_lambda=self._lambda, rew_norm=self._rew_norm)
|
2020-03-17 20:22:37 +08:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
batch: Batch,
|
|
|
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
|
|
|
**kwargs: Any
|
|
|
|
) -> Batch:
|
2020-04-06 19:36:59 +08:00
|
|
|
"""Compute action over the given batch data.
|
|
|
|
|
|
|
|
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
|
|
|
|
|
|
|
* ``act`` the action.
|
|
|
|
* ``logits`` the network's raw output.
|
|
|
|
* ``dist`` the action distribution.
|
|
|
|
* ``state`` the hidden state.
|
|
|
|
|
2020-04-09 21:36:53 +08:00
|
|
|
.. seealso::
|
|
|
|
|
2020-04-10 10:47:16 +08:00
|
|
|
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
2020-04-09 21:36:53 +08:00
|
|
|
more detailed explanation.
|
2020-04-06 19:36:59 +08:00
|
|
|
"""
|
2020-03-19 17:23:46 +08:00
|
|
|
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
2020-04-06 19:36:59 +08:00
|
|
|
if isinstance(logits, tuple):
|
|
|
|
dist = self.dist_fn(*logits)
|
|
|
|
else:
|
2021-02-27 11:20:43 +08:00
|
|
|
dist = self.dist_fn(logits)
|
2020-03-18 21:45:41 +08:00
|
|
|
act = dist.sample()
|
2020-03-19 17:23:46 +08:00
|
|
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
2020-03-17 20:22:37 +08:00
|
|
|
|
2020-09-13 19:31:50 +08:00
|
|
|
def learn( # type: ignore
|
2020-09-12 15:39:01 +08:00
|
|
|
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
|
|
|
) -> Dict[str, List[float]]:
|
2020-03-20 19:52:29 +08:00
|
|
|
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
|
|
|
for _ in range(repeat):
|
2020-08-27 12:15:18 +08:00
|
|
|
for b in batch.split(batch_size, merge_last=True):
|
2020-03-20 19:52:29 +08:00
|
|
|
self.optim.zero_grad()
|
2020-04-29 17:48:48 +08:00
|
|
|
dist = self(b).dist
|
2020-07-23 15:12:02 +08:00
|
|
|
v = self.critic(b.obs).flatten()
|
2020-06-03 13:59:47 +08:00
|
|
|
a = to_torch_as(b.act, v)
|
|
|
|
r = to_torch_as(b.returns, v)
|
2020-09-12 15:39:01 +08:00
|
|
|
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
|
2020-07-24 17:38:12 +08:00
|
|
|
a_loss = -(log_prob * (r - v).detach()).mean()
|
2020-09-13 19:31:50 +08:00
|
|
|
vf_loss = F.mse_loss(r, v) # type: ignore
|
2020-03-20 19:52:29 +08:00
|
|
|
ent_loss = dist.entropy().mean()
|
2021-02-27 11:20:43 +08:00
|
|
|
loss = a_loss + self._weight_vf * vf_loss - self._weight_ent * ent_loss
|
2020-03-20 19:52:29 +08:00
|
|
|
loss.backward()
|
2020-06-03 13:59:47 +08:00
|
|
|
if self._grad_norm is not None:
|
2020-03-20 19:52:29 +08:00
|
|
|
nn.utils.clip_grad_norm_(
|
2021-02-27 11:20:43 +08:00
|
|
|
list(self.actor.parameters()) + list(self.critic.parameters()),
|
2020-09-12 15:39:01 +08:00
|
|
|
max_norm=self._grad_norm,
|
|
|
|
)
|
2020-03-20 19:52:29 +08:00
|
|
|
self.optim.step()
|
2020-04-03 21:28:12 +08:00
|
|
|
actor_losses.append(a_loss.item())
|
|
|
|
vf_losses.append(vf_loss.item())
|
|
|
|
ent_losses.append(ent_loss.item())
|
|
|
|
losses.append(loss.item())
|
2020-03-20 19:52:29 +08:00
|
|
|
return {
|
2020-09-12 15:39:01 +08:00
|
|
|
"loss": losses,
|
|
|
|
"loss/actor": actor_losses,
|
|
|
|
"loss/vf": vf_losses,
|
|
|
|
"loss/ent": ent_losses,
|
2020-03-20 19:52:29 +08:00
|
|
|
}
|