refactor A2C/PPO, change behavior of value normalization (#321)
This commit is contained in:
parent
47c77899d5
commit
3ac67d9974
@ -43,7 +43,7 @@ This will start 10 experiments with different seeds.
|
||||
|
||||
#### Example benchmark
|
||||
|
||||
<img src="./benchmark/Ant-v3/figure.png" width="500" height="450">
|
||||
<img src="./benchmark/Ant-v3/offpolicy.png" width="500" height="450">
|
||||
|
||||
Other graphs can be found under `/examples/mujuco/benchmark/`
|
||||
|
||||
|
@ -20,22 +20,22 @@ def get_args():
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--il-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--epoch', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=50000)
|
||||
parser.add_argument('--il-step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--episode-per-collect', type=int, default=8)
|
||||
parser.add_argument('--step-per-collect', type=int, default=8)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.125)
|
||||
parser.add_argument('--episode-per-collect', type=int, default=16)
|
||||
parser.add_argument('--step-per-collect', type=int, default=16)
|
||||
parser.add_argument('--update-per-step', type=float, default=1 / 16)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128, 128])
|
||||
nargs='*', default=[64, 64])
|
||||
parser.add_argument('--imitation-hidden-sizes', type=int,
|
||||
nargs='*', default=[128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
||||
from typing import Any, Dict, List, Type, Optional
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
|
||||
|
||||
|
||||
class A2CPolicy(PGPolicy):
|
||||
@ -21,12 +21,12 @@ class A2CPolicy(PGPolicy):
|
||||
:param float discount_factor: in [0, 1]. Default to 0.99.
|
||||
:param float vf_coef: weight for value loss. Default to 0.5.
|
||||
:param float ent_coef: weight for entropy loss. Default to 0.01.
|
||||
:param float max_grad_norm: clipping gradients in back propagation.
|
||||
Default to None.
|
||||
:param float gae_lambda: in [0, 1], param for Generalized Advantage
|
||||
Estimation. Default to 0.95.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param float max_grad_norm: clipping gradients in back propagation. Default to
|
||||
None.
|
||||
:param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
|
||||
Default to 0.95.
|
||||
:param bool reward_normalization: normalize estimated values to have std close to
|
||||
1. Default to False.
|
||||
:param int max_batchsize: the maximum size of the batch when computing GAE,
|
||||
depends on the size of available memory and the memory cost of the
|
||||
model; should be as large as possible within the memory constraint.
|
||||
@ -72,22 +72,33 @@ class A2CPolicy(PGPolicy):
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
v_s_ = []
|
||||
v_s, v_s_ = [], []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
||||
v_s_.append(to_numpy(self.critic(b.obs_next)))
|
||||
v_s_ = np.concatenate(v_s_, axis=0)
|
||||
if self._rew_norm: # unnormalize v_s_
|
||||
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
|
||||
unnormalized_returns, _ = self.compute_episodic_return(
|
||||
batch, buffer, indice, v_s_=v_s_,
|
||||
v_s.append(self.critic(b.obs))
|
||||
v_s_.append(self.critic(b.obs_next))
|
||||
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
|
||||
v_s = to_numpy(batch.v_s)
|
||||
v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten())
|
||||
# when normalizing values, we do not minus self.ret_rms.mean to be numerically
|
||||
# consistent with OPENAI baselines' value normalization pipeline. Emperical
|
||||
# study also shows that "minus mean" will harm performances a tiny little bit
|
||||
# due to unknown reasons (on Mujoco envs, not confident, though).
|
||||
if self._rew_norm: # unnormalize v_s & v_s_
|
||||
v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
|
||||
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
|
||||
unnormalized_returns, advantages = self.compute_episodic_return(
|
||||
batch, buffer, indice, v_s_, v_s,
|
||||
gamma=self._gamma, gae_lambda=self._lambda)
|
||||
if self._rew_norm:
|
||||
batch.returns = (unnormalized_returns - self.ret_rms.mean) / \
|
||||
batch.returns = unnormalized_returns / \
|
||||
np.sqrt(self.ret_rms.var + self._eps)
|
||||
self.ret_rms.update(unnormalized_returns)
|
||||
else:
|
||||
batch.returns = unnormalized_returns
|
||||
batch.act = to_torch_as(batch.act, batch.v_s)
|
||||
batch.returns = to_torch_as(batch.returns, batch.v_s)
|
||||
batch.adv = to_torch_as(advantages, batch.v_s)
|
||||
return batch
|
||||
|
||||
def learn( # type: ignore
|
||||
@ -96,24 +107,25 @@ class A2CPolicy(PGPolicy):
|
||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size, merge_last=True):
|
||||
self.optim.zero_grad()
|
||||
# calculate loss for actor
|
||||
dist = self(b).dist
|
||||
v = self.critic(b.obs).flatten()
|
||||
a = to_torch_as(b.act, v)
|
||||
r = to_torch_as(b.returns, v)
|
||||
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
|
||||
a_loss = -(log_prob * (r - v).detach()).mean()
|
||||
vf_loss = F.mse_loss(r, v) # type: ignore
|
||||
log_prob = dist.log_prob(b.act).reshape(len(b.adv), -1).transpose(0, 1)
|
||||
actor_loss = -(log_prob * b.adv).mean()
|
||||
# calculate loss for critic
|
||||
value = self.critic(b.obs).flatten()
|
||||
vf_loss = F.mse_loss(b.returns, value)
|
||||
# calculate regularization and overall loss
|
||||
ent_loss = dist.entropy().mean()
|
||||
loss = a_loss + self._weight_vf * vf_loss - self._weight_ent * ent_loss
|
||||
loss = actor_loss + self._weight_vf * vf_loss \
|
||||
- self._weight_ent * ent_loss
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
if self._grad_norm is not None:
|
||||
if self._grad_norm is not None: # clip large gradient
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()),
|
||||
max_norm=self._grad_norm,
|
||||
)
|
||||
max_norm=self._grad_norm)
|
||||
self.optim.step()
|
||||
actor_losses.append(a_loss.item())
|
||||
actor_losses.append(actor_loss.item())
|
||||
vf_losses.append(vf_loss.item())
|
||||
ent_losses.append(ent_loss.item())
|
||||
losses.append(loss.item())
|
||||
|
@ -116,9 +116,9 @@ class PGPolicy(BasePolicy):
|
||||
result = self(b)
|
||||
dist = result.dist
|
||||
a = to_torch_as(b.act, result.act)
|
||||
r = to_torch_as(b.returns, result.act)
|
||||
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
|
||||
loss = -(log_prob * r).mean()
|
||||
ret = to_torch_as(b.returns, result.act)
|
||||
log_prob = dist.log_prob(a).reshape(len(ret), -1).transpose(0, 1)
|
||||
loss = -(log_prob * ret).mean()
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
losses.append(loss.item())
|
||||
|
@ -17,25 +17,24 @@ class PPOPolicy(A2CPolicy):
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
:type dist_fn: Type[torch.distributions.Distribution]
|
||||
:param float discount_factor: in [0, 1]. Default to 0.99.
|
||||
:param float max_grad_norm: clipping gradients in back propagation.
|
||||
Default to None.
|
||||
:param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original
|
||||
paper. Default to 0.2.
|
||||
:param float vf_coef: weight for value loss. Default to 0.5.
|
||||
:param float ent_coef: weight for entropy loss. Default to 0.01.
|
||||
:param float gae_lambda: in [0, 1], param for Generalized Advantage
|
||||
Estimation. Default to 0.95.
|
||||
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
|
||||
where c > 1 is a constant indicating the lower bound.
|
||||
Default to 5.0 (set None if you do not want to use it).
|
||||
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1.
|
||||
Default to True.
|
||||
:param bool reward_normalization: normalize the returns and advantage to
|
||||
Normal(0, 1). Default to False.
|
||||
:param float vf_coef: weight for value loss. Default to 0.5.
|
||||
:param float ent_coef: weight for entropy loss. Default to 0.01.
|
||||
:param float max_grad_norm: clipping gradients in back propagation. Default to
|
||||
None.
|
||||
:param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
|
||||
Default to 0.95.
|
||||
:param bool reward_normalization: normalize estimated values to have std close
|
||||
to 1, also normalize the advantage to Normal(0, 1). Default to False.
|
||||
:param int max_batchsize: the maximum size of the batch when computing GAE,
|
||||
depends on the size of available memory and the memory cost of the
|
||||
model; should be as large as possible within the memory constraint.
|
||||
Default to 256.
|
||||
depends on the size of available memory and the memory cost of the model;
|
||||
should be as large as possible within the memory constraint. Default to 256.
|
||||
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||
[action_spaces.low, action_spaces.high]. Default to True.
|
||||
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||
@ -58,20 +57,12 @@ class PPOPolicy(A2CPolicy):
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Type[torch.distributions.Distribution],
|
||||
max_grad_norm: Optional[float] = None,
|
||||
eps_clip: float = 0.2,
|
||||
vf_coef: float = 0.5,
|
||||
ent_coef: float = 0.01,
|
||||
gae_lambda: float = 0.95,
|
||||
dual_clip: Optional[float] = None,
|
||||
value_clip: bool = True,
|
||||
max_batchsize: int = 256,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
actor, critic, optim, dist_fn, max_grad_norm=max_grad_norm,
|
||||
vf_coef=vf_coef, ent_coef=ent_coef, gae_lambda=gae_lambda,
|
||||
max_batchsize=max_batchsize, **kwargs)
|
||||
super().__init__(actor, critic, optim, dist_fn, **kwargs)
|
||||
self._eps_clip = eps_clip
|
||||
assert dual_clip is None or dual_clip > 1.0, \
|
||||
"Dual-clip PPO parameter should greater than 1.0."
|
||||
@ -90,14 +81,18 @@ class PPOPolicy(A2CPolicy):
|
||||
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
|
||||
v_s = to_numpy(batch.v_s)
|
||||
v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten())
|
||||
# when normalizing values, we do not minus self.ret_rms.mean to be numerically
|
||||
# consistent with OPENAI baselines' value normalization pipeline. Emperical
|
||||
# study also shows that "minus mean" will harm performances a tiny little bit
|
||||
# due to unknown reasons (on Mujoco envs, not confident, though).
|
||||
if self._rew_norm: # unnormalize v_s & v_s_
|
||||
v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
|
||||
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
|
||||
v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
|
||||
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
|
||||
unnormalized_returns, advantages = self.compute_episodic_return(
|
||||
batch, buffer, indice, v_s_, v_s,
|
||||
gamma=self._gamma, gae_lambda=self._lambda)
|
||||
if self._rew_norm:
|
||||
batch.returns = (unnormalized_returns - self.ret_rms.mean) / \
|
||||
batch.returns = unnormalized_returns / \
|
||||
np.sqrt(self.ret_rms.var + self._eps)
|
||||
self.ret_rms.update(unnormalized_returns)
|
||||
mean, std = np.mean(advantages), np.std(advantages)
|
||||
@ -116,8 +111,8 @@ class PPOPolicy(A2CPolicy):
|
||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size, merge_last=True):
|
||||
# calculate loss for actor
|
||||
dist = self(b).dist
|
||||
value = self.critic(b.obs).flatten()
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
||||
surr1 = ratio * b.adv
|
||||
@ -128,7 +123,8 @@ class PPOPolicy(A2CPolicy):
|
||||
).mean()
|
||||
else:
|
||||
clip_loss = -torch.min(surr1, surr2).mean()
|
||||
clip_losses.append(clip_loss.item())
|
||||
# calculate loss for critic
|
||||
value = self.critic(b.obs).flatten()
|
||||
if self._value_clip:
|
||||
v_clip = b.v_s + (value - b.v_s).clamp(
|
||||
-self._eps_clip, self._eps_clip)
|
||||
@ -137,19 +133,21 @@ class PPOPolicy(A2CPolicy):
|
||||
vf_loss = 0.5 * torch.max(vf1, vf2).mean()
|
||||
else:
|
||||
vf_loss = 0.5 * (b.returns - value).pow(2).mean()
|
||||
vf_losses.append(vf_loss.item())
|
||||
e_loss = dist.entropy().mean()
|
||||
ent_losses.append(e_loss.item())
|
||||
# calculate regularization and overall loss
|
||||
ent_loss = dist.entropy().mean()
|
||||
loss = clip_loss + self._weight_vf * vf_loss \
|
||||
- self._weight_ent * e_loss
|
||||
losses.append(loss.item())
|
||||
- self._weight_ent * ent_loss
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
if self._grad_norm is not None:
|
||||
if self._grad_norm is not None: # clip large gradient
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()),
|
||||
self._grad_norm)
|
||||
max_norm=self._grad_norm)
|
||||
self.optim.step()
|
||||
clip_losses.append(clip_loss.item())
|
||||
vf_losses.append(vf_loss.item())
|
||||
ent_losses.append(ent_loss.item())
|
||||
losses.append(loss.item())
|
||||
# update learning rate if lr_scheduler is given
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
|
Loading…
x
Reference in New Issue
Block a user