diff --git a/.gitignore b/.gitignore index b1ec891..6cc5860 100644 --- a/.gitignore +++ b/.gitignore @@ -132,3 +132,6 @@ dmypy.json # pytype static type analyzer .pytype/ + +# customize +flake8.sh diff --git a/flake8.sh b/flake8.sh deleted file mode 100755 index 5f3cff1..0000000 --- a/flake8.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/sh -flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics -flake8 . --count --exit-zero --max-complexity=20 --max-line-length=79 --statistics diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 3e97de1..38edeca 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -2,8 +2,8 @@ from tianshou import data, env, utils, policy __version__ = '0.2.0' __all__ = [ - 'data', 'env', + 'data', 'utils', 'policy' ] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index d254217..7dd2ca8 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -25,5 +25,5 @@ class Batch(object): self.__dict__[k] += batch.__dict__[k] else: raise TypeError( - 'No support append method with {} in class Batch.' + 'Do not support append with type {} in class Batch.' .format(type(batch.__dict__[k]))) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 8378d53..c5da90c 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -21,7 +21,7 @@ class Collector(object): self.env_num = len(env) if isinstance(self.buffer, list): assert len(self.buffer) == self.env_num,\ - 'Data buffer number does not match the input env number.' + '# of data buffer does not match the # of input env.' elif isinstance(self.buffer, ReplayBuffer): self.buffer = [deepcopy(buffer) for _ in range(self.env_num)] else: diff --git a/tianshou/env/wrapper.py b/tianshou/env/wrapper.py index 5a4fbb5..bef3aab 100644 --- a/tianshou/env/wrapper.py +++ b/tianshou/env/wrapper.py @@ -58,18 +58,18 @@ class FrameStack(EnvWrapper): class BaseVectorEnv(ABC): - def __init__(self): - pass + def __init__(self, env_fns, reset_after_done): + self._env_fns = env_fns + self.env_num = len(env_fns) + self._reset_after_done = reset_after_done class VectorEnv(BaseVectorEnv): """docstring for VectorEnv""" def __init__(self, env_fns, reset_after_done=False): - super().__init__() + super().__init__(env_fns, reset_after_done) self.envs = [_() for _ in env_fns] - self.env_num = len(self.envs) - self._reset_after_done = reset_after_done def __len__(self): return len(self.envs) @@ -132,8 +132,7 @@ class SubprocVectorEnv(BaseVectorEnv): """docstring for SubProcVectorEnv""" def __init__(self, env_fns, reset_after_done=False): - super().__init__() - self.env_num = len(env_fns) + super().__init__(env_fns, reset_after_done) self.closed = False self.parent_remote, self.child_remote = \ zip(*[Pipe() for _ in range(self.env_num)]) @@ -193,9 +192,7 @@ class RayVectorEnv(BaseVectorEnv): """docstring for RayVectorEnv""" def __init__(self, env_fns, reset_after_done=False): - super().__init__() - self.env_num = len(env_fns) - self._reset_after_done = reset_after_done + super().__init__(env_fns, reset_after_done) try: if not ray.is_initialized(): ray.init() diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index d7a3fb5..37c8423 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,5 +1,7 @@ from tianshou.policy.base import BasePolicy +from tianshou.policy.dqn import DQNPolicy __all__ = [ - 'BasePolicy' + 'BasePolicy', + 'DQNPolicy', ] diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index de2eb9f..afbeacd 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -9,7 +9,7 @@ class BasePolicy(ABC): @abstractmethod def act(self, batch, hidden_state=None): - # return {policy, action, hidden} + # return Batch(policy, action, hidden) pass def train(self): diff --git a/tianshou/policy/reward_processor.py b/tianshou/policy/reward_processor.py deleted file mode 100644 index e69de29..0000000