minor reformat (#2)

* update atari.py

* fix setup.py
pass the pytest

* fix setup.py
pass the pytest
This commit is contained in:
Minghao Zhang 2020-03-26 09:01:20 +08:00 committed by GitHub
parent fdc969b830
commit 3c0a09fefd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 29 additions and 16 deletions

View File

@ -37,11 +37,19 @@ setup(
'examples', 'examples.*', 'examples', 'examples.*',
'docs', 'docs.*']), 'docs', 'docs.*']),
install_requires=[ install_requires=[
'gym', 'gym>=0.15.0',
'tqdm', 'tqdm',
'numpy', 'numpy',
'cloudpickle', 'cloudpickle',
'tensorboard', 'tensorboard',
'torch>=1.4.0', 'torch>=1.4.0',
], ],
extras_require={
'atari': [
'atari_py',
],
'mujoco': [
'mujoco_py',
]
},
) )

View File

@ -1,4 +1,5 @@
from tianshou.data import ReplayBuffer from tianshou.data import ReplayBuffer
if __name__ == '__main__': if __name__ == '__main__':
from env import MyTestEnv from env import MyTestEnv
else: # pytest else: # pytest

View File

@ -11,6 +11,7 @@ else: # pytest
class MyPolicy(BasePolicy): class MyPolicy(BasePolicy):
"""docstring for MyPolicy""" """docstring for MyPolicy"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -1,4 +1,4 @@
from tianshou import data, env, utils, policy, trainer,\ from tianshou import data, env, utils, policy, trainer, \
exploration exploration
__version__ = '0.2.0' __version__ = '0.2.0'

View File

@ -37,7 +37,7 @@ class Batch(object):
else: else:
raise TypeError( raise TypeError(
'No support for append with type {} in class Batch.' 'No support for append with type {} in class Batch.'
.format(type(batch.__dict__[k]))) .format(type(batch.__dict__[k])))
def split(self, size=None, permute=True): def split(self, size=None, permute=True):
length = min([ length = min([

View File

@ -47,7 +47,7 @@ class ReplayBuffer(object):
''' '''
weight: importance weights, disabled here weight: importance weights, disabled here
''' '''
assert isinstance(info, dict),\ assert isinstance(info, dict), \
'You should return a dict in the last argument of env.step().' 'You should return a dict in the last argument of env.step().'
self._add_to_buffer('obs', obs) self._add_to_buffer('obs', obs)
self._add_to_buffer('act', act) self._add_to_buffer('act', act)

View File

@ -31,8 +31,8 @@ class Collector(object):
if self._multi_env: if self._multi_env:
self.env_num = len(env) self.env_num = len(env)
if isinstance(self.buffer, list): if isinstance(self.buffer, list):
assert len(self.buffer) == self.env_num,\ assert len(self.buffer) == self.env_num, \
'The number of data buffer does not match the number of '\ 'The number of data buffer does not match the number of ' \
'input env.' 'input env.'
self._multi_buf = True self._multi_buf = True
elif isinstance(self.buffer, ReplayBuffer): elif isinstance(self.buffer, ReplayBuffer):
@ -87,7 +87,7 @@ class Collector(object):
if not self._multi_env: if not self._multi_env:
n_episode = np.sum(n_episode) n_episode = np.sum(n_episode)
start_time = time.time() start_time = time.time()
assert sum([(n_step != 0), (n_episode != 0)]) == 1,\ assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
"One and only one collection number specification permitted!" "One and only one collection number specification permitted!"
cur_step = 0 cur_step = 0
cur_episode = np.zeros(self.env_num) if self._multi_env else 0 cur_episode = np.zeros(self.env_num) if self._multi_env else 0

View File

@ -1,6 +1,6 @@
from tianshou.env.utils import CloudpickleWrapper from tianshou.env.utils import CloudpickleWrapper
from tianshou.env.common import EnvWrapper, FrameStack from tianshou.env.common import EnvWrapper, FrameStack
from tianshou.env.vecenv import BaseVectorEnv, VectorEnv,\ from tianshou.env.vecenv import BaseVectorEnv, VectorEnv, \
SubprocVectorEnv, RayVectorEnv SubprocVectorEnv, RayVectorEnv
__all__ = [ __all__ = [

View File

@ -1,6 +1,7 @@
import numpy as np import numpy as np
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from multiprocessing import Process, Pipe from multiprocessing import Process, Pipe
try: try:
import ray import ray
except ImportError: except ImportError:
@ -122,7 +123,7 @@ class SubprocVectorEnv(BaseVectorEnv):
zip(*[Pipe() for _ in range(self.env_num)]) zip(*[Pipe() for _ in range(self.env_num)])
self.processes = [ self.processes = [
Process(target=worker, args=( Process(target=worker, args=(
parent, child, CloudpickleWrapper(env_fn)), daemon=True) parent, child, CloudpickleWrapper(env_fn)), daemon=True)
for (parent, child, env_fn) in zip( for (parent, child, env_fn) in zip(
self.parent_remote, self.child_remote, env_fns) self.parent_remote, self.child_remote, env_fns)
] ]

View File

@ -14,7 +14,7 @@ class OUNoise(object):
if self.x is None or self.x.shape != size: if self.x is None or self.x.shape != size:
self.x = 0 self.x = 0
self.x = self.x + self.alpha * (mu - self.x) + \ self.x = self.x + self.alpha * (mu - self.x) + \
self.beta * np.random.normal(size=size) self.beta * np.random.normal(size=size)
return self.x return self.x
def reset(self): def reset(self):

View File

@ -40,8 +40,8 @@ class A2CPolicy(PGPolicy):
vf_loss = F.mse_loss(r[:, None], v) vf_loss = F.mse_loss(r[:, None], v)
ent_loss = dist.entropy().mean() ent_loss = dist.entropy().mean()
loss = actor_loss \ loss = actor_loss \
+ self._w_vf * vf_loss \ + self._w_vf * vf_loss \
- self._w_ent * ent_loss - self._w_ent * ent_loss
loss.backward() loss.backward()
if self._grad_norm: if self._grad_norm:
nn.utils.clip_grad_norm_( nn.utils.clip_grad_norm_(

View File

@ -5,6 +5,8 @@ import torch.nn.functional as F
from tianshou.data import Batch from tianshou.data import Batch
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
# from tianshou.exploration import OUNoise # from tianshou.exploration import OUNoise

View File

@ -35,7 +35,7 @@ class PGPolicy(BasePolicy):
def learn(self, batch, batch_size=None, repeat=1): def learn(self, batch, batch_size=None, repeat=1):
losses = [] losses = []
batch.returns = (batch.returns - batch.returns.mean()) \ batch.returns = (batch.returns - batch.returns.mean()) \
/ (batch.returns.std() + self._eps) / (batch.returns.std() + self._eps)
for _ in range(repeat): for _ in range(repeat):
for b in batch.split(batch_size): for b in batch.split(batch_size):
self.optim.zero_grad() self.optim.zero_grad()

View File

@ -59,7 +59,7 @@ class PPOPolicy(PGPolicy):
def learn(self, batch, batch_size=None, repeat=1): def learn(self, batch, batch_size=None, repeat=1):
losses, clip_losses, vf_losses, ent_losses = [], [], [], [] losses, clip_losses, vf_losses, ent_losses = [], [], [], []
batch.returns = (batch.returns - batch.returns.mean()) \ batch.returns = (batch.returns - batch.returns.mean()) \
/ (batch.returns.std() + self._eps) / (batch.returns.std() + self._eps)
batch.act = torch.tensor(batch.act) batch.act = torch.tensor(batch.act)
batch.returns = torch.tensor(batch.returns)[:, None] batch.returns = torch.tensor(batch.returns)[:, None]
for _ in range(repeat): for _ in range(repeat):
@ -82,13 +82,13 @@ class PPOPolicy(PGPolicy):
ent_loss = dist.entropy().mean() ent_loss = dist.entropy().mean()
ent_losses.append(ent_loss.detach().cpu().numpy()) ent_losses.append(ent_loss.detach().cpu().numpy())
loss = clip_loss \ loss = clip_loss \
+ self._w_vf * vf_loss - self._w_ent * ent_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
losses.append(loss.detach().cpu().numpy()) losses.append(loss.detach().cpu().numpy())
self.optim.zero_grad() self.optim.zero_grad()
loss.backward() loss.backward()
nn.utils.clip_grad_norm_(list( nn.utils.clip_grad_norm_(list(
self.actor.parameters()) + list(self.critic.parameters()), self.actor.parameters()) + list(self.critic.parameters()),
self._max_grad_norm) self._max_grad_norm)
self.optim.step() self.optim.step()
self.sync_weight() self.sync_weight()
return { return {