refactor ppo (#329)
This commit is contained in:
parent
1730a9008a
commit
5d580c3662
@ -14,8 +14,7 @@ class A2CPolicy(PGPolicy):
|
|||||||
:param torch.nn.Module actor: the actor network following the rules in
|
:param torch.nn.Module actor: the actor network following the rules in
|
||||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||||
:param torch.nn.Module critic: the critic network. (s -> V(s))
|
:param torch.nn.Module critic: the critic network. (s -> V(s))
|
||||||
:param torch.optim.Optimizer optim: the optimizer for actor and critic
|
:param torch.optim.Optimizer optim: the optimizer for actor and critic network.
|
||||||
network.
|
|
||||||
:param dist_fn: distribution class for computing the action.
|
:param dist_fn: distribution class for computing the action.
|
||||||
:type dist_fn: Type[torch.distributions.Distribution]
|
:type dist_fn: Type[torch.distributions.Distribution]
|
||||||
:param float discount_factor: in [0, 1]. Default to 0.99.
|
:param float discount_factor: in [0, 1]. Default to 0.99.
|
||||||
@ -71,6 +70,13 @@ class A2CPolicy(PGPolicy):
|
|||||||
|
|
||||||
def process_fn(
|
def process_fn(
|
||||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||||
|
) -> Batch:
|
||||||
|
batch = self._compute_returns(batch, buffer, indice)
|
||||||
|
batch.act = to_torch_as(batch.act, batch.v_s)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def _compute_returns(
|
||||||
|
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||||
) -> Batch:
|
) -> Batch:
|
||||||
v_s, v_s_ = [], []
|
v_s, v_s_ = [], []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -96,7 +102,6 @@ class A2CPolicy(PGPolicy):
|
|||||||
self.ret_rms.update(unnormalized_returns)
|
self.ret_rms.update(unnormalized_returns)
|
||||||
else:
|
else:
|
||||||
batch.returns = unnormalized_returns
|
batch.returns = unnormalized_returns
|
||||||
batch.act = to_torch_as(batch.act, batch.v_s)
|
|
||||||
batch.returns = to_torch_as(batch.returns, batch.v_s)
|
batch.returns = to_torch_as(batch.returns, batch.v_s)
|
||||||
batch.adv = to_torch_as(advantages, batch.v_s)
|
batch.adv = to_torch_as(advantages, batch.v_s)
|
||||||
return batch
|
return batch
|
||||||
|
@ -4,7 +4,7 @@ from torch import nn
|
|||||||
from typing import Any, Dict, List, Type, Optional
|
from typing import Any, Dict, List, Type, Optional
|
||||||
|
|
||||||
from tianshou.policy import A2CPolicy
|
from tianshou.policy import A2CPolicy
|
||||||
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
|
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||||
|
|
||||||
|
|
||||||
class PPOPolicy(A2CPolicy):
|
class PPOPolicy(A2CPolicy):
|
||||||
@ -24,6 +24,11 @@ class PPOPolicy(A2CPolicy):
|
|||||||
Default to 5.0 (set None if you do not want to use it).
|
Default to 5.0 (set None if you do not want to use it).
|
||||||
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1.
|
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1.
|
||||||
Default to True.
|
Default to True.
|
||||||
|
:param bool advantage_normalization: whether to do per mini-batch advantage
|
||||||
|
normalization. Default to True.
|
||||||
|
:param bool recompute_advantage: whether to recompute advantage every update
|
||||||
|
repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5.
|
||||||
|
Default to False.
|
||||||
:param float vf_coef: weight for value loss. Default to 0.5.
|
:param float vf_coef: weight for value loss. Default to 0.5.
|
||||||
:param float ent_coef: weight for entropy loss. Default to 0.01.
|
:param float ent_coef: weight for entropy loss. Default to 0.01.
|
||||||
:param float max_grad_norm: clipping gradients in back propagation. Default to
|
:param float max_grad_norm: clipping gradients in back propagation. Default to
|
||||||
@ -59,7 +64,9 @@ class PPOPolicy(A2CPolicy):
|
|||||||
dist_fn: Type[torch.distributions.Distribution],
|
dist_fn: Type[torch.distributions.Distribution],
|
||||||
eps_clip: float = 0.2,
|
eps_clip: float = 0.2,
|
||||||
dual_clip: Optional[float] = None,
|
dual_clip: Optional[float] = None,
|
||||||
value_clip: bool = True,
|
value_clip: bool = False,
|
||||||
|
advantage_normalization: bool = True,
|
||||||
|
recompute_advantage: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(actor, critic, optim, dist_fn, **kwargs)
|
super().__init__(actor, critic, optim, dist_fn, **kwargs)
|
||||||
@ -68,51 +75,41 @@ class PPOPolicy(A2CPolicy):
|
|||||||
"Dual-clip PPO parameter should greater than 1.0."
|
"Dual-clip PPO parameter should greater than 1.0."
|
||||||
self._dual_clip = dual_clip
|
self._dual_clip = dual_clip
|
||||||
self._value_clip = value_clip
|
self._value_clip = value_clip
|
||||||
|
if not self._rew_norm:
|
||||||
|
assert not self._value_clip, \
|
||||||
|
"value clip is available only when `reward_normalization` is True"
|
||||||
|
self._norm_adv = advantage_normalization
|
||||||
|
self._recompute_adv = recompute_advantage
|
||||||
|
|
||||||
def process_fn(
|
def process_fn(
|
||||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||||
) -> Batch:
|
) -> Batch:
|
||||||
v_s, v_s_, old_log_prob = [], [], []
|
if self._recompute_adv:
|
||||||
|
# buffer input `buffer` and `indice` to be used in `learn()`.
|
||||||
|
self._buffer = buffer
|
||||||
|
self._indice = indice
|
||||||
|
batch = self._compute_returns(batch, buffer, indice)
|
||||||
|
batch.act = to_torch_as(batch.act, batch.v_s)
|
||||||
|
old_log_prob = []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
||||||
v_s.append(self.critic(b.obs))
|
old_log_prob.append(self(b).dist.log_prob(b.act))
|
||||||
v_s_.append(self.critic(b.obs_next))
|
|
||||||
old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v_s[0])))
|
|
||||||
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
|
|
||||||
v_s = to_numpy(batch.v_s)
|
|
||||||
v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten())
|
|
||||||
# when normalizing values, we do not minus self.ret_rms.mean to be numerically
|
|
||||||
# consistent with OPENAI baselines' value normalization pipeline. Emperical
|
|
||||||
# study also shows that "minus mean" will harm performances a tiny little bit
|
|
||||||
# due to unknown reasons (on Mujoco envs, not confident, though).
|
|
||||||
if self._rew_norm: # unnormalize v_s & v_s_
|
|
||||||
v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
|
|
||||||
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
|
|
||||||
unnormalized_returns, advantages = self.compute_episodic_return(
|
|
||||||
batch, buffer, indice, v_s_, v_s,
|
|
||||||
gamma=self._gamma, gae_lambda=self._lambda)
|
|
||||||
if self._rew_norm:
|
|
||||||
batch.returns = unnormalized_returns / \
|
|
||||||
np.sqrt(self.ret_rms.var + self._eps)
|
|
||||||
self.ret_rms.update(unnormalized_returns)
|
|
||||||
mean, std = np.mean(advantages), np.std(advantages)
|
|
||||||
advantages = (advantages - mean) / std
|
|
||||||
else:
|
|
||||||
batch.returns = unnormalized_returns
|
|
||||||
batch.act = to_torch_as(batch.act, batch.v_s)
|
|
||||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||||
batch.returns = to_torch_as(batch.returns, batch.v_s)
|
|
||||||
batch.adv = to_torch_as(advantages, batch.v_s)
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def learn( # type: ignore
|
def learn( # type: ignore
|
||||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||||
) -> Dict[str, List[float]]:
|
) -> Dict[str, List[float]]:
|
||||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||||
for _ in range(repeat):
|
for step in range(repeat):
|
||||||
|
if self._recompute_adv and step > 0:
|
||||||
|
batch = self._compute_returns(batch, self._buffer, self._indice)
|
||||||
for b in batch.split(batch_size, merge_last=True):
|
for b in batch.split(batch_size, merge_last=True):
|
||||||
# calculate loss for actor
|
# calculate loss for actor
|
||||||
dist = self(b).dist
|
dist = self(b).dist
|
||||||
|
if self._norm_adv:
|
||||||
|
mean, std = b.adv.mean(), b.adv.std()
|
||||||
|
b.adv = (b.adv - mean) / std # per-batch norm
|
||||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||||
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
||||||
surr1 = ratio * b.adv
|
surr1 = ratio * b.adv
|
||||||
@ -130,9 +127,9 @@ class PPOPolicy(A2CPolicy):
|
|||||||
-self._eps_clip, self._eps_clip)
|
-self._eps_clip, self._eps_clip)
|
||||||
vf1 = (b.returns - value).pow(2)
|
vf1 = (b.returns - value).pow(2)
|
||||||
vf2 = (b.returns - v_clip).pow(2)
|
vf2 = (b.returns - v_clip).pow(2)
|
||||||
vf_loss = 0.5 * torch.max(vf1, vf2).mean()
|
vf_loss = torch.max(vf1, vf2).mean()
|
||||||
else:
|
else:
|
||||||
vf_loss = 0.5 * (b.returns - value).pow(2).mean()
|
vf_loss = (b.returns - value).pow(2).mean()
|
||||||
# calculate regularization and overall loss
|
# calculate regularization and overall loss
|
||||||
ent_loss = dist.entropy().mean()
|
ent_loss = dist.entropy().mean()
|
||||||
loss = clip_loss + self._weight_vf * vf_loss \
|
loss = clip_loss + self._weight_vf * vf_loss \
|
||||||
|
Loading…
x
Reference in New Issue
Block a user