This commit is contained in:
Trinkle23897 2020-03-13 21:47:17 +08:00
parent f16e05c0e7
commit 543e57cdbd
9 changed files with 17 additions and 18 deletions

3
.gitignore vendored
View File

@ -132,3 +132,6 @@ dmypy.json
# pytype static type analyzer
.pytype/
# customize
flake8.sh

View File

@ -1,3 +0,0 @@
#!/bin/sh
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 . --count --exit-zero --max-complexity=20 --max-line-length=79 --statistics

View File

@ -2,8 +2,8 @@ from tianshou import data, env, utils, policy
__version__ = '0.2.0'
__all__ = [
'data',
'env',
'data',
'utils',
'policy'
]

View File

@ -25,5 +25,5 @@ class Batch(object):
self.__dict__[k] += batch.__dict__[k]
else:
raise TypeError(
'No support append method with {} in class Batch.'
'Do not support append with type {} in class Batch.'
.format(type(batch.__dict__[k])))

View File

@ -21,7 +21,7 @@ class Collector(object):
self.env_num = len(env)
if isinstance(self.buffer, list):
assert len(self.buffer) == self.env_num,\
'Data buffer number does not match the input env number.'
'# of data buffer does not match the # of input env.'
elif isinstance(self.buffer, ReplayBuffer):
self.buffer = [deepcopy(buffer) for _ in range(self.env_num)]
else:

View File

@ -58,18 +58,18 @@ class FrameStack(EnvWrapper):
class BaseVectorEnv(ABC):
def __init__(self):
pass
def __init__(self, env_fns, reset_after_done):
self._env_fns = env_fns
self.env_num = len(env_fns)
self._reset_after_done = reset_after_done
class VectorEnv(BaseVectorEnv):
"""docstring for VectorEnv"""
def __init__(self, env_fns, reset_after_done=False):
super().__init__()
super().__init__(env_fns, reset_after_done)
self.envs = [_() for _ in env_fns]
self.env_num = len(self.envs)
self._reset_after_done = reset_after_done
def __len__(self):
return len(self.envs)
@ -132,8 +132,7 @@ class SubprocVectorEnv(BaseVectorEnv):
"""docstring for SubProcVectorEnv"""
def __init__(self, env_fns, reset_after_done=False):
super().__init__()
self.env_num = len(env_fns)
super().__init__(env_fns, reset_after_done)
self.closed = False
self.parent_remote, self.child_remote = \
zip(*[Pipe() for _ in range(self.env_num)])
@ -193,9 +192,7 @@ class RayVectorEnv(BaseVectorEnv):
"""docstring for RayVectorEnv"""
def __init__(self, env_fns, reset_after_done=False):
super().__init__()
self.env_num = len(env_fns)
self._reset_after_done = reset_after_done
super().__init__(env_fns, reset_after_done)
try:
if not ray.is_initialized():
ray.init()

View File

@ -1,5 +1,7 @@
from tianshou.policy.base import BasePolicy
from tianshou.policy.dqn import DQNPolicy
__all__ = [
'BasePolicy'
'BasePolicy',
'DQNPolicy',
]

View File

@ -9,7 +9,7 @@ class BasePolicy(ABC):
@abstractmethod
def act(self, batch, hidden_state=None):
# return {policy, action, hidden}
# return Batch(policy, action, hidden)
pass
def train(self):