- fix 2 warning in doctest
- change the minimum version of gym (to be aligned with openai baselines)
- change squeeze and reshape to flatten (related to #155). I think flatten is better.
This commit is contained in:
n+e 2020-07-23 15:12:02 +08:00 committed by GitHub
parent bd9c3c7f8d
commit 352a518399
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 30 additions and 35 deletions

View File

@ -79,8 +79,8 @@ html_logo = '_static/images/tianshou-logo.png'
def setup(app):
app.add_javascript("js/copybutton.js")
app.add_stylesheet("css/style.css")
app.add_js_file("js/copybutton.js")
app.add_css_file("css/style.css")
# -- Extension configuration -------------------------------------------------

View File

@ -46,7 +46,7 @@ setup(
'examples', 'examples.*',
'docs', 'docs.*']),
install_requires=[
'gym>=0.15.0',
'gym>=0.15.4',
'tqdm',
'numpy',
'tensorboard',

View File

@ -23,7 +23,7 @@ def get_args():
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=1)
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=128)

View File

@ -4,7 +4,7 @@ from torch import nn
from abc import ABC, abstractmethod
from typing import Dict, List, Union, Optional, Callable
from tianshou.data import Batch, ReplayBuffer, to_torch_as
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
class BasePolicy(ABC, nn.Module):
@ -138,15 +138,10 @@ class BasePolicy(ABC, nn.Module):
Estimation, should be in [0, 1], defaults to 0.95.
:return: a Batch. The result will be stored in batch.returns as a numpy
array.
array with shape (bsz, ).
"""
rew = batch.rew
if v_s_ is None:
v_s_ = rew * 0.
else:
if not isinstance(v_s_, np.ndarray):
v_s_ = np.array(v_s_, np.float)
v_s_ = v_s_.reshape(rew.shape)
v_s_ = rew * 0. if v_s_ is None else to_numpy(v_s_).flatten()
returns = np.roll(v_s_, 1, axis=0)
m = (1. - batch.done) * gamma
delta = rew + v_s_ * m - returns

View File

@ -105,10 +105,10 @@ class A2CPolicy(PGPolicy):
for b in batch.split(batch_size):
self.optim.zero_grad()
dist = self(b).dist
v = self.critic(b.obs).squeeze(-1)
v = self.critic(b.obs).flatten()
a = to_torch_as(b.act, v)
r = to_torch_as(b.returns, v)
a_loss = -(dist.log_prob(a).reshape(v.shape) * (r - v).detach()
a_loss = -(dist.log_prob(a).flatten() * (r - v).detach()
).mean()
vf_loss = F.mse_loss(r, v)
ent_loss = dist.entropy().mean()

View File

@ -142,8 +142,8 @@ class DDPGPolicy(BasePolicy):
return Batch(act=actions, state=h)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
current_q = self.critic(batch.obs, batch.act).squeeze(-1)
target_q = batch.returns
current_q = self.critic(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
critic_loss = F.mse_loss(current_q, target_q)
self.critic_optim.zero_grad()
critic_loss.backward()

View File

@ -159,7 +159,7 @@ class DQNPolicy(BasePolicy):
self.optim.zero_grad()
q = self(batch).logits
q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns, q)
r = to_torch_as(batch.returns, q).flatten()
if hasattr(batch, 'update_weight'):
td = r - q
batch.update_weight(batch.indice, to_numpy(td))

View File

@ -90,7 +90,7 @@ class PGPolicy(BasePolicy):
dist = self(b).dist
a = to_torch_as(b.act, dist.logits)
r = to_torch_as(b.returns, dist.logits)
loss = -(dist.log_prob(a) * r).sum()
loss = -(dist.log_prob(a).flatten() * r).sum()
loss.backward()
self.optim.step()
losses.append(loss.item())

View File

@ -130,9 +130,9 @@ class PPOPolicy(PGPolicy):
v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob(
to_torch_as(b.act, v[0])))
batch.v = torch.cat(v, dim=0).squeeze(-1) # old value
batch.v = torch.cat(v, dim=0).flatten() # old value
batch.act = to_torch_as(batch.act, v[0])
batch.logp_old = torch.cat(old_log_prob, dim=0).reshape(batch.v.shape)
batch.logp_old = torch.cat(old_log_prob, dim=0).flatten()
batch.returns = to_torch_as(batch.returns, v[0])
if self._rew_norm:
mean, std = batch.returns.mean(), batch.returns.std()
@ -146,12 +146,12 @@ class PPOPolicy(PGPolicy):
for _ in range(repeat):
for b in batch.split(batch_size):
dist = self(b).dist
value = self.critic(b.obs).squeeze(-1)
ratio = (dist.log_prob(b.act).reshape(value.shape) - b.logp_old
).exp().float()
value = self.critic(b.obs).flatten()
ratio = (dist.log_prob(b.act).flatten()
- b.logp_old).exp().float()
surr1 = ratio * b.adv
surr2 = ratio.clamp(
1. - self._eps_clip, 1. + self._eps_clip) * b.adv
surr2 = ratio.clamp(1. - self._eps_clip,
1. + self._eps_clip) * b.adv
if self._dual_clip:
clip_loss = -torch.max(torch.min(surr1, surr2),
self._dual_clip * b.adv).mean()

View File

@ -139,14 +139,14 @@ class SACPolicy(DDPGPolicy):
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
# critic 1
current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1)
target_q = batch.returns
current_q1 = self.critic1(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad()
critic1_loss.backward()
self.critic1_optim.step()
# critic 2
current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1)
current_q2 = self.critic2(batch.obs, batch.act).flatten()
critic2_loss = F.mse_loss(current_q2, target_q)
self.critic2_optim.zero_grad()
critic2_loss.backward()
@ -154,10 +154,10 @@ class SACPolicy(DDPGPolicy):
# actor
obs_result = self(batch, explorating=False)
a = obs_result.act
current_q1a = self.critic1(batch.obs, a).squeeze(-1)
current_q2a = self.critic2(batch.obs, a).squeeze(-1)
actor_loss = (self._alpha * obs_result.log_prob.reshape(
target_q.shape) - torch.min(current_q1a, current_q2a)).mean()
current_q1a = self.critic1(batch.obs, a).flatten()
current_q2a = self.critic2(batch.obs, a).flatten()
actor_loss = (self._alpha * obs_result.log_prob.flatten()
- torch.min(current_q1a, current_q2a)).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()

View File

@ -117,14 +117,14 @@ class TD3Policy(DDPGPolicy):
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
# critic 1
current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1)
target_q = batch.returns
current_q1 = self.critic1(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad()
critic1_loss.backward()
self.critic1_optim.step()
# critic 2
current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1)
current_q2 = self.critic2(batch.obs, batch.act).flatten()
critic2_loss = F.mse_loss(current_q2, target_q)
self.critic2_optim.zero_grad()
critic2_loss.backward()