2021-09-03 05:05:04 +08:00
|
|
|
from typing import Any, Dict, List, Optional, Type
|
|
|
|
|
2020-04-14 21:11:06 +08:00
|
|
|
import numpy as np
|
2021-09-03 05:05:04 +08:00
|
|
|
import torch
|
2020-03-17 20:22:37 +08:00
|
|
|
import torch.nn.functional as F
|
2021-09-03 05:05:04 +08:00
|
|
|
from torch import nn
|
2020-03-17 20:22:37 +08:00
|
|
|
|
2021-03-30 16:06:03 +08:00
|
|
|
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
2021-09-03 05:05:04 +08:00
|
|
|
from tianshou.policy import PGPolicy
|
2021-10-04 11:19:07 -04:00
|
|
|
from tianshou.utils.net.common import ActorCritic
|
2020-03-17 20:22:37 +08:00
|
|
|
|
|
|
|
|
|
|
|
class A2CPolicy(PGPolicy):
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783.
|
2020-04-06 19:36:59 +08:00
|
|
|
|
|
|
|
:param torch.nn.Module actor: the actor network following the rules in
|
|
|
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
|
|
|
:param torch.nn.Module critic: the critic network. (s -> V(s))
|
2021-03-28 18:28:36 +08:00
|
|
|
:param torch.optim.Optimizer optim: the optimizer for actor and critic network.
|
2020-09-11 07:55:37 +08:00
|
|
|
:param dist_fn: distribution class for computing the action.
|
2021-02-27 11:20:43 +08:00
|
|
|
:type dist_fn: Type[torch.distributions.Distribution]
|
|
|
|
:param float discount_factor: in [0, 1]. Default to 0.99.
|
|
|
|
:param float vf_coef: weight for value loss. Default to 0.5.
|
|
|
|
:param float ent_coef: weight for entropy loss. Default to 0.01.
|
2021-03-25 10:12:39 +08:00
|
|
|
:param float max_grad_norm: clipping gradients in back propagation. Default to
|
|
|
|
None.
|
|
|
|
:param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
|
|
|
|
Default to 0.95.
|
|
|
|
:param bool reward_normalization: normalize estimated values to have std close to
|
|
|
|
1. Default to False.
|
2020-08-27 12:15:18 +08:00
|
|
|
:param int max_batchsize: the maximum size of the batch when computing GAE,
|
|
|
|
depends on the size of available memory and the memory cost of the
|
2021-02-27 11:20:43 +08:00
|
|
|
model; should be as large as possible within the memory constraint.
|
|
|
|
Default to 256.
|
2021-03-21 16:45:50 +08:00
|
|
|
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
|
|
|
[action_spaces.low, action_spaces.high]. Default to True.
|
|
|
|
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
|
|
|
either "clip" (for simply clipping the action), "tanh" (for applying tanh
|
|
|
|
squashing) for now, or empty string for no bounding. Default to "clip".
|
|
|
|
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
|
|
|
to use option "action_scaling" or "action_bound_method". Default to None.
|
2021-03-22 16:57:24 +08:00
|
|
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
|
|
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
2021-04-27 21:22:39 +08:00
|
|
|
:param bool deterministic_eval: whether to use deterministic action instead of
|
|
|
|
stochastic action sampled by the policy. Default to False.
|
2020-04-09 21:36:53 +08:00
|
|
|
|
|
|
|
.. seealso::
|
|
|
|
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
|
|
explanation.
|
2020-04-06 19:36:59 +08:00
|
|
|
"""
|
2020-03-17 20:22:37 +08:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
actor: torch.nn.Module,
|
|
|
|
critic: torch.nn.Module,
|
|
|
|
optim: torch.optim.Optimizer,
|
2021-02-27 11:20:43 +08:00
|
|
|
dist_fn: Type[torch.distributions.Distribution],
|
2020-09-12 15:39:01 +08:00
|
|
|
vf_coef: float = 0.5,
|
|
|
|
ent_coef: float = 0.01,
|
|
|
|
max_grad_norm: Optional[float] = None,
|
|
|
|
gae_lambda: float = 0.95,
|
|
|
|
max_batchsize: int = 256,
|
|
|
|
**kwargs: Any
|
|
|
|
) -> None:
|
2021-03-23 22:05:48 +08:00
|
|
|
super().__init__(actor, optim, dist_fn, **kwargs)
|
2020-03-19 17:23:46 +08:00
|
|
|
self.critic = critic
|
2020-09-12 15:39:01 +08:00
|
|
|
assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]."
|
2020-04-14 21:11:06 +08:00
|
|
|
self._lambda = gae_lambda
|
2021-01-20 02:13:04 -08:00
|
|
|
self._weight_vf = vf_coef
|
|
|
|
self._weight_ent = ent_coef
|
2020-03-18 21:45:41 +08:00
|
|
|
self._grad_norm = max_grad_norm
|
2020-08-27 12:15:18 +08:00
|
|
|
self._batch = max_batchsize
|
2021-10-04 11:19:07 -04:00
|
|
|
self._actor_critic = ActorCritic(self.actor, self.critic)
|
2020-04-14 21:11:06 +08:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def process_fn(
|
2021-08-20 09:58:44 -04:00
|
|
|
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
2021-03-28 18:28:36 +08:00
|
|
|
) -> Batch:
|
2021-08-20 09:58:44 -04:00
|
|
|
batch = self._compute_returns(batch, buffer, indices)
|
2021-03-28 18:28:36 +08:00
|
|
|
batch.act = to_torch_as(batch.act, batch.v_s)
|
|
|
|
return batch
|
|
|
|
|
|
|
|
def _compute_returns(
|
2021-08-20 09:58:44 -04:00
|
|
|
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
2020-09-12 15:39:01 +08:00
|
|
|
) -> Batch:
|
2021-03-25 10:12:39 +08:00
|
|
|
v_s, v_s_ = [], []
|
2020-04-14 21:11:06 +08:00
|
|
|
with torch.no_grad():
|
2020-08-27 12:15:18 +08:00
|
|
|
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
2021-03-25 10:12:39 +08:00
|
|
|
v_s.append(self.critic(b.obs))
|
|
|
|
v_s_.append(self.critic(b.obs_next))
|
|
|
|
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
|
2021-03-30 16:06:03 +08:00
|
|
|
v_s = batch.v_s.cpu().numpy()
|
|
|
|
v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy()
|
2021-03-25 10:12:39 +08:00
|
|
|
# when normalizing values, we do not minus self.ret_rms.mean to be numerically
|
|
|
|
# consistent with OPENAI baselines' value normalization pipeline. Emperical
|
|
|
|
# study also shows that "minus mean" will harm performances a tiny little bit
|
|
|
|
# due to unknown reasons (on Mujoco envs, not confident, though).
|
|
|
|
if self._rew_norm: # unnormalize v_s & v_s_
|
|
|
|
v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
|
|
|
|
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
|
|
|
|
unnormalized_returns, advantages = self.compute_episodic_return(
|
2021-09-03 05:05:04 +08:00
|
|
|
batch,
|
|
|
|
buffer,
|
|
|
|
indices,
|
|
|
|
v_s_,
|
|
|
|
v_s,
|
|
|
|
gamma=self._gamma,
|
|
|
|
gae_lambda=self._lambda
|
|
|
|
)
|
2021-03-23 22:05:48 +08:00
|
|
|
if self._rew_norm:
|
2021-03-25 10:12:39 +08:00
|
|
|
batch.returns = unnormalized_returns / \
|
2021-03-23 22:05:48 +08:00
|
|
|
np.sqrt(self.ret_rms.var + self._eps)
|
|
|
|
self.ret_rms.update(unnormalized_returns)
|
2020-04-06 19:36:59 +08:00
|
|
|
else:
|
2021-03-23 22:05:48 +08:00
|
|
|
batch.returns = unnormalized_returns
|
2021-03-25 10:12:39 +08:00
|
|
|
batch.returns = to_torch_as(batch.returns, batch.v_s)
|
|
|
|
batch.adv = to_torch_as(advantages, batch.v_s)
|
2021-03-23 22:05:48 +08:00
|
|
|
return batch
|
2020-03-17 20:22:37 +08:00
|
|
|
|
2020-09-13 19:31:50 +08:00
|
|
|
def learn( # type: ignore
|
2020-09-12 15:39:01 +08:00
|
|
|
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
|
|
|
) -> Dict[str, List[float]]:
|
2020-03-20 19:52:29 +08:00
|
|
|
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
|
|
|
for _ in range(repeat):
|
2020-08-27 12:15:18 +08:00
|
|
|
for b in batch.split(batch_size, merge_last=True):
|
2021-03-25 10:12:39 +08:00
|
|
|
# calculate loss for actor
|
2020-04-29 17:48:48 +08:00
|
|
|
dist = self(b).dist
|
2021-03-25 10:12:39 +08:00
|
|
|
log_prob = dist.log_prob(b.act).reshape(len(b.adv), -1).transpose(0, 1)
|
|
|
|
actor_loss = -(log_prob * b.adv).mean()
|
|
|
|
# calculate loss for critic
|
|
|
|
value = self.critic(b.obs).flatten()
|
|
|
|
vf_loss = F.mse_loss(b.returns, value)
|
|
|
|
# calculate regularization and overall loss
|
2020-03-20 19:52:29 +08:00
|
|
|
ent_loss = dist.entropy().mean()
|
2021-03-25 10:12:39 +08:00
|
|
|
loss = actor_loss + self._weight_vf * vf_loss \
|
|
|
|
- self._weight_ent * ent_loss
|
|
|
|
self.optim.zero_grad()
|
2020-03-20 19:52:29 +08:00
|
|
|
loss.backward()
|
2021-03-28 13:12:43 +08:00
|
|
|
if self._grad_norm: # clip large gradient
|
2020-03-20 19:52:29 +08:00
|
|
|
nn.utils.clip_grad_norm_(
|
2021-10-04 11:19:07 -04:00
|
|
|
self._actor_critic.parameters(), max_norm=self._grad_norm
|
2021-09-03 05:05:04 +08:00
|
|
|
)
|
2020-03-20 19:52:29 +08:00
|
|
|
self.optim.step()
|
2021-03-25 10:12:39 +08:00
|
|
|
actor_losses.append(actor_loss.item())
|
2020-04-03 21:28:12 +08:00
|
|
|
vf_losses.append(vf_loss.item())
|
|
|
|
ent_losses.append(ent_loss.item())
|
|
|
|
losses.append(loss.item())
|
2021-03-22 16:57:24 +08:00
|
|
|
# update learning rate if lr_scheduler is given
|
|
|
|
if self.lr_scheduler is not None:
|
|
|
|
self.lr_scheduler.step()
|
|
|
|
|
2020-03-20 19:52:29 +08:00
|
|
|
return {
|
2020-09-12 15:39:01 +08:00
|
|
|
"loss": losses,
|
|
|
|
"loss/actor": actor_losses,
|
|
|
|
"loss/vf": vf_losses,
|
|
|
|
"loss/ent": ent_losses,
|
2020-03-20 19:52:29 +08:00
|
|
|
}
|