From 352a5183995770748dbced04b737e9e89db556f5 Mon Sep 17 00:00:00 2001 From: n+e <463003665@qq.com> Date: Thu, 23 Jul 2020 15:12:02 +0800 Subject: [PATCH] 3 fix (#158) - fix 2 warning in doctest - change the minimum version of gym (to be aligned with openai baselines) - change squeeze and reshape to flatten (related to #155). I think flatten is better. --- docs/conf.py | 4 ++-- setup.py | 2 +- test/continuous/test_ppo.py | 2 +- tianshou/policy/base.py | 11 +++-------- tianshou/policy/modelfree/a2c.py | 4 ++-- tianshou/policy/modelfree/ddpg.py | 4 ++-- tianshou/policy/modelfree/dqn.py | 2 +- tianshou/policy/modelfree/pg.py | 2 +- tianshou/policy/modelfree/ppo.py | 14 +++++++------- tianshou/policy/modelfree/sac.py | 14 +++++++------- tianshou/policy/modelfree/td3.py | 6 +++--- 11 files changed, 30 insertions(+), 35 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 19e4a9f..eb6f65f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -79,8 +79,8 @@ html_logo = '_static/images/tianshou-logo.png' def setup(app): - app.add_javascript("js/copybutton.js") - app.add_stylesheet("css/style.css") + app.add_js_file("js/copybutton.js") + app.add_css_file("css/style.css") # -- Extension configuration ------------------------------------------------- diff --git a/setup.py b/setup.py index 3fee82c..d1d1127 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ setup( 'examples', 'examples.*', 'docs', 'docs.*']), install_requires=[ - 'gym>=0.15.0', + 'gym>=0.15.4', 'tqdm', 'numpy', 'tensorboard', diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index daa7b06..73aad34 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -23,7 +23,7 @@ def get_args(): parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--step-per-epoch', type=int, default=2400) parser.add_argument('--collect-per-step', type=int, default=1) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=128) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index f9799da..4690749 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -4,7 +4,7 @@ from torch import nn from abc import ABC, abstractmethod from typing import Dict, List, Union, Optional, Callable -from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy class BasePolicy(ABC, nn.Module): @@ -138,15 +138,10 @@ class BasePolicy(ABC, nn.Module): Estimation, should be in [0, 1], defaults to 0.95. :return: a Batch. The result will be stored in batch.returns as a numpy - array. + array with shape (bsz, ). """ rew = batch.rew - if v_s_ is None: - v_s_ = rew * 0. - else: - if not isinstance(v_s_, np.ndarray): - v_s_ = np.array(v_s_, np.float) - v_s_ = v_s_.reshape(rew.shape) + v_s_ = rew * 0. if v_s_ is None else to_numpy(v_s_).flatten() returns = np.roll(v_s_, 1, axis=0) m = (1. - batch.done) * gamma delta = rew + v_s_ * m - returns diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 569d98a..8b98f70 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -105,10 +105,10 @@ class A2CPolicy(PGPolicy): for b in batch.split(batch_size): self.optim.zero_grad() dist = self(b).dist - v = self.critic(b.obs).squeeze(-1) + v = self.critic(b.obs).flatten() a = to_torch_as(b.act, v) r = to_torch_as(b.returns, v) - a_loss = -(dist.log_prob(a).reshape(v.shape) * (r - v).detach() + a_loss = -(dist.log_prob(a).flatten() * (r - v).detach() ).mean() vf_loss = F.mse_loss(r, v) ent_loss = dist.entropy().mean() diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 5dff17f..da4833a 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -142,8 +142,8 @@ class DDPGPolicy(BasePolicy): return Batch(act=actions, state=h) def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: - current_q = self.critic(batch.obs, batch.act).squeeze(-1) - target_q = batch.returns + current_q = self.critic(batch.obs, batch.act).flatten() + target_q = batch.returns.flatten() critic_loss = F.mse_loss(current_q, target_q) self.critic_optim.zero_grad() critic_loss.backward() diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 214a453..1627199 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -159,7 +159,7 @@ class DQNPolicy(BasePolicy): self.optim.zero_grad() q = self(batch).logits q = q[np.arange(len(q)), batch.act] - r = to_torch_as(batch.returns, q) + r = to_torch_as(batch.returns, q).flatten() if hasattr(batch, 'update_weight'): td = r - q batch.update_weight(batch.indice, to_numpy(td)) diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 16a815c..09f90ea 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -90,7 +90,7 @@ class PGPolicy(BasePolicy): dist = self(b).dist a = to_torch_as(b.act, dist.logits) r = to_torch_as(b.returns, dist.logits) - loss = -(dist.log_prob(a) * r).sum() + loss = -(dist.log_prob(a).flatten() * r).sum() loss.backward() self.optim.step() losses.append(loss.item()) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index dbf0f24..bae6907 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -130,9 +130,9 @@ class PPOPolicy(PGPolicy): v.append(self.critic(b.obs)) old_log_prob.append(self(b).dist.log_prob( to_torch_as(b.act, v[0]))) - batch.v = torch.cat(v, dim=0).squeeze(-1) # old value + batch.v = torch.cat(v, dim=0).flatten() # old value batch.act = to_torch_as(batch.act, v[0]) - batch.logp_old = torch.cat(old_log_prob, dim=0).reshape(batch.v.shape) + batch.logp_old = torch.cat(old_log_prob, dim=0).flatten() batch.returns = to_torch_as(batch.returns, v[0]) if self._rew_norm: mean, std = batch.returns.mean(), batch.returns.std() @@ -146,12 +146,12 @@ class PPOPolicy(PGPolicy): for _ in range(repeat): for b in batch.split(batch_size): dist = self(b).dist - value = self.critic(b.obs).squeeze(-1) - ratio = (dist.log_prob(b.act).reshape(value.shape) - b.logp_old - ).exp().float() + value = self.critic(b.obs).flatten() + ratio = (dist.log_prob(b.act).flatten() + - b.logp_old).exp().float() surr1 = ratio * b.adv - surr2 = ratio.clamp( - 1. - self._eps_clip, 1. + self._eps_clip) * b.adv + surr2 = ratio.clamp(1. - self._eps_clip, + 1. + self._eps_clip) * b.adv if self._dual_clip: clip_loss = -torch.max(torch.min(surr1, surr2), self._dual_clip * b.adv).mean() diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 5fdf359..b67a95b 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -139,14 +139,14 @@ class SACPolicy(DDPGPolicy): def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: # critic 1 - current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1) - target_q = batch.returns + current_q1 = self.critic1(batch.obs, batch.act).flatten() + target_q = batch.returns.flatten() critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 - current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1) + current_q2 = self.critic2(batch.obs, batch.act).flatten() critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() @@ -154,10 +154,10 @@ class SACPolicy(DDPGPolicy): # actor obs_result = self(batch, explorating=False) a = obs_result.act - current_q1a = self.critic1(batch.obs, a).squeeze(-1) - current_q2a = self.critic2(batch.obs, a).squeeze(-1) - actor_loss = (self._alpha * obs_result.log_prob.reshape( - target_q.shape) - torch.min(current_q1a, current_q2a)).mean() + current_q1a = self.critic1(batch.obs, a).flatten() + current_q2a = self.critic2(batch.obs, a).flatten() + actor_loss = (self._alpha * obs_result.log_prob.flatten() + - torch.min(current_q1a, current_q2a)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 1ea3a90..d90f510 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -117,14 +117,14 @@ class TD3Policy(DDPGPolicy): def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: # critic 1 - current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1) - target_q = batch.returns + current_q1 = self.critic1(batch.obs, batch.act).flatten() + target_q = batch.returns.flatten() critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 - current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1) + current_q2 = self.critic2(batch.obs, batch.act).flatten() critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward()