diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 26c2f68..fc9af2f 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -30,10 +30,7 @@ jobs: - name: Lint with flake8 run: | pip install flake8 - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + ./flake_check.sh - name: Test with pytest run: | pip install pytest diff --git a/flake_check.sh b/flake_check.sh new file mode 100755 index 0000000..9d03862 --- /dev/null +++ b/flake_check.sh @@ -0,0 +1,3 @@ +#!/bin/sh +flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics +flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics diff --git a/setup.py b/setup.py index a82aef7..96f325f 100644 --- a/setup.py +++ b/setup.py @@ -40,12 +40,12 @@ setup( 'examples', 'examples.*', 'docs', 'docs.*']), install_requires=[ - 'numpy', - 'torch', - 'tensorboard', - 'tqdm', - # 'ray', - 'gym', + 'numpy', + 'torch', + 'tensorboard', + 'tqdm', + # 'ray', + 'gym', 'cloudpickle' ], -) \ No newline at end of file +) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index cafde94..d294ae7 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -20,7 +20,7 @@ class ReplayBuffer(object): self.__dict__[name] = np.zeros([self._maxsize, *inst.shape]) elif isinstance(inst, dict): self.__dict__[name] = np.array([{} for _ in range(self._maxsize)]) - else: # assume `inst` is a number + else: # assume `inst` is a number self.__dict__[name] = np.zeros([self._maxsize]) self.__dict__[name][self._index] = inst @@ -46,15 +46,21 @@ class ReplayBuffer(object): def sample(self, batch_size): indice = self.sample_index(batch_size) - return Batch(obs=self.obs[indice], act=self.act[indice], rew=self.rew[indice], - done=self.done[indice], obs_next=self.obs_next[indice], info=self.info[indice]) + return Batch( + obs=self.obs[indice], + act=self.act[indice], + rew=self.rew[indice], + done=self.done[indice], + obs_next=self.obs_next[indice], + info=self.info[indice] + ) class PrioritizedReplayBuffer(ReplayBuffer): """docstring for PrioritizedReplayBuffer""" def __init__(self, size): super().__init__(size) - + def add(self, obs, act, rew, done, obs_next, info={}, weight=None): raise NotImplementedError diff --git a/tianshou/env/wrapper.py b/tianshou/env/wrapper.py index f068a0e..c1b0c10 100644 --- a/tianshou/env/wrapper.py +++ b/tianshou/env/wrapper.py @@ -1,6 +1,10 @@ import numpy as np from collections import deque from multiprocessing import Process, Pipe +try: + import ray +except ImportError: + pass from tianshou.utils import CloudpickleWrapper @@ -11,10 +15,10 @@ class EnvWrapper(object): def step(self, action): return self.env.step(action) - + def reset(self): self.env.reset() - + def seed(self, seed=None): if hasattr(self.env, 'seed'): self.env.seed(seed) @@ -55,7 +59,7 @@ class VectorEnv(object): super().__init__() self.envs = [_() for _ in env_fns] self._reset_after_done = kwargs.get('reset_after_done', False) - + def __len__(self): return len(self.envs) @@ -89,12 +93,15 @@ class VectorEnv(object): class SubprocVectorEnv(object): """docstring for SubProcVectorEnv""" def __init__(self, env_fns, **kwargs): + super().__init__() self.env_num = len(env_fns) self.closed = False self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)]) - self.processes = [Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True) - for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)] + self.processes = [ + Process(target=self.worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True) + for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns) + ] for p in self.processes: p.start() for c in self.child_remote: @@ -102,27 +109,27 @@ class SubprocVectorEnv(object): def __len__(self): return self.env_num - - def worker(parent, p, env_fn_wrapper, **kwargs): + + def worker(self, parent, p, env_fn_wrapper, **kwargs): reset_after_done = kwargs.get('reset_after_done', True) parent.close() env = env_fn_wrapper.data() while True: cmd, data = p.recv() - if cmd is 'step': + if cmd == 'step': obs, rew, done, info = env.step(data) if reset_after_done and done: # s_ is useless when episode finishes obs = env.reset() p.send([obs, rew, done, info]) - elif cmd is 'reset': + elif cmd == 'reset': p.send(env.reset()) - elif cmd is 'close': + elif cmd == 'close': p.close() break - elif cmd is 'render': + elif cmd == 'render': p.send(env.render()) - elif cmd is 'seed': + elif cmd == 'seed': p.send(env.seed(data)) else: raise NotImplementedError @@ -163,7 +170,6 @@ class SubprocVectorEnv(object): p.join() - class RayVectorEnv(object): """docstring for RayVectorEnv""" def __init__(self, env_fns, **kwargs): diff --git a/tianshou/utils/cloudpicklewrapper.py b/tianshou/utils/cloudpicklewrapper.py index c2221ca..2bf3cd5 100644 --- a/tianshou/utils/cloudpicklewrapper.py +++ b/tianshou/utils/cloudpicklewrapper.py @@ -4,7 +4,9 @@ import cloudpickle class CloudpickleWrapper(object): def __init__(self, data): self.data = data + def __getstate__(self): return cloudpickle.dumps(self.data) + def __setstate__(self, data): self.data = cloudpickle.loads(data)