refactor ppo (#329)

This commit is contained in:
ChenDRAG 2021-03-28 18:28:36 +08:00 committed by GitHub
parent 1730a9008a
commit 5d580c3662
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 36 deletions

View File

@ -14,8 +14,7 @@ class A2CPolicy(PGPolicy):
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.nn.Module critic: the critic network. (s -> V(s))
:param torch.optim.Optimizer optim: the optimizer for actor and critic
network.
:param torch.optim.Optimizer optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
:type dist_fn: Type[torch.distributions.Distribution]
:param float discount_factor: in [0, 1]. Default to 0.99.
@ -71,6 +70,13 @@ class A2CPolicy(PGPolicy):
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
batch = self._compute_returns(batch, buffer, indice)
batch.act = to_torch_as(batch.act, batch.v_s)
return batch
def _compute_returns(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
v_s, v_s_ = [], []
with torch.no_grad():
@ -96,7 +102,6 @@ class A2CPolicy(PGPolicy):
self.ret_rms.update(unnormalized_returns)
else:
batch.returns = unnormalized_returns
batch.act = to_torch_as(batch.act, batch.v_s)
batch.returns = to_torch_as(batch.returns, batch.v_s)
batch.adv = to_torch_as(advantages, batch.v_s)
return batch

View File

@ -4,7 +4,7 @@ from torch import nn
from typing import Any, Dict, List, Type, Optional
from tianshou.policy import A2CPolicy
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
from tianshou.data import Batch, ReplayBuffer, to_torch_as
class PPOPolicy(A2CPolicy):
@ -24,6 +24,11 @@ class PPOPolicy(A2CPolicy):
Default to 5.0 (set None if you do not want to use it).
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1.
Default to True.
:param bool advantage_normalization: whether to do per mini-batch advantage
normalization. Default to True.
:param bool recompute_advantage: whether to recompute advantage every update
repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5.
Default to False.
:param float vf_coef: weight for value loss. Default to 0.5.
:param float ent_coef: weight for entropy loss. Default to 0.01.
:param float max_grad_norm: clipping gradients in back propagation. Default to
@ -59,7 +64,9 @@ class PPOPolicy(A2CPolicy):
dist_fn: Type[torch.distributions.Distribution],
eps_clip: float = 0.2,
dual_clip: Optional[float] = None,
value_clip: bool = True,
value_clip: bool = False,
advantage_normalization: bool = True,
recompute_advantage: bool = False,
**kwargs: Any,
) -> None:
super().__init__(actor, critic, optim, dist_fn, **kwargs)
@ -68,51 +75,41 @@ class PPOPolicy(A2CPolicy):
"Dual-clip PPO parameter should greater than 1.0."
self._dual_clip = dual_clip
self._value_clip = value_clip
if not self._rew_norm:
assert not self._value_clip, \
"value clip is available only when `reward_normalization` is True"
self._norm_adv = advantage_normalization
self._recompute_adv = recompute_advantage
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
v_s, v_s_, old_log_prob = [], [], []
if self._recompute_adv:
# buffer input `buffer` and `indice` to be used in `learn()`.
self._buffer = buffer
self._indice = indice
batch = self._compute_returns(batch, buffer, indice)
batch.act = to_torch_as(batch.act, batch.v_s)
old_log_prob = []
with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True):
v_s.append(self.critic(b.obs))
v_s_.append(self.critic(b.obs_next))
old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v_s[0])))
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
v_s = to_numpy(batch.v_s)
v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten())
# when normalizing values, we do not minus self.ret_rms.mean to be numerically
# consistent with OPENAI baselines' value normalization pipeline. Emperical
# study also shows that "minus mean" will harm performances a tiny little bit
# due to unknown reasons (on Mujoco envs, not confident, though).
if self._rew_norm: # unnormalize v_s & v_s_
v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
unnormalized_returns, advantages = self.compute_episodic_return(
batch, buffer, indice, v_s_, v_s,
gamma=self._gamma, gae_lambda=self._lambda)
if self._rew_norm:
batch.returns = unnormalized_returns / \
np.sqrt(self.ret_rms.var + self._eps)
self.ret_rms.update(unnormalized_returns)
mean, std = np.mean(advantages), np.std(advantages)
advantages = (advantages - mean) / std
else:
batch.returns = unnormalized_returns
batch.act = to_torch_as(batch.act, batch.v_s)
old_log_prob.append(self(b).dist.log_prob(b.act))
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = to_torch_as(batch.returns, batch.v_s)
batch.adv = to_torch_as(advantages, batch.v_s)
return batch
def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat):
for step in range(repeat):
if self._recompute_adv and step > 0:
batch = self._compute_returns(batch, self._buffer, self._indice)
for b in batch.split(batch_size, merge_last=True):
# calculate loss for actor
dist = self(b).dist
if self._norm_adv:
mean, std = b.adv.mean(), b.adv.std()
b.adv = (b.adv - mean) / std # per-batch norm
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
surr1 = ratio * b.adv
@ -130,9 +127,9 @@ class PPOPolicy(A2CPolicy):
-self._eps_clip, self._eps_clip)
vf1 = (b.returns - value).pow(2)
vf2 = (b.returns - v_clip).pow(2)
vf_loss = 0.5 * torch.max(vf1, vf2).mean()
vf_loss = torch.max(vf1, vf2).mean()
else:
vf_loss = 0.5 * (b.returns - value).pow(2).mean()
vf_loss = (b.returns - value).pow(2).mean()
# calculate regularization and overall loss
ent_loss = dist.entropy().mean()
loss = clip_loss + self._weight_vf * vf_loss \