From 04557fdb82d79d5ff6dafca231ed1be01338869f Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 11 Mar 2020 16:14:53 +0800 Subject: [PATCH] env test \ ray --- .github/workflows/pytest.yml | 4 +- README.md | 2 + test/test_env.py | 38 ++++++++++++++++-- tianshou/env/wrapper.py | 78 +++++++++++++++++++----------------- 4 files changed, 80 insertions(+), 42 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 8e67963..91789db 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -33,5 +33,5 @@ jobs: ./flake8.sh - name: Test with pytest run: | - pip install pytest - pytest + pip install pytest pytest-cov + pytest --cov tianshou diff --git a/README.md b/README.md index b392c7f..c4cefb0 100644 --- a/README.md +++ b/README.md @@ -1 +1,3 @@ # Tianshou + +![Python package](https://github.com/Trinkle23897/tianshou/workflows/Python%20package/badge.svg) diff --git a/test/test_env.py b/test/test_env.py index e56afff..1c9cea1 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -27,9 +27,8 @@ class MyTestEnv(gym.Env): finished = self.index == self.size return self.index, int(finished), finished, {} -def test_framestack(): - k = 4 - size = 10 + +def test_framestack(k=4, size=10): env = MyTestEnv(size=size) fsenv = FrameStack(env, k) fsenv.seed() @@ -53,5 +52,36 @@ def test_framestack(): assert (rew, done) == (0, True) fsenv.close() + +def test_vecenv(verbose=False, size=10, num=8, sleep=0.001): + env_fns = [lambda: MyTestEnv(size=size, sleep=sleep) for _ in range(num)] + venv = [ + VectorEnv(env_fns, reset_after_done=True), + SubprocVectorEnv(env_fns, reset_after_done=True), + ] + if verbose: + venv.append(RayVectorEnv(env_fns, reset_after_done=True)) + for v in venv: + v.seed() + action_list = [1] * 5 + [0] * 10 + [1] * 9 + if not verbose: + o = [v.reset() for v in venv] + for i, a in enumerate(action_list): + o = [v.step([a] * num) for v in venv] + for i in zip(*o): + for j in range(1, len(i)): + assert (i[0] == i[j]).all() + else: + t = [0, 0, 0] + for i, e in enumerate(venv): + t[i] = time.time() + e.reset() + for a in action_list: + e.step([a] * num) + t[i] = time.time() - t[i] + print(f'VectorEnv: {t[0]:.6f}s\nSubprocVectorEnv: {t[1]:.6f}s\nRayVectorEnv: {t[2]:.6f}s') + + if __name__ == '__main__': - test_framestack() \ No newline at end of file + test_framestack() + test_vecenv(True) diff --git a/tianshou/env/wrapper.py b/tianshou/env/wrapper.py index ca81f97..852ac4e 100644 --- a/tianshou/env/wrapper.py +++ b/tianshou/env/wrapper.py @@ -17,7 +17,7 @@ class EnvWrapper(object): return self.env.step(action) def reset(self): - self.env.reset() + return self.env.reset() def seed(self, seed=None): if hasattr(self.env, 'seed'): @@ -61,6 +61,7 @@ class VectorEnv(object): def __init__(self, env_fns, **kwargs): super().__init__() self.envs = [_() for _ in env_fns] + self.env_num = len(self.envs) self._reset_after_done = kwargs.get('reset_after_done', False) def __len__(self): @@ -70,8 +71,9 @@ class VectorEnv(object): return np.stack([e.reset() for e in self.envs]) def step(self, action): - result = zip(*[e.step(action[i]) for i, e in enumerate(self.envs)]) - obs, rew, done, info = zip(*result) + assert len(action) == self.env_num + result = zip(*[e.step(a) for e, a in zip(self.envs, action)]) + obs, rew, done, info = result if self._reset_after_done and sum(done): for i, e in enumerate(self.envs): if done[i]: @@ -79,9 +81,11 @@ class VectorEnv(object): return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info) def seed(self, seed=None): - for e in self.envs: + if np.isscalar(seed) or seed is None: + seed = [seed for _ in range(self.env_num)] + for e, s in zip(self.envs, seed): if hasattr(e, 'seed'): - e.seed(seed) + e.seed(s) def render(self): for e in self.envs: @@ -93,16 +97,40 @@ class VectorEnv(object): e.close() +def worker(parent, p, env_fn_wrapper, kwargs): + reset_after_done = kwargs.get('reset_after_done', True) + parent.close() + env = env_fn_wrapper.data() + while True: + cmd, data = p.recv() + if cmd == 'step': + obs, rew, done, info = env.step(data) + if reset_after_done and done: + # s_ is useless when episode finishes + obs = env.reset() + p.send([obs, rew, done, info]) + elif cmd == 'reset': + p.send(env.reset()) + elif cmd == 'close': + p.close() + break + elif cmd == 'render': + p.send(env.render()) + elif cmd == 'seed': + p.send(env.seed(data)) + else: + raise NotImplementedError + + class SubprocVectorEnv(object): """docstring for SubProcVectorEnv""" def __init__(self, env_fns, **kwargs): - super().__init__() self.env_num = len(env_fns) self.closed = False self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)]) self.processes = [ - Process(target=self.worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True) + Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True) for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns) ] for p in self.processes: @@ -113,31 +141,8 @@ class SubprocVectorEnv(object): def __len__(self): return self.env_num - def worker(self, parent, p, env_fn_wrapper, **kwargs): - reset_after_done = kwargs.get('reset_after_done', True) - parent.close() - env = env_fn_wrapper.data() - while True: - cmd, data = p.recv() - if cmd == 'step': - obs, rew, done, info = env.step(data) - if reset_after_done and done: - # s_ is useless when episode finishes - obs = env.reset() - p.send([obs, rew, done, info]) - elif cmd == 'reset': - p.send(env.reset()) - elif cmd == 'close': - p.close() - break - elif cmd == 'render': - p.send(env.render()) - elif cmd == 'seed': - p.send(env.seed(data)) - else: - raise NotImplementedError - def step(self, action): + assert len(action) == self.env_num for p, a in zip(self.parent_remote, action): p.send(['step', a]) result = [p.recv() for p in self.parent_remote] @@ -149,8 +154,8 @@ class SubprocVectorEnv(object): p.send(['reset', None]) return np.stack([p.recv() for p in self.parent_remote]) - def seed(self, seed): - if np.isscalar(seed): + def seed(self, seed=None): + if np.isscalar(seed) or seed is None: seed = [seed for _ in range(self.env_num)] for p, s in zip(self.parent_remote, seed): p.send(['seed', s]) @@ -190,7 +195,8 @@ class RayVectorEnv(object): return self.env_num def step(self, action): - result_obj = [e.step.remote(action[i]) for i, e in enumerate(self.envs)] + assert len(action) == self.env_num + result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)] obs, rew, done, info = zip(*[ray.get(r) for r in result_obj]) return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info) @@ -198,8 +204,8 @@ class RayVectorEnv(object): result_obj = [e.reset.remote() for e in self.envs] return np.stack([ray.get(r) for r in result_obj]) - def seed(self, seed): - if np.isscalar(seed): + def seed(self, seed=None): + if np.isscalar(seed) or seed is None: seed = [seed for _ in range(self.env_num)] result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)] for r in result_obj: