Fix shape inconsistency in A2CPolicy and PPOPolicy (#155)

- The original `r - v`'s shape in A2C is wrong.

- The shape of log_prob is different: [bsz] in Categorical and [bsz, 1] in Normal. Should manually make the shape to be consistent with other tensors.
This commit is contained in:
n+e 2020-07-21 22:24:06 +08:00 committed by GitHub
parent 865ef6c693
commit 089b85b6a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 39 additions and 29 deletions

View File

@ -105,15 +105,24 @@ class BasePolicy(ABC, nn.Module):
"""Update policy with a given batch of data.
:return: A dict which includes loss and its corresponding label.
.. warning::
If you use ``torch.distributions.Normal`` and
``torch.distributions.Categorical`` to calculate the log_prob,
please be careful about the shape: Categorical distribution gives
"[batch_size]" shape while Normal distribution gives "[batch_size,
1]" shape. The auto-broadcasting of numerical operation with torch
tensors will amplify this error.
"""
pass
@staticmethod
def compute_episodic_return(
batch: Batch,
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
gamma: float = 0.99,
gae_lambda: float = 0.95,
batch: Batch,
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
gamma: float = 0.99,
gae_lambda: float = 0.95,
) -> Batch:
"""Compute returns over given full-length episodes, including the
implementation of Generalized Advantage Estimator (arXiv:1506.02438).
@ -128,7 +137,8 @@ class BasePolicy(ABC, nn.Module):
:param float gae_lambda: the parameter for Generalized Advantage
Estimation, should be in [0, 1], defaults to 0.95.
:return: a Batch. The result will be stored in batch.returns.
:return: a Batch. The result will be stored in batch.returns as a numpy
array.
"""
rew = batch.rew
if v_s_ is None:
@ -157,7 +167,7 @@ class BasePolicy(ABC, nn.Module):
gamma: float = 0.99,
n_step: int = 1,
rew_norm: bool = False,
) -> np.ndarray:
) -> Batch:
r"""Compute n-step return for Q-learning targets:
.. math::
@ -204,7 +214,7 @@ class BasePolicy(ABC, nn.Module):
returns[done[now] > 0] = 0
returns = (rew[now] - mean) / std + gamma * returns
terminal = (indice + n_step - 1) % buf_len
target_q = target_q_fn(buffer, terminal).squeeze()
target_q = target_q_fn(buffer, terminal).flatten() # shape: [bsz, ]
target_q[gammas != n_step] = 0
returns = to_torch_as(returns, target_q)
gammas = to_torch_as(gamma ** gammas, target_q)

View File

@ -105,11 +105,12 @@ class A2CPolicy(PGPolicy):
for b in batch.split(batch_size):
self.optim.zero_grad()
dist = self(b).dist
v = self.critic(b.obs)
v = self.critic(b.obs).squeeze(-1)
a = to_torch_as(b.act, v)
r = to_torch_as(b.returns, v)
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
vf_loss = F.mse_loss(r[:, None], v)
a_loss = -(dist.log_prob(a).reshape(v.shape) * (r - v).detach()
).mean()
vf_loss = F.mse_loss(r, v)
ent_loss = dist.entropy().mean()
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
loss.backward()

View File

@ -142,9 +142,8 @@ class DDPGPolicy(BasePolicy):
return Batch(act=actions, state=h)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
current_q = self.critic(batch.obs, batch.act)
target_q = to_torch_as(batch.returns, current_q)
target_q = target_q[:, None]
current_q = self.critic(batch.obs, batch.act).squeeze(-1)
target_q = batch.returns
critic_loss = F.mse_loss(current_q, target_q)
self.critic_optim.zero_grad()
critic_loss.backward()

View File

@ -130,11 +130,10 @@ class PPOPolicy(PGPolicy):
v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob(
to_torch_as(b.act, v[0])))
batch.v = torch.cat(v, dim=0) # old value
batch.v = torch.cat(v, dim=0).squeeze(-1) # old value
batch.act = to_torch_as(batch.act, v[0])
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = to_torch_as(
batch.returns, v[0]).reshape(batch.v.shape)
batch.logp_old = torch.cat(old_log_prob, dim=0).reshape(batch.v.shape)
batch.returns = to_torch_as(batch.returns, v[0])
if self._rew_norm:
mean, std = batch.returns.mean(), batch.returns.std()
if not np.isclose(std.item(), 0):
@ -147,8 +146,9 @@ class PPOPolicy(PGPolicy):
for _ in range(repeat):
for b in batch.split(batch_size):
dist = self(b).dist
value = self.critic(b.obs)
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
value = self.critic(b.obs).squeeze(-1)
ratio = (dist.log_prob(b.act).reshape(value.shape) - b.logp_old
).exp().float()
surr1 = ratio * b.adv
surr2 = ratio.clamp(
1. - self._eps_clip, 1. + self._eps_clip) * b.adv

View File

@ -139,14 +139,14 @@ class SACPolicy(DDPGPolicy):
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
# critic 1
current_q1 = self.critic1(batch.obs, batch.act)
target_q = to_torch_as(batch.returns, current_q1)[:, None]
current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1)
target_q = batch.returns
critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad()
critic1_loss.backward()
self.critic1_optim.step()
# critic 2
current_q2 = self.critic2(batch.obs, batch.act)
current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1)
critic2_loss = F.mse_loss(current_q2, target_q)
self.critic2_optim.zero_grad()
critic2_loss.backward()
@ -154,10 +154,10 @@ class SACPolicy(DDPGPolicy):
# actor
obs_result = self(batch, explorating=False)
a = obs_result.act
current_q1a = self.critic1(batch.obs, a)
current_q2a = self.critic2(batch.obs, a)
actor_loss = (self._alpha * obs_result.log_prob - torch.min(
current_q1a, current_q2a)).mean()
current_q1a = self.critic1(batch.obs, a).squeeze(-1)
current_q2a = self.critic2(batch.obs, a).squeeze(-1)
actor_loss = (self._alpha * obs_result.log_prob.reshape(
target_q.shape) - torch.min(current_q1a, current_q2a)).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()

View File

@ -117,14 +117,14 @@ class TD3Policy(DDPGPolicy):
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
# critic 1
current_q1 = self.critic1(batch.obs, batch.act)
target_q = batch.returns[:, None]
current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1)
target_q = batch.returns
critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad()
critic1_loss.backward()
self.critic1_optim.step()
# critic 2
current_q2 = self.critic2(batch.obs, batch.act)
current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1)
critic2_loss = F.mse_loss(current_q2, target_q)
self.critic2_optim.zero_grad()
critic2_loss.backward()