refactor ppo (#329)
This commit is contained in:
parent
1730a9008a
commit
5d580c3662
@ -14,8 +14,7 @@ class A2CPolicy(PGPolicy):
|
||||
:param torch.nn.Module actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.nn.Module critic: the critic network. (s -> V(s))
|
||||
:param torch.optim.Optimizer optim: the optimizer for actor and critic
|
||||
network.
|
||||
:param torch.optim.Optimizer optim: the optimizer for actor and critic network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
:type dist_fn: Type[torch.distributions.Distribution]
|
||||
:param float discount_factor: in [0, 1]. Default to 0.99.
|
||||
@ -71,6 +70,13 @@ class A2CPolicy(PGPolicy):
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
batch = self._compute_returns(batch, buffer, indice)
|
||||
batch.act = to_torch_as(batch.act, batch.v_s)
|
||||
return batch
|
||||
|
||||
def _compute_returns(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
v_s, v_s_ = [], []
|
||||
with torch.no_grad():
|
||||
@ -96,7 +102,6 @@ class A2CPolicy(PGPolicy):
|
||||
self.ret_rms.update(unnormalized_returns)
|
||||
else:
|
||||
batch.returns = unnormalized_returns
|
||||
batch.act = to_torch_as(batch.act, batch.v_s)
|
||||
batch.returns = to_torch_as(batch.returns, batch.v_s)
|
||||
batch.adv = to_torch_as(advantages, batch.v_s)
|
||||
return batch
|
||||
|
@ -4,7 +4,7 @@ from torch import nn
|
||||
from typing import Any, Dict, List, Type, Optional
|
||||
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||
|
||||
|
||||
class PPOPolicy(A2CPolicy):
|
||||
@ -24,6 +24,11 @@ class PPOPolicy(A2CPolicy):
|
||||
Default to 5.0 (set None if you do not want to use it).
|
||||
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1.
|
||||
Default to True.
|
||||
:param bool advantage_normalization: whether to do per mini-batch advantage
|
||||
normalization. Default to True.
|
||||
:param bool recompute_advantage: whether to recompute advantage every update
|
||||
repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5.
|
||||
Default to False.
|
||||
:param float vf_coef: weight for value loss. Default to 0.5.
|
||||
:param float ent_coef: weight for entropy loss. Default to 0.01.
|
||||
:param float max_grad_norm: clipping gradients in back propagation. Default to
|
||||
@ -59,7 +64,9 @@ class PPOPolicy(A2CPolicy):
|
||||
dist_fn: Type[torch.distributions.Distribution],
|
||||
eps_clip: float = 0.2,
|
||||
dual_clip: Optional[float] = None,
|
||||
value_clip: bool = True,
|
||||
value_clip: bool = False,
|
||||
advantage_normalization: bool = True,
|
||||
recompute_advantage: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(actor, critic, optim, dist_fn, **kwargs)
|
||||
@ -68,51 +75,41 @@ class PPOPolicy(A2CPolicy):
|
||||
"Dual-clip PPO parameter should greater than 1.0."
|
||||
self._dual_clip = dual_clip
|
||||
self._value_clip = value_clip
|
||||
if not self._rew_norm:
|
||||
assert not self._value_clip, \
|
||||
"value clip is available only when `reward_normalization` is True"
|
||||
self._norm_adv = advantage_normalization
|
||||
self._recompute_adv = recompute_advantage
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
v_s, v_s_, old_log_prob = [], [], []
|
||||
if self._recompute_adv:
|
||||
# buffer input `buffer` and `indice` to be used in `learn()`.
|
||||
self._buffer = buffer
|
||||
self._indice = indice
|
||||
batch = self._compute_returns(batch, buffer, indice)
|
||||
batch.act = to_torch_as(batch.act, batch.v_s)
|
||||
old_log_prob = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
||||
v_s.append(self.critic(b.obs))
|
||||
v_s_.append(self.critic(b.obs_next))
|
||||
old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v_s[0])))
|
||||
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
|
||||
v_s = to_numpy(batch.v_s)
|
||||
v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten())
|
||||
# when normalizing values, we do not minus self.ret_rms.mean to be numerically
|
||||
# consistent with OPENAI baselines' value normalization pipeline. Emperical
|
||||
# study also shows that "minus mean" will harm performances a tiny little bit
|
||||
# due to unknown reasons (on Mujoco envs, not confident, though).
|
||||
if self._rew_norm: # unnormalize v_s & v_s_
|
||||
v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
|
||||
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
|
||||
unnormalized_returns, advantages = self.compute_episodic_return(
|
||||
batch, buffer, indice, v_s_, v_s,
|
||||
gamma=self._gamma, gae_lambda=self._lambda)
|
||||
if self._rew_norm:
|
||||
batch.returns = unnormalized_returns / \
|
||||
np.sqrt(self.ret_rms.var + self._eps)
|
||||
self.ret_rms.update(unnormalized_returns)
|
||||
mean, std = np.mean(advantages), np.std(advantages)
|
||||
advantages = (advantages - mean) / std
|
||||
else:
|
||||
batch.returns = unnormalized_returns
|
||||
batch.act = to_torch_as(batch.act, batch.v_s)
|
||||
old_log_prob.append(self(b).dist.log_prob(b.act))
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
batch.returns = to_torch_as(batch.returns, batch.v_s)
|
||||
batch.adv = to_torch_as(advantages, batch.v_s)
|
||||
return batch
|
||||
|
||||
def learn( # type: ignore
|
||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||
) -> Dict[str, List[float]]:
|
||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||
for _ in range(repeat):
|
||||
for step in range(repeat):
|
||||
if self._recompute_adv and step > 0:
|
||||
batch = self._compute_returns(batch, self._buffer, self._indice)
|
||||
for b in batch.split(batch_size, merge_last=True):
|
||||
# calculate loss for actor
|
||||
dist = self(b).dist
|
||||
if self._norm_adv:
|
||||
mean, std = b.adv.mean(), b.adv.std()
|
||||
b.adv = (b.adv - mean) / std # per-batch norm
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
||||
surr1 = ratio * b.adv
|
||||
@ -130,9 +127,9 @@ class PPOPolicy(A2CPolicy):
|
||||
-self._eps_clip, self._eps_clip)
|
||||
vf1 = (b.returns - value).pow(2)
|
||||
vf2 = (b.returns - v_clip).pow(2)
|
||||
vf_loss = 0.5 * torch.max(vf1, vf2).mean()
|
||||
vf_loss = torch.max(vf1, vf2).mean()
|
||||
else:
|
||||
vf_loss = 0.5 * (b.returns - value).pow(2).mean()
|
||||
vf_loss = (b.returns - value).pow(2).mean()
|
||||
# calculate regularization and overall loss
|
||||
ent_loss = dist.entropy().mean()
|
||||
loss = clip_loss + self._weight_vf * vf_loss \
|
||||
|
Loading…
x
Reference in New Issue
Block a user