- fix 2 warning in doctest
- change the minimum version of gym (to be aligned with openai baselines)
- change squeeze and reshape to flatten (related to #155). I think flatten is better.
This commit is contained in:
n+e 2020-07-23 15:12:02 +08:00 committed by GitHub
parent bd9c3c7f8d
commit 352a518399
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 30 additions and 35 deletions

View File

@ -79,8 +79,8 @@ html_logo = '_static/images/tianshou-logo.png'
def setup(app): def setup(app):
app.add_javascript("js/copybutton.js") app.add_js_file("js/copybutton.js")
app.add_stylesheet("css/style.css") app.add_css_file("css/style.css")
# -- Extension configuration ------------------------------------------------- # -- Extension configuration -------------------------------------------------

View File

@ -46,7 +46,7 @@ setup(
'examples', 'examples.*', 'examples', 'examples.*',
'docs', 'docs.*']), 'docs', 'docs.*']),
install_requires=[ install_requires=[
'gym>=0.15.0', 'gym>=0.15.4',
'tqdm', 'tqdm',
'numpy', 'numpy',
'tensorboard', 'tensorboard',

View File

@ -23,7 +23,7 @@ def get_args():
parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=1) parser.add_argument('--collect-per-step', type=int, default=1)
parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--batch-size', type=int, default=128)

View File

@ -4,7 +4,7 @@ from torch import nn
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Union, Optional, Callable from typing import Dict, List, Union, Optional, Callable
from tianshou.data import Batch, ReplayBuffer, to_torch_as from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
class BasePolicy(ABC, nn.Module): class BasePolicy(ABC, nn.Module):
@ -138,15 +138,10 @@ class BasePolicy(ABC, nn.Module):
Estimation, should be in [0, 1], defaults to 0.95. Estimation, should be in [0, 1], defaults to 0.95.
:return: a Batch. The result will be stored in batch.returns as a numpy :return: a Batch. The result will be stored in batch.returns as a numpy
array. array with shape (bsz, ).
""" """
rew = batch.rew rew = batch.rew
if v_s_ is None: v_s_ = rew * 0. if v_s_ is None else to_numpy(v_s_).flatten()
v_s_ = rew * 0.
else:
if not isinstance(v_s_, np.ndarray):
v_s_ = np.array(v_s_, np.float)
v_s_ = v_s_.reshape(rew.shape)
returns = np.roll(v_s_, 1, axis=0) returns = np.roll(v_s_, 1, axis=0)
m = (1. - batch.done) * gamma m = (1. - batch.done) * gamma
delta = rew + v_s_ * m - returns delta = rew + v_s_ * m - returns

View File

@ -105,10 +105,10 @@ class A2CPolicy(PGPolicy):
for b in batch.split(batch_size): for b in batch.split(batch_size):
self.optim.zero_grad() self.optim.zero_grad()
dist = self(b).dist dist = self(b).dist
v = self.critic(b.obs).squeeze(-1) v = self.critic(b.obs).flatten()
a = to_torch_as(b.act, v) a = to_torch_as(b.act, v)
r = to_torch_as(b.returns, v) r = to_torch_as(b.returns, v)
a_loss = -(dist.log_prob(a).reshape(v.shape) * (r - v).detach() a_loss = -(dist.log_prob(a).flatten() * (r - v).detach()
).mean() ).mean()
vf_loss = F.mse_loss(r, v) vf_loss = F.mse_loss(r, v)
ent_loss = dist.entropy().mean() ent_loss = dist.entropy().mean()

View File

@ -142,8 +142,8 @@ class DDPGPolicy(BasePolicy):
return Batch(act=actions, state=h) return Batch(act=actions, state=h)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
current_q = self.critic(batch.obs, batch.act).squeeze(-1) current_q = self.critic(batch.obs, batch.act).flatten()
target_q = batch.returns target_q = batch.returns.flatten()
critic_loss = F.mse_loss(current_q, target_q) critic_loss = F.mse_loss(current_q, target_q)
self.critic_optim.zero_grad() self.critic_optim.zero_grad()
critic_loss.backward() critic_loss.backward()

View File

@ -159,7 +159,7 @@ class DQNPolicy(BasePolicy):
self.optim.zero_grad() self.optim.zero_grad()
q = self(batch).logits q = self(batch).logits
q = q[np.arange(len(q)), batch.act] q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns, q) r = to_torch_as(batch.returns, q).flatten()
if hasattr(batch, 'update_weight'): if hasattr(batch, 'update_weight'):
td = r - q td = r - q
batch.update_weight(batch.indice, to_numpy(td)) batch.update_weight(batch.indice, to_numpy(td))

View File

@ -90,7 +90,7 @@ class PGPolicy(BasePolicy):
dist = self(b).dist dist = self(b).dist
a = to_torch_as(b.act, dist.logits) a = to_torch_as(b.act, dist.logits)
r = to_torch_as(b.returns, dist.logits) r = to_torch_as(b.returns, dist.logits)
loss = -(dist.log_prob(a) * r).sum() loss = -(dist.log_prob(a).flatten() * r).sum()
loss.backward() loss.backward()
self.optim.step() self.optim.step()
losses.append(loss.item()) losses.append(loss.item())

View File

@ -130,9 +130,9 @@ class PPOPolicy(PGPolicy):
v.append(self.critic(b.obs)) v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob( old_log_prob.append(self(b).dist.log_prob(
to_torch_as(b.act, v[0]))) to_torch_as(b.act, v[0])))
batch.v = torch.cat(v, dim=0).squeeze(-1) # old value batch.v = torch.cat(v, dim=0).flatten() # old value
batch.act = to_torch_as(batch.act, v[0]) batch.act = to_torch_as(batch.act, v[0])
batch.logp_old = torch.cat(old_log_prob, dim=0).reshape(batch.v.shape) batch.logp_old = torch.cat(old_log_prob, dim=0).flatten()
batch.returns = to_torch_as(batch.returns, v[0]) batch.returns = to_torch_as(batch.returns, v[0])
if self._rew_norm: if self._rew_norm:
mean, std = batch.returns.mean(), batch.returns.std() mean, std = batch.returns.mean(), batch.returns.std()
@ -146,12 +146,12 @@ class PPOPolicy(PGPolicy):
for _ in range(repeat): for _ in range(repeat):
for b in batch.split(batch_size): for b in batch.split(batch_size):
dist = self(b).dist dist = self(b).dist
value = self.critic(b.obs).squeeze(-1) value = self.critic(b.obs).flatten()
ratio = (dist.log_prob(b.act).reshape(value.shape) - b.logp_old ratio = (dist.log_prob(b.act).flatten()
).exp().float() - b.logp_old).exp().float()
surr1 = ratio * b.adv surr1 = ratio * b.adv
surr2 = ratio.clamp( surr2 = ratio.clamp(1. - self._eps_clip,
1. - self._eps_clip, 1. + self._eps_clip) * b.adv 1. + self._eps_clip) * b.adv
if self._dual_clip: if self._dual_clip:
clip_loss = -torch.max(torch.min(surr1, surr2), clip_loss = -torch.max(torch.min(surr1, surr2),
self._dual_clip * b.adv).mean() self._dual_clip * b.adv).mean()

View File

@ -139,14 +139,14 @@ class SACPolicy(DDPGPolicy):
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
# critic 1 # critic 1
current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1) current_q1 = self.critic1(batch.obs, batch.act).flatten()
target_q = batch.returns target_q = batch.returns.flatten()
critic1_loss = F.mse_loss(current_q1, target_q) critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad() self.critic1_optim.zero_grad()
critic1_loss.backward() critic1_loss.backward()
self.critic1_optim.step() self.critic1_optim.step()
# critic 2 # critic 2
current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1) current_q2 = self.critic2(batch.obs, batch.act).flatten()
critic2_loss = F.mse_loss(current_q2, target_q) critic2_loss = F.mse_loss(current_q2, target_q)
self.critic2_optim.zero_grad() self.critic2_optim.zero_grad()
critic2_loss.backward() critic2_loss.backward()
@ -154,10 +154,10 @@ class SACPolicy(DDPGPolicy):
# actor # actor
obs_result = self(batch, explorating=False) obs_result = self(batch, explorating=False)
a = obs_result.act a = obs_result.act
current_q1a = self.critic1(batch.obs, a).squeeze(-1) current_q1a = self.critic1(batch.obs, a).flatten()
current_q2a = self.critic2(batch.obs, a).squeeze(-1) current_q2a = self.critic2(batch.obs, a).flatten()
actor_loss = (self._alpha * obs_result.log_prob.reshape( actor_loss = (self._alpha * obs_result.log_prob.flatten()
target_q.shape) - torch.min(current_q1a, current_q2a)).mean() - torch.min(current_q1a, current_q2a)).mean()
self.actor_optim.zero_grad() self.actor_optim.zero_grad()
actor_loss.backward() actor_loss.backward()
self.actor_optim.step() self.actor_optim.step()

View File

@ -117,14 +117,14 @@ class TD3Policy(DDPGPolicy):
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
# critic 1 # critic 1
current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1) current_q1 = self.critic1(batch.obs, batch.act).flatten()
target_q = batch.returns target_q = batch.returns.flatten()
critic1_loss = F.mse_loss(current_q1, target_q) critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad() self.critic1_optim.zero_grad()
critic1_loss.backward() critic1_loss.backward()
self.critic1_optim.step() self.critic1_optim.step()
# critic 2 # critic 2
current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1) current_q2 = self.critic2(batch.obs, batch.act).flatten()
critic2_loss = F.mse_loss(current_q2, target_q) critic2_loss = F.mse_loss(current_q2, target_q)
self.critic2_optim.zero_grad() self.critic2_optim.zero_grad()
critic2_loss.backward() critic2_loss.backward()