refactor ppo (#329)

This commit is contained in:
ChenDRAG 2021-03-28 18:28:36 +08:00 committed by GitHub
parent 1730a9008a
commit 5d580c3662
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 36 deletions

View File

@ -14,8 +14,7 @@ class A2CPolicy(PGPolicy):
:param torch.nn.Module actor: the actor network following the rules in :param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits) :class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.nn.Module critic: the critic network. (s -> V(s)) :param torch.nn.Module critic: the critic network. (s -> V(s))
:param torch.optim.Optimizer optim: the optimizer for actor and critic :param torch.optim.Optimizer optim: the optimizer for actor and critic network.
network.
:param dist_fn: distribution class for computing the action. :param dist_fn: distribution class for computing the action.
:type dist_fn: Type[torch.distributions.Distribution] :type dist_fn: Type[torch.distributions.Distribution]
:param float discount_factor: in [0, 1]. Default to 0.99. :param float discount_factor: in [0, 1]. Default to 0.99.
@ -71,6 +70,13 @@ class A2CPolicy(PGPolicy):
def process_fn( def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
batch = self._compute_returns(batch, buffer, indice)
batch.act = to_torch_as(batch.act, batch.v_s)
return batch
def _compute_returns(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch: ) -> Batch:
v_s, v_s_ = [], [] v_s, v_s_ = [], []
with torch.no_grad(): with torch.no_grad():
@ -96,7 +102,6 @@ class A2CPolicy(PGPolicy):
self.ret_rms.update(unnormalized_returns) self.ret_rms.update(unnormalized_returns)
else: else:
batch.returns = unnormalized_returns batch.returns = unnormalized_returns
batch.act = to_torch_as(batch.act, batch.v_s)
batch.returns = to_torch_as(batch.returns, batch.v_s) batch.returns = to_torch_as(batch.returns, batch.v_s)
batch.adv = to_torch_as(advantages, batch.v_s) batch.adv = to_torch_as(advantages, batch.v_s)
return batch return batch

View File

@ -4,7 +4,7 @@ from torch import nn
from typing import Any, Dict, List, Type, Optional from typing import Any, Dict, List, Type, Optional
from tianshou.policy import A2CPolicy from tianshou.policy import A2CPolicy
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as from tianshou.data import Batch, ReplayBuffer, to_torch_as
class PPOPolicy(A2CPolicy): class PPOPolicy(A2CPolicy):
@ -24,6 +24,11 @@ class PPOPolicy(A2CPolicy):
Default to 5.0 (set None if you do not want to use it). Default to 5.0 (set None if you do not want to use it).
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1. :param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1.
Default to True. Default to True.
:param bool advantage_normalization: whether to do per mini-batch advantage
normalization. Default to True.
:param bool recompute_advantage: whether to recompute advantage every update
repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5.
Default to False.
:param float vf_coef: weight for value loss. Default to 0.5. :param float vf_coef: weight for value loss. Default to 0.5.
:param float ent_coef: weight for entropy loss. Default to 0.01. :param float ent_coef: weight for entropy loss. Default to 0.01.
:param float max_grad_norm: clipping gradients in back propagation. Default to :param float max_grad_norm: clipping gradients in back propagation. Default to
@ -59,7 +64,9 @@ class PPOPolicy(A2CPolicy):
dist_fn: Type[torch.distributions.Distribution], dist_fn: Type[torch.distributions.Distribution],
eps_clip: float = 0.2, eps_clip: float = 0.2,
dual_clip: Optional[float] = None, dual_clip: Optional[float] = None,
value_clip: bool = True, value_clip: bool = False,
advantage_normalization: bool = True,
recompute_advantage: bool = False,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(actor, critic, optim, dist_fn, **kwargs) super().__init__(actor, critic, optim, dist_fn, **kwargs)
@ -68,51 +75,41 @@ class PPOPolicy(A2CPolicy):
"Dual-clip PPO parameter should greater than 1.0." "Dual-clip PPO parameter should greater than 1.0."
self._dual_clip = dual_clip self._dual_clip = dual_clip
self._value_clip = value_clip self._value_clip = value_clip
if not self._rew_norm:
assert not self._value_clip, \
"value clip is available only when `reward_normalization` is True"
self._norm_adv = advantage_normalization
self._recompute_adv = recompute_advantage
def process_fn( def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch: ) -> Batch:
v_s, v_s_, old_log_prob = [], [], [] if self._recompute_adv:
# buffer input `buffer` and `indice` to be used in `learn()`.
self._buffer = buffer
self._indice = indice
batch = self._compute_returns(batch, buffer, indice)
batch.act = to_torch_as(batch.act, batch.v_s)
old_log_prob = []
with torch.no_grad(): with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True): for b in batch.split(self._batch, shuffle=False, merge_last=True):
v_s.append(self.critic(b.obs)) old_log_prob.append(self(b).dist.log_prob(b.act))
v_s_.append(self.critic(b.obs_next))
old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v_s[0])))
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
v_s = to_numpy(batch.v_s)
v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten())
# when normalizing values, we do not minus self.ret_rms.mean to be numerically
# consistent with OPENAI baselines' value normalization pipeline. Emperical
# study also shows that "minus mean" will harm performances a tiny little bit
# due to unknown reasons (on Mujoco envs, not confident, though).
if self._rew_norm: # unnormalize v_s & v_s_
v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
unnormalized_returns, advantages = self.compute_episodic_return(
batch, buffer, indice, v_s_, v_s,
gamma=self._gamma, gae_lambda=self._lambda)
if self._rew_norm:
batch.returns = unnormalized_returns / \
np.sqrt(self.ret_rms.var + self._eps)
self.ret_rms.update(unnormalized_returns)
mean, std = np.mean(advantages), np.std(advantages)
advantages = (advantages - mean) / std
else:
batch.returns = unnormalized_returns
batch.act = to_torch_as(batch.act, batch.v_s)
batch.logp_old = torch.cat(old_log_prob, dim=0) batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = to_torch_as(batch.returns, batch.v_s)
batch.adv = to_torch_as(advantages, batch.v_s)
return batch return batch
def learn( # type: ignore def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]: ) -> Dict[str, List[float]]:
losses, clip_losses, vf_losses, ent_losses = [], [], [], [] losses, clip_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat): for step in range(repeat):
if self._recompute_adv and step > 0:
batch = self._compute_returns(batch, self._buffer, self._indice)
for b in batch.split(batch_size, merge_last=True): for b in batch.split(batch_size, merge_last=True):
# calculate loss for actor # calculate loss for actor
dist = self(b).dist dist = self(b).dist
if self._norm_adv:
mean, std = b.adv.mean(), b.adv.std()
b.adv = (b.adv - mean) / std # per-batch norm
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
surr1 = ratio * b.adv surr1 = ratio * b.adv
@ -130,9 +127,9 @@ class PPOPolicy(A2CPolicy):
-self._eps_clip, self._eps_clip) -self._eps_clip, self._eps_clip)
vf1 = (b.returns - value).pow(2) vf1 = (b.returns - value).pow(2)
vf2 = (b.returns - v_clip).pow(2) vf2 = (b.returns - v_clip).pow(2)
vf_loss = 0.5 * torch.max(vf1, vf2).mean() vf_loss = torch.max(vf1, vf2).mean()
else: else:
vf_loss = 0.5 * (b.returns - value).pow(2).mean() vf_loss = (b.returns - value).pow(2).mean()
# calculate regularization and overall loss # calculate regularization and overall loss
ent_loss = dist.entropy().mean() ent_loss = dist.entropy().mean()
loss = clip_loss + self._weight_vf * vf_loss \ loss = clip_loss + self._weight_vf * vf_loss \