diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index c976362..f9799da 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -105,15 +105,24 @@ class BasePolicy(ABC, nn.Module): """Update policy with a given batch of data. :return: A dict which includes loss and its corresponding label. + + .. warning:: + + If you use ``torch.distributions.Normal`` and + ``torch.distributions.Categorical`` to calculate the log_prob, + please be careful about the shape: Categorical distribution gives + "[batch_size]" shape while Normal distribution gives "[batch_size, + 1]" shape. The auto-broadcasting of numerical operation with torch + tensors will amplify this error. """ pass @staticmethod def compute_episodic_return( - batch: Batch, - v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None, - gamma: float = 0.99, - gae_lambda: float = 0.95, + batch: Batch, + v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None, + gamma: float = 0.99, + gae_lambda: float = 0.95, ) -> Batch: """Compute returns over given full-length episodes, including the implementation of Generalized Advantage Estimator (arXiv:1506.02438). @@ -128,7 +137,8 @@ class BasePolicy(ABC, nn.Module): :param float gae_lambda: the parameter for Generalized Advantage Estimation, should be in [0, 1], defaults to 0.95. - :return: a Batch. The result will be stored in batch.returns. + :return: a Batch. The result will be stored in batch.returns as a numpy + array. """ rew = batch.rew if v_s_ is None: @@ -157,7 +167,7 @@ class BasePolicy(ABC, nn.Module): gamma: float = 0.99, n_step: int = 1, rew_norm: bool = False, - ) -> np.ndarray: + ) -> Batch: r"""Compute n-step return for Q-learning targets: .. math:: @@ -204,7 +214,7 @@ class BasePolicy(ABC, nn.Module): returns[done[now] > 0] = 0 returns = (rew[now] - mean) / std + gamma * returns terminal = (indice + n_step - 1) % buf_len - target_q = target_q_fn(buffer, terminal).squeeze() + target_q = target_q_fn(buffer, terminal).flatten() # shape: [bsz, ] target_q[gammas != n_step] = 0 returns = to_torch_as(returns, target_q) gammas = to_torch_as(gamma ** gammas, target_q) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index f4de285..569d98a 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -105,11 +105,12 @@ class A2CPolicy(PGPolicy): for b in batch.split(batch_size): self.optim.zero_grad() dist = self(b).dist - v = self.critic(b.obs) + v = self.critic(b.obs).squeeze(-1) a = to_torch_as(b.act, v) r = to_torch_as(b.returns, v) - a_loss = -(dist.log_prob(a) * (r - v).detach()).mean() - vf_loss = F.mse_loss(r[:, None], v) + a_loss = -(dist.log_prob(a).reshape(v.shape) * (r - v).detach() + ).mean() + vf_loss = F.mse_loss(r, v) ent_loss = dist.entropy().mean() loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss loss.backward() diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 0944e8d..5dff17f 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -142,9 +142,8 @@ class DDPGPolicy(BasePolicy): return Batch(act=actions, state=h) def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: - current_q = self.critic(batch.obs, batch.act) - target_q = to_torch_as(batch.returns, current_q) - target_q = target_q[:, None] + current_q = self.critic(batch.obs, batch.act).squeeze(-1) + target_q = batch.returns critic_loss = F.mse_loss(current_q, target_q) self.critic_optim.zero_grad() critic_loss.backward() diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index a7abec5..dbf0f24 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -130,11 +130,10 @@ class PPOPolicy(PGPolicy): v.append(self.critic(b.obs)) old_log_prob.append(self(b).dist.log_prob( to_torch_as(b.act, v[0]))) - batch.v = torch.cat(v, dim=0) # old value + batch.v = torch.cat(v, dim=0).squeeze(-1) # old value batch.act = to_torch_as(batch.act, v[0]) - batch.logp_old = torch.cat(old_log_prob, dim=0) - batch.returns = to_torch_as( - batch.returns, v[0]).reshape(batch.v.shape) + batch.logp_old = torch.cat(old_log_prob, dim=0).reshape(batch.v.shape) + batch.returns = to_torch_as(batch.returns, v[0]) if self._rew_norm: mean, std = batch.returns.mean(), batch.returns.std() if not np.isclose(std.item(), 0): @@ -147,8 +146,9 @@ class PPOPolicy(PGPolicy): for _ in range(repeat): for b in batch.split(batch_size): dist = self(b).dist - value = self.critic(b.obs) - ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() + value = self.critic(b.obs).squeeze(-1) + ratio = (dist.log_prob(b.act).reshape(value.shape) - b.logp_old + ).exp().float() surr1 = ratio * b.adv surr2 = ratio.clamp( 1. - self._eps_clip, 1. + self._eps_clip) * b.adv diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index ab0a633..5fdf359 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -139,14 +139,14 @@ class SACPolicy(DDPGPolicy): def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: # critic 1 - current_q1 = self.critic1(batch.obs, batch.act) - target_q = to_torch_as(batch.returns, current_q1)[:, None] + current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1) + target_q = batch.returns critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 - current_q2 = self.critic2(batch.obs, batch.act) + current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1) critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() @@ -154,10 +154,10 @@ class SACPolicy(DDPGPolicy): # actor obs_result = self(batch, explorating=False) a = obs_result.act - current_q1a = self.critic1(batch.obs, a) - current_q2a = self.critic2(batch.obs, a) - actor_loss = (self._alpha * obs_result.log_prob - torch.min( - current_q1a, current_q2a)).mean() + current_q1a = self.critic1(batch.obs, a).squeeze(-1) + current_q2a = self.critic2(batch.obs, a).squeeze(-1) + actor_loss = (self._alpha * obs_result.log_prob.reshape( + target_q.shape) - torch.min(current_q1a, current_q2a)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 5d9bc37..1ea3a90 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -117,14 +117,14 @@ class TD3Policy(DDPGPolicy): def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: # critic 1 - current_q1 = self.critic1(batch.obs, batch.act) - target_q = batch.returns[:, None] + current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1) + target_q = batch.returns critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 - current_q2 = self.critic2(batch.obs, batch.act) + current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1) critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward()