minor reformat (#2)

* update atari.py

* fix setup.py
pass the pytest

* fix setup.py
pass the pytest
This commit is contained in:
Minghao Zhang 2020-03-26 09:01:20 +08:00 committed by GitHub
parent fdc969b830
commit 3c0a09fefd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 29 additions and 16 deletions

View File

@ -37,11 +37,19 @@ setup(
'examples', 'examples.*',
'docs', 'docs.*']),
install_requires=[
'gym',
'gym>=0.15.0',
'tqdm',
'numpy',
'cloudpickle',
'tensorboard',
'torch>=1.4.0',
],
extras_require={
'atari': [
'atari_py',
],
'mujoco': [
'mujoco_py',
]
},
)

View File

@ -1,4 +1,5 @@
from tianshou.data import ReplayBuffer
if __name__ == '__main__':
from env import MyTestEnv
else: # pytest

View File

@ -11,6 +11,7 @@ else: # pytest
class MyPolicy(BasePolicy):
"""docstring for MyPolicy"""
def __init__(self):
super().__init__()

View File

@ -1,4 +1,4 @@
from tianshou import data, env, utils, policy, trainer,\
from tianshou import data, env, utils, policy, trainer, \
exploration
__version__ = '0.2.0'

View File

@ -47,7 +47,7 @@ class ReplayBuffer(object):
'''
weight: importance weights, disabled here
'''
assert isinstance(info, dict),\
assert isinstance(info, dict), \
'You should return a dict in the last argument of env.step().'
self._add_to_buffer('obs', obs)
self._add_to_buffer('act', act)

View File

@ -31,8 +31,8 @@ class Collector(object):
if self._multi_env:
self.env_num = len(env)
if isinstance(self.buffer, list):
assert len(self.buffer) == self.env_num,\
'The number of data buffer does not match the number of '\
assert len(self.buffer) == self.env_num, \
'The number of data buffer does not match the number of ' \
'input env.'
self._multi_buf = True
elif isinstance(self.buffer, ReplayBuffer):
@ -87,7 +87,7 @@ class Collector(object):
if not self._multi_env:
n_episode = np.sum(n_episode)
start_time = time.time()
assert sum([(n_step != 0), (n_episode != 0)]) == 1,\
assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
"One and only one collection number specification permitted!"
cur_step = 0
cur_episode = np.zeros(self.env_num) if self._multi_env else 0

View File

@ -1,6 +1,6 @@
from tianshou.env.utils import CloudpickleWrapper
from tianshou.env.common import EnvWrapper, FrameStack
from tianshou.env.vecenv import BaseVectorEnv, VectorEnv,\
from tianshou.env.vecenv import BaseVectorEnv, VectorEnv, \
SubprocVectorEnv, RayVectorEnv
__all__ = [

View File

@ -1,6 +1,7 @@
import numpy as np
from abc import ABC, abstractmethod
from multiprocessing import Process, Pipe
try:
import ray
except ImportError:

View File

@ -5,6 +5,8 @@ import torch.nn.functional as F
from tianshou.data import Batch
from tianshou.policy import BasePolicy
# from tianshou.exploration import OUNoise