Fix shape inconsistency in A2CPolicy and PPOPolicy (#155)
- The original `r - v`'s shape in A2C is wrong. - The shape of log_prob is different: [bsz] in Categorical and [bsz, 1] in Normal. Should manually make the shape to be consistent with other tensors.
This commit is contained in:
parent
865ef6c693
commit
089b85b6a2
@ -105,15 +105,24 @@ class BasePolicy(ABC, nn.Module):
|
||||
"""Update policy with a given batch of data.
|
||||
|
||||
:return: A dict which includes loss and its corresponding label.
|
||||
|
||||
.. warning::
|
||||
|
||||
If you use ``torch.distributions.Normal`` and
|
||||
``torch.distributions.Categorical`` to calculate the log_prob,
|
||||
please be careful about the shape: Categorical distribution gives
|
||||
"[batch_size]" shape while Normal distribution gives "[batch_size,
|
||||
1]" shape. The auto-broadcasting of numerical operation with torch
|
||||
tensors will amplify this error.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def compute_episodic_return(
|
||||
batch: Batch,
|
||||
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95,
|
||||
batch: Batch,
|
||||
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95,
|
||||
) -> Batch:
|
||||
"""Compute returns over given full-length episodes, including the
|
||||
implementation of Generalized Advantage Estimator (arXiv:1506.02438).
|
||||
@ -128,7 +137,8 @@ class BasePolicy(ABC, nn.Module):
|
||||
:param float gae_lambda: the parameter for Generalized Advantage
|
||||
Estimation, should be in [0, 1], defaults to 0.95.
|
||||
|
||||
:return: a Batch. The result will be stored in batch.returns.
|
||||
:return: a Batch. The result will be stored in batch.returns as a numpy
|
||||
array.
|
||||
"""
|
||||
rew = batch.rew
|
||||
if v_s_ is None:
|
||||
@ -157,7 +167,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
gamma: float = 0.99,
|
||||
n_step: int = 1,
|
||||
rew_norm: bool = False,
|
||||
) -> np.ndarray:
|
||||
) -> Batch:
|
||||
r"""Compute n-step return for Q-learning targets:
|
||||
|
||||
.. math::
|
||||
@ -204,7 +214,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
returns[done[now] > 0] = 0
|
||||
returns = (rew[now] - mean) / std + gamma * returns
|
||||
terminal = (indice + n_step - 1) % buf_len
|
||||
target_q = target_q_fn(buffer, terminal).squeeze()
|
||||
target_q = target_q_fn(buffer, terminal).flatten() # shape: [bsz, ]
|
||||
target_q[gammas != n_step] = 0
|
||||
returns = to_torch_as(returns, target_q)
|
||||
gammas = to_torch_as(gamma ** gammas, target_q)
|
||||
|
@ -105,11 +105,12 @@ class A2CPolicy(PGPolicy):
|
||||
for b in batch.split(batch_size):
|
||||
self.optim.zero_grad()
|
||||
dist = self(b).dist
|
||||
v = self.critic(b.obs)
|
||||
v = self.critic(b.obs).squeeze(-1)
|
||||
a = to_torch_as(b.act, v)
|
||||
r = to_torch_as(b.returns, v)
|
||||
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
||||
vf_loss = F.mse_loss(r[:, None], v)
|
||||
a_loss = -(dist.log_prob(a).reshape(v.shape) * (r - v).detach()
|
||||
).mean()
|
||||
vf_loss = F.mse_loss(r, v)
|
||||
ent_loss = dist.entropy().mean()
|
||||
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
|
||||
loss.backward()
|
||||
|
@ -142,9 +142,8 @@ class DDPGPolicy(BasePolicy):
|
||||
return Batch(act=actions, state=h)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
current_q = self.critic(batch.obs, batch.act)
|
||||
target_q = to_torch_as(batch.returns, current_q)
|
||||
target_q = target_q[:, None]
|
||||
current_q = self.critic(batch.obs, batch.act).squeeze(-1)
|
||||
target_q = batch.returns
|
||||
critic_loss = F.mse_loss(current_q, target_q)
|
||||
self.critic_optim.zero_grad()
|
||||
critic_loss.backward()
|
||||
|
@ -130,11 +130,10 @@ class PPOPolicy(PGPolicy):
|
||||
v.append(self.critic(b.obs))
|
||||
old_log_prob.append(self(b).dist.log_prob(
|
||||
to_torch_as(b.act, v[0])))
|
||||
batch.v = torch.cat(v, dim=0) # old value
|
||||
batch.v = torch.cat(v, dim=0).squeeze(-1) # old value
|
||||
batch.act = to_torch_as(batch.act, v[0])
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
batch.returns = to_torch_as(
|
||||
batch.returns, v[0]).reshape(batch.v.shape)
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0).reshape(batch.v.shape)
|
||||
batch.returns = to_torch_as(batch.returns, v[0])
|
||||
if self._rew_norm:
|
||||
mean, std = batch.returns.mean(), batch.returns.std()
|
||||
if not np.isclose(std.item(), 0):
|
||||
@ -147,8 +146,9 @@ class PPOPolicy(PGPolicy):
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
dist = self(b).dist
|
||||
value = self.critic(b.obs)
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
value = self.critic(b.obs).squeeze(-1)
|
||||
ratio = (dist.log_prob(b.act).reshape(value.shape) - b.logp_old
|
||||
).exp().float()
|
||||
surr1 = ratio * b.adv
|
||||
surr2 = ratio.clamp(
|
||||
1. - self._eps_clip, 1. + self._eps_clip) * b.adv
|
||||
|
@ -139,14 +139,14 @@ class SACPolicy(DDPGPolicy):
|
||||
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
# critic 1
|
||||
current_q1 = self.critic1(batch.obs, batch.act)
|
||||
target_q = to_torch_as(batch.returns, current_q1)[:, None]
|
||||
current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1)
|
||||
target_q = batch.returns
|
||||
critic1_loss = F.mse_loss(current_q1, target_q)
|
||||
self.critic1_optim.zero_grad()
|
||||
critic1_loss.backward()
|
||||
self.critic1_optim.step()
|
||||
# critic 2
|
||||
current_q2 = self.critic2(batch.obs, batch.act)
|
||||
current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1)
|
||||
critic2_loss = F.mse_loss(current_q2, target_q)
|
||||
self.critic2_optim.zero_grad()
|
||||
critic2_loss.backward()
|
||||
@ -154,10 +154,10 @@ class SACPolicy(DDPGPolicy):
|
||||
# actor
|
||||
obs_result = self(batch, explorating=False)
|
||||
a = obs_result.act
|
||||
current_q1a = self.critic1(batch.obs, a)
|
||||
current_q2a = self.critic2(batch.obs, a)
|
||||
actor_loss = (self._alpha * obs_result.log_prob - torch.min(
|
||||
current_q1a, current_q2a)).mean()
|
||||
current_q1a = self.critic1(batch.obs, a).squeeze(-1)
|
||||
current_q2a = self.critic2(batch.obs, a).squeeze(-1)
|
||||
actor_loss = (self._alpha * obs_result.log_prob.reshape(
|
||||
target_q.shape) - torch.min(current_q1a, current_q2a)).mean()
|
||||
self.actor_optim.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optim.step()
|
||||
|
@ -117,14 +117,14 @@ class TD3Policy(DDPGPolicy):
|
||||
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
# critic 1
|
||||
current_q1 = self.critic1(batch.obs, batch.act)
|
||||
target_q = batch.returns[:, None]
|
||||
current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1)
|
||||
target_q = batch.returns
|
||||
critic1_loss = F.mse_loss(current_q1, target_q)
|
||||
self.critic1_optim.zero_grad()
|
||||
critic1_loss.backward()
|
||||
self.critic1_optim.step()
|
||||
# critic 2
|
||||
current_q2 = self.critic2(batch.obs, batch.act)
|
||||
current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1)
|
||||
critic2_loss = F.mse_loss(current_q2, target_q)
|
||||
self.critic2_optim.zero_grad()
|
||||
critic2_loss.backward()
|
||||
|
Loading…
x
Reference in New Issue
Block a user