refactor A2C/PPO, change behavior of value normalization (#321)

This commit is contained in:
ChenDRAG 2021-03-25 10:12:39 +08:00 committed by GitHub
parent 47c77899d5
commit 3ac67d9974
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 79 additions and 69 deletions

View File

@ -43,7 +43,7 @@ This will start 10 experiments with different seeds.
#### Example benchmark #### Example benchmark
<img src="./benchmark/Ant-v3/figure.png" width="500" height="450"> <img src="./benchmark/Ant-v3/offpolicy.png" width="500" height="450">
Other graphs can be found under `/examples/mujuco/benchmark/` Other graphs can be found under `/examples/mujuco/benchmark/`

View File

@ -20,22 +20,22 @@ def get_args():
parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1) parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--il-lr', type=float, default=1e-3) parser.add_argument('--il-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=50000) parser.add_argument('--step-per-epoch', type=int, default=50000)
parser.add_argument('--il-step-per-epoch', type=int, default=1000) parser.add_argument('--il-step-per-epoch', type=int, default=1000)
parser.add_argument('--episode-per-collect', type=int, default=8) parser.add_argument('--episode-per-collect', type=int, default=16)
parser.add_argument('--step-per-collect', type=int, default=8) parser.add_argument('--step-per-collect', type=int, default=16)
parser.add_argument('--update-per-step', type=float, default=0.125) parser.add_argument('--update-per-step', type=float, default=1 / 16)
parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--repeat-per-collect', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int, parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128]) nargs='*', default=[64, 64])
parser.add_argument('--imitation-hidden-sizes', type=int, parser.add_argument('--imitation-hidden-sizes', type=int,
nargs='*', default=[128]) nargs='*', default=[128])
parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.) parser.add_argument('--render', type=float, default=0.)

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from typing import Any, Dict, List, Type, Optional from typing import Any, Dict, List, Type, Optional
from tianshou.policy import PGPolicy from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
class A2CPolicy(PGPolicy): class A2CPolicy(PGPolicy):
@ -21,12 +21,12 @@ class A2CPolicy(PGPolicy):
:param float discount_factor: in [0, 1]. Default to 0.99. :param float discount_factor: in [0, 1]. Default to 0.99.
:param float vf_coef: weight for value loss. Default to 0.5. :param float vf_coef: weight for value loss. Default to 0.5.
:param float ent_coef: weight for entropy loss. Default to 0.01. :param float ent_coef: weight for entropy loss. Default to 0.01.
:param float max_grad_norm: clipping gradients in back propagation. :param float max_grad_norm: clipping gradients in back propagation. Default to
Default to None. None.
:param float gae_lambda: in [0, 1], param for Generalized Advantage :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
Estimation. Default to 0.95. Default to 0.95.
:param bool reward_normalization: normalize the reward to Normal(0, 1). :param bool reward_normalization: normalize estimated values to have std close to
Default to False. 1. Default to False.
:param int max_batchsize: the maximum size of the batch when computing GAE, :param int max_batchsize: the maximum size of the batch when computing GAE,
depends on the size of available memory and the memory cost of the depends on the size of available memory and the memory cost of the
model; should be as large as possible within the memory constraint. model; should be as large as possible within the memory constraint.
@ -72,22 +72,33 @@ class A2CPolicy(PGPolicy):
def process_fn( def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch: ) -> Batch:
v_s_ = [] v_s, v_s_ = [], []
with torch.no_grad(): with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True): for b in batch.split(self._batch, shuffle=False, merge_last=True):
v_s_.append(to_numpy(self.critic(b.obs_next))) v_s.append(self.critic(b.obs))
v_s_ = np.concatenate(v_s_, axis=0) v_s_.append(self.critic(b.obs_next))
if self._rew_norm: # unnormalize v_s_ batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean v_s = to_numpy(batch.v_s)
unnormalized_returns, _ = self.compute_episodic_return( v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten())
batch, buffer, indice, v_s_=v_s_, # when normalizing values, we do not minus self.ret_rms.mean to be numerically
# consistent with OPENAI baselines' value normalization pipeline. Emperical
# study also shows that "minus mean" will harm performances a tiny little bit
# due to unknown reasons (on Mujoco envs, not confident, though).
if self._rew_norm: # unnormalize v_s & v_s_
v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
unnormalized_returns, advantages = self.compute_episodic_return(
batch, buffer, indice, v_s_, v_s,
gamma=self._gamma, gae_lambda=self._lambda) gamma=self._gamma, gae_lambda=self._lambda)
if self._rew_norm: if self._rew_norm:
batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ batch.returns = unnormalized_returns / \
np.sqrt(self.ret_rms.var + self._eps) np.sqrt(self.ret_rms.var + self._eps)
self.ret_rms.update(unnormalized_returns) self.ret_rms.update(unnormalized_returns)
else: else:
batch.returns = unnormalized_returns batch.returns = unnormalized_returns
batch.act = to_torch_as(batch.act, batch.v_s)
batch.returns = to_torch_as(batch.returns, batch.v_s)
batch.adv = to_torch_as(advantages, batch.v_s)
return batch return batch
def learn( # type: ignore def learn( # type: ignore
@ -96,24 +107,25 @@ class A2CPolicy(PGPolicy):
losses, actor_losses, vf_losses, ent_losses = [], [], [], [] losses, actor_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat): for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True): for b in batch.split(batch_size, merge_last=True):
self.optim.zero_grad() # calculate loss for actor
dist = self(b).dist dist = self(b).dist
v = self.critic(b.obs).flatten() log_prob = dist.log_prob(b.act).reshape(len(b.adv), -1).transpose(0, 1)
a = to_torch_as(b.act, v) actor_loss = -(log_prob * b.adv).mean()
r = to_torch_as(b.returns, v) # calculate loss for critic
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) value = self.critic(b.obs).flatten()
a_loss = -(log_prob * (r - v).detach()).mean() vf_loss = F.mse_loss(b.returns, value)
vf_loss = F.mse_loss(r, v) # type: ignore # calculate regularization and overall loss
ent_loss = dist.entropy().mean() ent_loss = dist.entropy().mean()
loss = a_loss + self._weight_vf * vf_loss - self._weight_ent * ent_loss loss = actor_loss + self._weight_vf * vf_loss \
- self._weight_ent * ent_loss
self.optim.zero_grad()
loss.backward() loss.backward()
if self._grad_norm is not None: if self._grad_norm is not None: # clip large gradient
nn.utils.clip_grad_norm_( nn.utils.clip_grad_norm_(
list(self.actor.parameters()) + list(self.critic.parameters()), list(self.actor.parameters()) + list(self.critic.parameters()),
max_norm=self._grad_norm, max_norm=self._grad_norm)
)
self.optim.step() self.optim.step()
actor_losses.append(a_loss.item()) actor_losses.append(actor_loss.item())
vf_losses.append(vf_loss.item()) vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item()) ent_losses.append(ent_loss.item())
losses.append(loss.item()) losses.append(loss.item())

View File

@ -116,9 +116,9 @@ class PGPolicy(BasePolicy):
result = self(b) result = self(b)
dist = result.dist dist = result.dist
a = to_torch_as(b.act, result.act) a = to_torch_as(b.act, result.act)
r = to_torch_as(b.returns, result.act) ret = to_torch_as(b.returns, result.act)
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) log_prob = dist.log_prob(a).reshape(len(ret), -1).transpose(0, 1)
loss = -(log_prob * r).mean() loss = -(log_prob * ret).mean()
loss.backward() loss.backward()
self.optim.step() self.optim.step()
losses.append(loss.item()) losses.append(loss.item())

View File

@ -17,25 +17,24 @@ class PPOPolicy(A2CPolicy):
:param dist_fn: distribution class for computing the action. :param dist_fn: distribution class for computing the action.
:type dist_fn: Type[torch.distributions.Distribution] :type dist_fn: Type[torch.distributions.Distribution]
:param float discount_factor: in [0, 1]. Default to 0.99. :param float discount_factor: in [0, 1]. Default to 0.99.
:param float max_grad_norm: clipping gradients in back propagation.
Default to None.
:param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original :param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original
paper. Default to 0.2. paper. Default to 0.2.
:param float vf_coef: weight for value loss. Default to 0.5.
:param float ent_coef: weight for entropy loss. Default to 0.01.
:param float gae_lambda: in [0, 1], param for Generalized Advantage
Estimation. Default to 0.95.
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, :param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
where c > 1 is a constant indicating the lower bound. where c > 1 is a constant indicating the lower bound.
Default to 5.0 (set None if you do not want to use it). Default to 5.0 (set None if you do not want to use it).
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1. :param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1.
Default to True. Default to True.
:param bool reward_normalization: normalize the returns and advantage to :param float vf_coef: weight for value loss. Default to 0.5.
Normal(0, 1). Default to False. :param float ent_coef: weight for entropy loss. Default to 0.01.
:param float max_grad_norm: clipping gradients in back propagation. Default to
None.
:param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
Default to 0.95.
:param bool reward_normalization: normalize estimated values to have std close
to 1, also normalize the advantage to Normal(0, 1). Default to False.
:param int max_batchsize: the maximum size of the batch when computing GAE, :param int max_batchsize: the maximum size of the batch when computing GAE,
depends on the size of available memory and the memory cost of the depends on the size of available memory and the memory cost of the model;
model; should be as large as possible within the memory constraint. should be as large as possible within the memory constraint. Default to 256.
Default to 256.
:param bool action_scaling: whether to map actions from range [-1, 1] to range :param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True. [action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be :param str action_bound_method: method to bound action to range [-1, 1], can be
@ -58,20 +57,12 @@ class PPOPolicy(A2CPolicy):
critic: torch.nn.Module, critic: torch.nn.Module,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: Type[torch.distributions.Distribution], dist_fn: Type[torch.distributions.Distribution],
max_grad_norm: Optional[float] = None,
eps_clip: float = 0.2, eps_clip: float = 0.2,
vf_coef: float = 0.5,
ent_coef: float = 0.01,
gae_lambda: float = 0.95,
dual_clip: Optional[float] = None, dual_clip: Optional[float] = None,
value_clip: bool = True, value_clip: bool = True,
max_batchsize: int = 256,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__( super().__init__(actor, critic, optim, dist_fn, **kwargs)
actor, critic, optim, dist_fn, max_grad_norm=max_grad_norm,
vf_coef=vf_coef, ent_coef=ent_coef, gae_lambda=gae_lambda,
max_batchsize=max_batchsize, **kwargs)
self._eps_clip = eps_clip self._eps_clip = eps_clip
assert dual_clip is None or dual_clip > 1.0, \ assert dual_clip is None or dual_clip > 1.0, \
"Dual-clip PPO parameter should greater than 1.0." "Dual-clip PPO parameter should greater than 1.0."
@ -90,14 +81,18 @@ class PPOPolicy(A2CPolicy):
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
v_s = to_numpy(batch.v_s) v_s = to_numpy(batch.v_s)
v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten()) v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten())
# when normalizing values, we do not minus self.ret_rms.mean to be numerically
# consistent with OPENAI baselines' value normalization pipeline. Emperical
# study also shows that "minus mean" will harm performances a tiny little bit
# due to unknown reasons (on Mujoco envs, not confident, though).
if self._rew_norm: # unnormalize v_s & v_s_ if self._rew_norm: # unnormalize v_s & v_s_
v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
unnormalized_returns, advantages = self.compute_episodic_return( unnormalized_returns, advantages = self.compute_episodic_return(
batch, buffer, indice, v_s_, v_s, batch, buffer, indice, v_s_, v_s,
gamma=self._gamma, gae_lambda=self._lambda) gamma=self._gamma, gae_lambda=self._lambda)
if self._rew_norm: if self._rew_norm:
batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ batch.returns = unnormalized_returns / \
np.sqrt(self.ret_rms.var + self._eps) np.sqrt(self.ret_rms.var + self._eps)
self.ret_rms.update(unnormalized_returns) self.ret_rms.update(unnormalized_returns)
mean, std = np.mean(advantages), np.std(advantages) mean, std = np.mean(advantages), np.std(advantages)
@ -116,8 +111,8 @@ class PPOPolicy(A2CPolicy):
losses, clip_losses, vf_losses, ent_losses = [], [], [], [] losses, clip_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat): for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True): for b in batch.split(batch_size, merge_last=True):
# calculate loss for actor
dist = self(b).dist dist = self(b).dist
value = self.critic(b.obs).flatten()
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
surr1 = ratio * b.adv surr1 = ratio * b.adv
@ -128,7 +123,8 @@ class PPOPolicy(A2CPolicy):
).mean() ).mean()
else: else:
clip_loss = -torch.min(surr1, surr2).mean() clip_loss = -torch.min(surr1, surr2).mean()
clip_losses.append(clip_loss.item()) # calculate loss for critic
value = self.critic(b.obs).flatten()
if self._value_clip: if self._value_clip:
v_clip = b.v_s + (value - b.v_s).clamp( v_clip = b.v_s + (value - b.v_s).clamp(
-self._eps_clip, self._eps_clip) -self._eps_clip, self._eps_clip)
@ -137,19 +133,21 @@ class PPOPolicy(A2CPolicy):
vf_loss = 0.5 * torch.max(vf1, vf2).mean() vf_loss = 0.5 * torch.max(vf1, vf2).mean()
else: else:
vf_loss = 0.5 * (b.returns - value).pow(2).mean() vf_loss = 0.5 * (b.returns - value).pow(2).mean()
vf_losses.append(vf_loss.item()) # calculate regularization and overall loss
e_loss = dist.entropy().mean() ent_loss = dist.entropy().mean()
ent_losses.append(e_loss.item())
loss = clip_loss + self._weight_vf * vf_loss \ loss = clip_loss + self._weight_vf * vf_loss \
- self._weight_ent * e_loss - self._weight_ent * ent_loss
losses.append(loss.item())
self.optim.zero_grad() self.optim.zero_grad()
loss.backward() loss.backward()
if self._grad_norm is not None: if self._grad_norm is not None: # clip large gradient
nn.utils.clip_grad_norm_( nn.utils.clip_grad_norm_(
list(self.actor.parameters()) + list(self.critic.parameters()), list(self.actor.parameters()) + list(self.critic.parameters()),
self._grad_norm) max_norm=self._grad_norm)
self.optim.step() self.optim.step()
clip_losses.append(clip_loss.item())
vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item())
losses.append(loss.item())
# update learning rate if lr_scheduler is given # update learning rate if lr_scheduler is given
if self.lr_scheduler is not None: if self.lr_scheduler is not None:
self.lr_scheduler.step() self.lr_scheduler.step()