3 fix (#158)
- fix 2 warning in doctest - change the minimum version of gym (to be aligned with openai baselines) - change squeeze and reshape to flatten (related to #155). I think flatten is better.
This commit is contained in:
parent
bd9c3c7f8d
commit
352a518399
@ -79,8 +79,8 @@ html_logo = '_static/images/tianshou-logo.png'
|
||||
|
||||
|
||||
def setup(app):
|
||||
app.add_javascript("js/copybutton.js")
|
||||
app.add_stylesheet("css/style.css")
|
||||
app.add_js_file("js/copybutton.js")
|
||||
app.add_css_file("css/style.css")
|
||||
|
||||
# -- Extension configuration -------------------------------------------------
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -46,7 +46,7 @@ setup(
|
||||
'examples', 'examples.*',
|
||||
'docs', 'docs.*']),
|
||||
install_requires=[
|
||||
'gym>=0.15.0',
|
||||
'gym>=0.15.4',
|
||||
'tqdm',
|
||||
'numpy',
|
||||
'tensorboard',
|
||||
|
@ -23,7 +23,7 @@ def get_args():
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--epoch', type=int, default=20)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||
parser.add_argument('--collect-per-step', type=int, default=1)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
|
@ -4,7 +4,7 @@ from torch import nn
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Union, Optional, Callable
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
||||
|
||||
|
||||
class BasePolicy(ABC, nn.Module):
|
||||
@ -138,15 +138,10 @@ class BasePolicy(ABC, nn.Module):
|
||||
Estimation, should be in [0, 1], defaults to 0.95.
|
||||
|
||||
:return: a Batch. The result will be stored in batch.returns as a numpy
|
||||
array.
|
||||
array with shape (bsz, ).
|
||||
"""
|
||||
rew = batch.rew
|
||||
if v_s_ is None:
|
||||
v_s_ = rew * 0.
|
||||
else:
|
||||
if not isinstance(v_s_, np.ndarray):
|
||||
v_s_ = np.array(v_s_, np.float)
|
||||
v_s_ = v_s_.reshape(rew.shape)
|
||||
v_s_ = rew * 0. if v_s_ is None else to_numpy(v_s_).flatten()
|
||||
returns = np.roll(v_s_, 1, axis=0)
|
||||
m = (1. - batch.done) * gamma
|
||||
delta = rew + v_s_ * m - returns
|
||||
|
@ -105,10 +105,10 @@ class A2CPolicy(PGPolicy):
|
||||
for b in batch.split(batch_size):
|
||||
self.optim.zero_grad()
|
||||
dist = self(b).dist
|
||||
v = self.critic(b.obs).squeeze(-1)
|
||||
v = self.critic(b.obs).flatten()
|
||||
a = to_torch_as(b.act, v)
|
||||
r = to_torch_as(b.returns, v)
|
||||
a_loss = -(dist.log_prob(a).reshape(v.shape) * (r - v).detach()
|
||||
a_loss = -(dist.log_prob(a).flatten() * (r - v).detach()
|
||||
).mean()
|
||||
vf_loss = F.mse_loss(r, v)
|
||||
ent_loss = dist.entropy().mean()
|
||||
|
@ -142,8 +142,8 @@ class DDPGPolicy(BasePolicy):
|
||||
return Batch(act=actions, state=h)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
current_q = self.critic(batch.obs, batch.act).squeeze(-1)
|
||||
target_q = batch.returns
|
||||
current_q = self.critic(batch.obs, batch.act).flatten()
|
||||
target_q = batch.returns.flatten()
|
||||
critic_loss = F.mse_loss(current_q, target_q)
|
||||
self.critic_optim.zero_grad()
|
||||
critic_loss.backward()
|
||||
|
@ -159,7 +159,7 @@ class DQNPolicy(BasePolicy):
|
||||
self.optim.zero_grad()
|
||||
q = self(batch).logits
|
||||
q = q[np.arange(len(q)), batch.act]
|
||||
r = to_torch_as(batch.returns, q)
|
||||
r = to_torch_as(batch.returns, q).flatten()
|
||||
if hasattr(batch, 'update_weight'):
|
||||
td = r - q
|
||||
batch.update_weight(batch.indice, to_numpy(td))
|
||||
|
@ -90,7 +90,7 @@ class PGPolicy(BasePolicy):
|
||||
dist = self(b).dist
|
||||
a = to_torch_as(b.act, dist.logits)
|
||||
r = to_torch_as(b.returns, dist.logits)
|
||||
loss = -(dist.log_prob(a) * r).sum()
|
||||
loss = -(dist.log_prob(a).flatten() * r).sum()
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
losses.append(loss.item())
|
||||
|
@ -130,9 +130,9 @@ class PPOPolicy(PGPolicy):
|
||||
v.append(self.critic(b.obs))
|
||||
old_log_prob.append(self(b).dist.log_prob(
|
||||
to_torch_as(b.act, v[0])))
|
||||
batch.v = torch.cat(v, dim=0).squeeze(-1) # old value
|
||||
batch.v = torch.cat(v, dim=0).flatten() # old value
|
||||
batch.act = to_torch_as(batch.act, v[0])
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0).reshape(batch.v.shape)
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0).flatten()
|
||||
batch.returns = to_torch_as(batch.returns, v[0])
|
||||
if self._rew_norm:
|
||||
mean, std = batch.returns.mean(), batch.returns.std()
|
||||
@ -146,12 +146,12 @@ class PPOPolicy(PGPolicy):
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
dist = self(b).dist
|
||||
value = self.critic(b.obs).squeeze(-1)
|
||||
ratio = (dist.log_prob(b.act).reshape(value.shape) - b.logp_old
|
||||
).exp().float()
|
||||
value = self.critic(b.obs).flatten()
|
||||
ratio = (dist.log_prob(b.act).flatten()
|
||||
- b.logp_old).exp().float()
|
||||
surr1 = ratio * b.adv
|
||||
surr2 = ratio.clamp(
|
||||
1. - self._eps_clip, 1. + self._eps_clip) * b.adv
|
||||
surr2 = ratio.clamp(1. - self._eps_clip,
|
||||
1. + self._eps_clip) * b.adv
|
||||
if self._dual_clip:
|
||||
clip_loss = -torch.max(torch.min(surr1, surr2),
|
||||
self._dual_clip * b.adv).mean()
|
||||
|
@ -139,14 +139,14 @@ class SACPolicy(DDPGPolicy):
|
||||
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
# critic 1
|
||||
current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1)
|
||||
target_q = batch.returns
|
||||
current_q1 = self.critic1(batch.obs, batch.act).flatten()
|
||||
target_q = batch.returns.flatten()
|
||||
critic1_loss = F.mse_loss(current_q1, target_q)
|
||||
self.critic1_optim.zero_grad()
|
||||
critic1_loss.backward()
|
||||
self.critic1_optim.step()
|
||||
# critic 2
|
||||
current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1)
|
||||
current_q2 = self.critic2(batch.obs, batch.act).flatten()
|
||||
critic2_loss = F.mse_loss(current_q2, target_q)
|
||||
self.critic2_optim.zero_grad()
|
||||
critic2_loss.backward()
|
||||
@ -154,10 +154,10 @@ class SACPolicy(DDPGPolicy):
|
||||
# actor
|
||||
obs_result = self(batch, explorating=False)
|
||||
a = obs_result.act
|
||||
current_q1a = self.critic1(batch.obs, a).squeeze(-1)
|
||||
current_q2a = self.critic2(batch.obs, a).squeeze(-1)
|
||||
actor_loss = (self._alpha * obs_result.log_prob.reshape(
|
||||
target_q.shape) - torch.min(current_q1a, current_q2a)).mean()
|
||||
current_q1a = self.critic1(batch.obs, a).flatten()
|
||||
current_q2a = self.critic2(batch.obs, a).flatten()
|
||||
actor_loss = (self._alpha * obs_result.log_prob.flatten()
|
||||
- torch.min(current_q1a, current_q2a)).mean()
|
||||
self.actor_optim.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optim.step()
|
||||
|
@ -117,14 +117,14 @@ class TD3Policy(DDPGPolicy):
|
||||
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
# critic 1
|
||||
current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1)
|
||||
target_q = batch.returns
|
||||
current_q1 = self.critic1(batch.obs, batch.act).flatten()
|
||||
target_q = batch.returns.flatten()
|
||||
critic1_loss = F.mse_loss(current_q1, target_q)
|
||||
self.critic1_optim.zero_grad()
|
||||
critic1_loss.backward()
|
||||
self.critic1_optim.step()
|
||||
# critic 2
|
||||
current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1)
|
||||
current_q2 = self.critic2(batch.obs, batch.act).flatten()
|
||||
critic2_loss = F.mse_loss(current_q2, target_q)
|
||||
self.critic2_optim.zero_grad()
|
||||
critic2_loss.backward()
|
||||
|
Loading…
x
Reference in New Issue
Block a user