env test \ ray

This commit is contained in:
Trinkle23897 2020-03-11 16:14:53 +08:00
parent 7533e5b0ac
commit 04557fdb82
4 changed files with 80 additions and 42 deletions

View File

@ -33,5 +33,5 @@ jobs:
./flake8.sh
- name: Test with pytest
run: |
pip install pytest
pytest
pip install pytest pytest-cov
pytest --cov tianshou

View File

@ -1 +1,3 @@
# Tianshou
![Python package](https://github.com/Trinkle23897/tianshou/workflows/Python%20package/badge.svg)

View File

@ -27,9 +27,8 @@ class MyTestEnv(gym.Env):
finished = self.index == self.size
return self.index, int(finished), finished, {}
def test_framestack():
k = 4
size = 10
def test_framestack(k=4, size=10):
env = MyTestEnv(size=size)
fsenv = FrameStack(env, k)
fsenv.seed()
@ -53,5 +52,36 @@ def test_framestack():
assert (rew, done) == (0, True)
fsenv.close()
def test_vecenv(verbose=False, size=10, num=8, sleep=0.001):
env_fns = [lambda: MyTestEnv(size=size, sleep=sleep) for _ in range(num)]
venv = [
VectorEnv(env_fns, reset_after_done=True),
SubprocVectorEnv(env_fns, reset_after_done=True),
]
if verbose:
venv.append(RayVectorEnv(env_fns, reset_after_done=True))
for v in venv:
v.seed()
action_list = [1] * 5 + [0] * 10 + [1] * 9
if not verbose:
o = [v.reset() for v in venv]
for i, a in enumerate(action_list):
o = [v.step([a] * num) for v in venv]
for i in zip(*o):
for j in range(1, len(i)):
assert (i[0] == i[j]).all()
else:
t = [0, 0, 0]
for i, e in enumerate(venv):
t[i] = time.time()
e.reset()
for a in action_list:
e.step([a] * num)
t[i] = time.time() - t[i]
print(f'VectorEnv: {t[0]:.6f}s\nSubprocVectorEnv: {t[1]:.6f}s\nRayVectorEnv: {t[2]:.6f}s')
if __name__ == '__main__':
test_framestack()
test_framestack()
test_vecenv(True)

View File

@ -17,7 +17,7 @@ class EnvWrapper(object):
return self.env.step(action)
def reset(self):
self.env.reset()
return self.env.reset()
def seed(self, seed=None):
if hasattr(self.env, 'seed'):
@ -61,6 +61,7 @@ class VectorEnv(object):
def __init__(self, env_fns, **kwargs):
super().__init__()
self.envs = [_() for _ in env_fns]
self.env_num = len(self.envs)
self._reset_after_done = kwargs.get('reset_after_done', False)
def __len__(self):
@ -70,8 +71,9 @@ class VectorEnv(object):
return np.stack([e.reset() for e in self.envs])
def step(self, action):
result = zip(*[e.step(action[i]) for i, e in enumerate(self.envs)])
obs, rew, done, info = zip(*result)
assert len(action) == self.env_num
result = zip(*[e.step(a) for e, a in zip(self.envs, action)])
obs, rew, done, info = result
if self._reset_after_done and sum(done):
for i, e in enumerate(self.envs):
if done[i]:
@ -79,9 +81,11 @@ class VectorEnv(object):
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
def seed(self, seed=None):
for e in self.envs:
if np.isscalar(seed) or seed is None:
seed = [seed for _ in range(self.env_num)]
for e, s in zip(self.envs, seed):
if hasattr(e, 'seed'):
e.seed(seed)
e.seed(s)
def render(self):
for e in self.envs:
@ -93,16 +97,40 @@ class VectorEnv(object):
e.close()
def worker(parent, p, env_fn_wrapper, kwargs):
reset_after_done = kwargs.get('reset_after_done', True)
parent.close()
env = env_fn_wrapper.data()
while True:
cmd, data = p.recv()
if cmd == 'step':
obs, rew, done, info = env.step(data)
if reset_after_done and done:
# s_ is useless when episode finishes
obs = env.reset()
p.send([obs, rew, done, info])
elif cmd == 'reset':
p.send(env.reset())
elif cmd == 'close':
p.close()
break
elif cmd == 'render':
p.send(env.render())
elif cmd == 'seed':
p.send(env.seed(data))
else:
raise NotImplementedError
class SubprocVectorEnv(object):
"""docstring for SubProcVectorEnv"""
def __init__(self, env_fns, **kwargs):
super().__init__()
self.env_num = len(env_fns)
self.closed = False
self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)])
self.processes = [
Process(target=self.worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True)
Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True)
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)
]
for p in self.processes:
@ -113,31 +141,8 @@ class SubprocVectorEnv(object):
def __len__(self):
return self.env_num
def worker(self, parent, p, env_fn_wrapper, **kwargs):
reset_after_done = kwargs.get('reset_after_done', True)
parent.close()
env = env_fn_wrapper.data()
while True:
cmd, data = p.recv()
if cmd == 'step':
obs, rew, done, info = env.step(data)
if reset_after_done and done:
# s_ is useless when episode finishes
obs = env.reset()
p.send([obs, rew, done, info])
elif cmd == 'reset':
p.send(env.reset())
elif cmd == 'close':
p.close()
break
elif cmd == 'render':
p.send(env.render())
elif cmd == 'seed':
p.send(env.seed(data))
else:
raise NotImplementedError
def step(self, action):
assert len(action) == self.env_num
for p, a in zip(self.parent_remote, action):
p.send(['step', a])
result = [p.recv() for p in self.parent_remote]
@ -149,8 +154,8 @@ class SubprocVectorEnv(object):
p.send(['reset', None])
return np.stack([p.recv() for p in self.parent_remote])
def seed(self, seed):
if np.isscalar(seed):
def seed(self, seed=None):
if np.isscalar(seed) or seed is None:
seed = [seed for _ in range(self.env_num)]
for p, s in zip(self.parent_remote, seed):
p.send(['seed', s])
@ -190,7 +195,8 @@ class RayVectorEnv(object):
return self.env_num
def step(self, action):
result_obj = [e.step.remote(action[i]) for i, e in enumerate(self.envs)]
assert len(action) == self.env_num
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)]
obs, rew, done, info = zip(*[ray.get(r) for r in result_obj])
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
@ -198,8 +204,8 @@ class RayVectorEnv(object):
result_obj = [e.reset.remote() for e in self.envs]
return np.stack([ray.get(r) for r in result_obj])
def seed(self, seed):
if np.isscalar(seed):
def seed(self, seed=None):
if np.isscalar(seed) or seed is None:
seed = [seed for _ in range(self.env_num)]
result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)]
for r in result_obj: