Fix shape inconsistency in A2CPolicy and PPOPolicy (#155)
- The original `r - v`'s shape in A2C is wrong. - The shape of log_prob is different: [bsz] in Categorical and [bsz, 1] in Normal. Should manually make the shape to be consistent with other tensors.
This commit is contained in:
parent
865ef6c693
commit
089b85b6a2
@ -105,15 +105,24 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
"""Update policy with a given batch of data.
|
"""Update policy with a given batch of data.
|
||||||
|
|
||||||
:return: A dict which includes loss and its corresponding label.
|
:return: A dict which includes loss and its corresponding label.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
If you use ``torch.distributions.Normal`` and
|
||||||
|
``torch.distributions.Categorical`` to calculate the log_prob,
|
||||||
|
please be careful about the shape: Categorical distribution gives
|
||||||
|
"[batch_size]" shape while Normal distribution gives "[batch_size,
|
||||||
|
1]" shape. The auto-broadcasting of numerical operation with torch
|
||||||
|
tensors will amplify this error.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_episodic_return(
|
def compute_episodic_return(
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
gae_lambda: float = 0.95,
|
gae_lambda: float = 0.95,
|
||||||
) -> Batch:
|
) -> Batch:
|
||||||
"""Compute returns over given full-length episodes, including the
|
"""Compute returns over given full-length episodes, including the
|
||||||
implementation of Generalized Advantage Estimator (arXiv:1506.02438).
|
implementation of Generalized Advantage Estimator (arXiv:1506.02438).
|
||||||
@ -128,7 +137,8 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
:param float gae_lambda: the parameter for Generalized Advantage
|
:param float gae_lambda: the parameter for Generalized Advantage
|
||||||
Estimation, should be in [0, 1], defaults to 0.95.
|
Estimation, should be in [0, 1], defaults to 0.95.
|
||||||
|
|
||||||
:return: a Batch. The result will be stored in batch.returns.
|
:return: a Batch. The result will be stored in batch.returns as a numpy
|
||||||
|
array.
|
||||||
"""
|
"""
|
||||||
rew = batch.rew
|
rew = batch.rew
|
||||||
if v_s_ is None:
|
if v_s_ is None:
|
||||||
@ -157,7 +167,7 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
n_step: int = 1,
|
n_step: int = 1,
|
||||||
rew_norm: bool = False,
|
rew_norm: bool = False,
|
||||||
) -> np.ndarray:
|
) -> Batch:
|
||||||
r"""Compute n-step return for Q-learning targets:
|
r"""Compute n-step return for Q-learning targets:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
@ -204,7 +214,7 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
returns[done[now] > 0] = 0
|
returns[done[now] > 0] = 0
|
||||||
returns = (rew[now] - mean) / std + gamma * returns
|
returns = (rew[now] - mean) / std + gamma * returns
|
||||||
terminal = (indice + n_step - 1) % buf_len
|
terminal = (indice + n_step - 1) % buf_len
|
||||||
target_q = target_q_fn(buffer, terminal).squeeze()
|
target_q = target_q_fn(buffer, terminal).flatten() # shape: [bsz, ]
|
||||||
target_q[gammas != n_step] = 0
|
target_q[gammas != n_step] = 0
|
||||||
returns = to_torch_as(returns, target_q)
|
returns = to_torch_as(returns, target_q)
|
||||||
gammas = to_torch_as(gamma ** gammas, target_q)
|
gammas = to_torch_as(gamma ** gammas, target_q)
|
||||||
|
@ -105,11 +105,12 @@ class A2CPolicy(PGPolicy):
|
|||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
dist = self(b).dist
|
dist = self(b).dist
|
||||||
v = self.critic(b.obs)
|
v = self.critic(b.obs).squeeze(-1)
|
||||||
a = to_torch_as(b.act, v)
|
a = to_torch_as(b.act, v)
|
||||||
r = to_torch_as(b.returns, v)
|
r = to_torch_as(b.returns, v)
|
||||||
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
a_loss = -(dist.log_prob(a).reshape(v.shape) * (r - v).detach()
|
||||||
vf_loss = F.mse_loss(r[:, None], v)
|
).mean()
|
||||||
|
vf_loss = F.mse_loss(r, v)
|
||||||
ent_loss = dist.entropy().mean()
|
ent_loss = dist.entropy().mean()
|
||||||
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
|
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
@ -142,9 +142,8 @@ class DDPGPolicy(BasePolicy):
|
|||||||
return Batch(act=actions, state=h)
|
return Batch(act=actions, state=h)
|
||||||
|
|
||||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||||
current_q = self.critic(batch.obs, batch.act)
|
current_q = self.critic(batch.obs, batch.act).squeeze(-1)
|
||||||
target_q = to_torch_as(batch.returns, current_q)
|
target_q = batch.returns
|
||||||
target_q = target_q[:, None]
|
|
||||||
critic_loss = F.mse_loss(current_q, target_q)
|
critic_loss = F.mse_loss(current_q, target_q)
|
||||||
self.critic_optim.zero_grad()
|
self.critic_optim.zero_grad()
|
||||||
critic_loss.backward()
|
critic_loss.backward()
|
||||||
|
@ -130,11 +130,10 @@ class PPOPolicy(PGPolicy):
|
|||||||
v.append(self.critic(b.obs))
|
v.append(self.critic(b.obs))
|
||||||
old_log_prob.append(self(b).dist.log_prob(
|
old_log_prob.append(self(b).dist.log_prob(
|
||||||
to_torch_as(b.act, v[0])))
|
to_torch_as(b.act, v[0])))
|
||||||
batch.v = torch.cat(v, dim=0) # old value
|
batch.v = torch.cat(v, dim=0).squeeze(-1) # old value
|
||||||
batch.act = to_torch_as(batch.act, v[0])
|
batch.act = to_torch_as(batch.act, v[0])
|
||||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
batch.logp_old = torch.cat(old_log_prob, dim=0).reshape(batch.v.shape)
|
||||||
batch.returns = to_torch_as(
|
batch.returns = to_torch_as(batch.returns, v[0])
|
||||||
batch.returns, v[0]).reshape(batch.v.shape)
|
|
||||||
if self._rew_norm:
|
if self._rew_norm:
|
||||||
mean, std = batch.returns.mean(), batch.returns.std()
|
mean, std = batch.returns.mean(), batch.returns.std()
|
||||||
if not np.isclose(std.item(), 0):
|
if not np.isclose(std.item(), 0):
|
||||||
@ -147,8 +146,9 @@ class PPOPolicy(PGPolicy):
|
|||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
dist = self(b).dist
|
dist = self(b).dist
|
||||||
value = self.critic(b.obs)
|
value = self.critic(b.obs).squeeze(-1)
|
||||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
ratio = (dist.log_prob(b.act).reshape(value.shape) - b.logp_old
|
||||||
|
).exp().float()
|
||||||
surr1 = ratio * b.adv
|
surr1 = ratio * b.adv
|
||||||
surr2 = ratio.clamp(
|
surr2 = ratio.clamp(
|
||||||
1. - self._eps_clip, 1. + self._eps_clip) * b.adv
|
1. - self._eps_clip, 1. + self._eps_clip) * b.adv
|
||||||
|
@ -139,14 +139,14 @@ class SACPolicy(DDPGPolicy):
|
|||||||
|
|
||||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||||
# critic 1
|
# critic 1
|
||||||
current_q1 = self.critic1(batch.obs, batch.act)
|
current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1)
|
||||||
target_q = to_torch_as(batch.returns, current_q1)[:, None]
|
target_q = batch.returns
|
||||||
critic1_loss = F.mse_loss(current_q1, target_q)
|
critic1_loss = F.mse_loss(current_q1, target_q)
|
||||||
self.critic1_optim.zero_grad()
|
self.critic1_optim.zero_grad()
|
||||||
critic1_loss.backward()
|
critic1_loss.backward()
|
||||||
self.critic1_optim.step()
|
self.critic1_optim.step()
|
||||||
# critic 2
|
# critic 2
|
||||||
current_q2 = self.critic2(batch.obs, batch.act)
|
current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1)
|
||||||
critic2_loss = F.mse_loss(current_q2, target_q)
|
critic2_loss = F.mse_loss(current_q2, target_q)
|
||||||
self.critic2_optim.zero_grad()
|
self.critic2_optim.zero_grad()
|
||||||
critic2_loss.backward()
|
critic2_loss.backward()
|
||||||
@ -154,10 +154,10 @@ class SACPolicy(DDPGPolicy):
|
|||||||
# actor
|
# actor
|
||||||
obs_result = self(batch, explorating=False)
|
obs_result = self(batch, explorating=False)
|
||||||
a = obs_result.act
|
a = obs_result.act
|
||||||
current_q1a = self.critic1(batch.obs, a)
|
current_q1a = self.critic1(batch.obs, a).squeeze(-1)
|
||||||
current_q2a = self.critic2(batch.obs, a)
|
current_q2a = self.critic2(batch.obs, a).squeeze(-1)
|
||||||
actor_loss = (self._alpha * obs_result.log_prob - torch.min(
|
actor_loss = (self._alpha * obs_result.log_prob.reshape(
|
||||||
current_q1a, current_q2a)).mean()
|
target_q.shape) - torch.min(current_q1a, current_q2a)).mean()
|
||||||
self.actor_optim.zero_grad()
|
self.actor_optim.zero_grad()
|
||||||
actor_loss.backward()
|
actor_loss.backward()
|
||||||
self.actor_optim.step()
|
self.actor_optim.step()
|
||||||
|
@ -117,14 +117,14 @@ class TD3Policy(DDPGPolicy):
|
|||||||
|
|
||||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||||
# critic 1
|
# critic 1
|
||||||
current_q1 = self.critic1(batch.obs, batch.act)
|
current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1)
|
||||||
target_q = batch.returns[:, None]
|
target_q = batch.returns
|
||||||
critic1_loss = F.mse_loss(current_q1, target_q)
|
critic1_loss = F.mse_loss(current_q1, target_q)
|
||||||
self.critic1_optim.zero_grad()
|
self.critic1_optim.zero_grad()
|
||||||
critic1_loss.backward()
|
critic1_loss.backward()
|
||||||
self.critic1_optim.step()
|
self.critic1_optim.step()
|
||||||
# critic 2
|
# critic 2
|
||||||
current_q2 = self.critic2(batch.obs, batch.act)
|
current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1)
|
||||||
critic2_loss = F.mse_loss(current_q2, target_q)
|
critic2_loss = F.mse_loss(current_q2, target_q)
|
||||||
self.critic2_optim.zero_grad()
|
self.critic2_optim.zero_grad()
|
||||||
critic2_loss.backward()
|
critic2_loss.backward()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user