env test \ ray
This commit is contained in:
parent
7533e5b0ac
commit
04557fdb82
4
.github/workflows/pytest.yml
vendored
4
.github/workflows/pytest.yml
vendored
@ -33,5 +33,5 @@ jobs:
|
|||||||
./flake8.sh
|
./flake8.sh
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
pip install pytest
|
pip install pytest pytest-cov
|
||||||
pytest
|
pytest --cov tianshou
|
||||||
|
@ -1 +1,3 @@
|
|||||||
# Tianshou
|
# Tianshou
|
||||||
|
|
||||||
|

|
||||||
|
@ -27,9 +27,8 @@ class MyTestEnv(gym.Env):
|
|||||||
finished = self.index == self.size
|
finished = self.index == self.size
|
||||||
return self.index, int(finished), finished, {}
|
return self.index, int(finished), finished, {}
|
||||||
|
|
||||||
def test_framestack():
|
|
||||||
k = 4
|
def test_framestack(k=4, size=10):
|
||||||
size = 10
|
|
||||||
env = MyTestEnv(size=size)
|
env = MyTestEnv(size=size)
|
||||||
fsenv = FrameStack(env, k)
|
fsenv = FrameStack(env, k)
|
||||||
fsenv.seed()
|
fsenv.seed()
|
||||||
@ -53,5 +52,36 @@ def test_framestack():
|
|||||||
assert (rew, done) == (0, True)
|
assert (rew, done) == (0, True)
|
||||||
fsenv.close()
|
fsenv.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_vecenv(verbose=False, size=10, num=8, sleep=0.001):
|
||||||
|
env_fns = [lambda: MyTestEnv(size=size, sleep=sleep) for _ in range(num)]
|
||||||
|
venv = [
|
||||||
|
VectorEnv(env_fns, reset_after_done=True),
|
||||||
|
SubprocVectorEnv(env_fns, reset_after_done=True),
|
||||||
|
]
|
||||||
|
if verbose:
|
||||||
|
venv.append(RayVectorEnv(env_fns, reset_after_done=True))
|
||||||
|
for v in venv:
|
||||||
|
v.seed()
|
||||||
|
action_list = [1] * 5 + [0] * 10 + [1] * 9
|
||||||
|
if not verbose:
|
||||||
|
o = [v.reset() for v in venv]
|
||||||
|
for i, a in enumerate(action_list):
|
||||||
|
o = [v.step([a] * num) for v in venv]
|
||||||
|
for i in zip(*o):
|
||||||
|
for j in range(1, len(i)):
|
||||||
|
assert (i[0] == i[j]).all()
|
||||||
|
else:
|
||||||
|
t = [0, 0, 0]
|
||||||
|
for i, e in enumerate(venv):
|
||||||
|
t[i] = time.time()
|
||||||
|
e.reset()
|
||||||
|
for a in action_list:
|
||||||
|
e.step([a] * num)
|
||||||
|
t[i] = time.time() - t[i]
|
||||||
|
print(f'VectorEnv: {t[0]:.6f}s\nSubprocVectorEnv: {t[1]:.6f}s\nRayVectorEnv: {t[2]:.6f}s')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_framestack()
|
test_framestack()
|
||||||
|
test_vecenv(True)
|
||||||
|
78
tianshou/env/wrapper.py
vendored
78
tianshou/env/wrapper.py
vendored
@ -17,7 +17,7 @@ class EnvWrapper(object):
|
|||||||
return self.env.step(action)
|
return self.env.step(action)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.env.reset()
|
return self.env.reset()
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
if hasattr(self.env, 'seed'):
|
if hasattr(self.env, 'seed'):
|
||||||
@ -61,6 +61,7 @@ class VectorEnv(object):
|
|||||||
def __init__(self, env_fns, **kwargs):
|
def __init__(self, env_fns, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.envs = [_() for _ in env_fns]
|
self.envs = [_() for _ in env_fns]
|
||||||
|
self.env_num = len(self.envs)
|
||||||
self._reset_after_done = kwargs.get('reset_after_done', False)
|
self._reset_after_done = kwargs.get('reset_after_done', False)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -70,8 +71,9 @@ class VectorEnv(object):
|
|||||||
return np.stack([e.reset() for e in self.envs])
|
return np.stack([e.reset() for e in self.envs])
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
result = zip(*[e.step(action[i]) for i, e in enumerate(self.envs)])
|
assert len(action) == self.env_num
|
||||||
obs, rew, done, info = zip(*result)
|
result = zip(*[e.step(a) for e, a in zip(self.envs, action)])
|
||||||
|
obs, rew, done, info = result
|
||||||
if self._reset_after_done and sum(done):
|
if self._reset_after_done and sum(done):
|
||||||
for i, e in enumerate(self.envs):
|
for i, e in enumerate(self.envs):
|
||||||
if done[i]:
|
if done[i]:
|
||||||
@ -79,9 +81,11 @@ class VectorEnv(object):
|
|||||||
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
|
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
for e in self.envs:
|
if np.isscalar(seed) or seed is None:
|
||||||
|
seed = [seed for _ in range(self.env_num)]
|
||||||
|
for e, s in zip(self.envs, seed):
|
||||||
if hasattr(e, 'seed'):
|
if hasattr(e, 'seed'):
|
||||||
e.seed(seed)
|
e.seed(s)
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
for e in self.envs:
|
for e in self.envs:
|
||||||
@ -93,16 +97,40 @@ class VectorEnv(object):
|
|||||||
e.close()
|
e.close()
|
||||||
|
|
||||||
|
|
||||||
|
def worker(parent, p, env_fn_wrapper, kwargs):
|
||||||
|
reset_after_done = kwargs.get('reset_after_done', True)
|
||||||
|
parent.close()
|
||||||
|
env = env_fn_wrapper.data()
|
||||||
|
while True:
|
||||||
|
cmd, data = p.recv()
|
||||||
|
if cmd == 'step':
|
||||||
|
obs, rew, done, info = env.step(data)
|
||||||
|
if reset_after_done and done:
|
||||||
|
# s_ is useless when episode finishes
|
||||||
|
obs = env.reset()
|
||||||
|
p.send([obs, rew, done, info])
|
||||||
|
elif cmd == 'reset':
|
||||||
|
p.send(env.reset())
|
||||||
|
elif cmd == 'close':
|
||||||
|
p.close()
|
||||||
|
break
|
||||||
|
elif cmd == 'render':
|
||||||
|
p.send(env.render())
|
||||||
|
elif cmd == 'seed':
|
||||||
|
p.send(env.seed(data))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class SubprocVectorEnv(object):
|
class SubprocVectorEnv(object):
|
||||||
"""docstring for SubProcVectorEnv"""
|
"""docstring for SubProcVectorEnv"""
|
||||||
def __init__(self, env_fns, **kwargs):
|
def __init__(self, env_fns, **kwargs):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.env_num = len(env_fns)
|
self.env_num = len(env_fns)
|
||||||
self.closed = False
|
self.closed = False
|
||||||
self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)])
|
self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)])
|
||||||
self.processes = [
|
self.processes = [
|
||||||
Process(target=self.worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True)
|
Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True)
|
||||||
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)
|
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)
|
||||||
]
|
]
|
||||||
for p in self.processes:
|
for p in self.processes:
|
||||||
@ -113,31 +141,8 @@ class SubprocVectorEnv(object):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.env_num
|
return self.env_num
|
||||||
|
|
||||||
def worker(self, parent, p, env_fn_wrapper, **kwargs):
|
|
||||||
reset_after_done = kwargs.get('reset_after_done', True)
|
|
||||||
parent.close()
|
|
||||||
env = env_fn_wrapper.data()
|
|
||||||
while True:
|
|
||||||
cmd, data = p.recv()
|
|
||||||
if cmd == 'step':
|
|
||||||
obs, rew, done, info = env.step(data)
|
|
||||||
if reset_after_done and done:
|
|
||||||
# s_ is useless when episode finishes
|
|
||||||
obs = env.reset()
|
|
||||||
p.send([obs, rew, done, info])
|
|
||||||
elif cmd == 'reset':
|
|
||||||
p.send(env.reset())
|
|
||||||
elif cmd == 'close':
|
|
||||||
p.close()
|
|
||||||
break
|
|
||||||
elif cmd == 'render':
|
|
||||||
p.send(env.render())
|
|
||||||
elif cmd == 'seed':
|
|
||||||
p.send(env.seed(data))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
assert len(action) == self.env_num
|
||||||
for p, a in zip(self.parent_remote, action):
|
for p, a in zip(self.parent_remote, action):
|
||||||
p.send(['step', a])
|
p.send(['step', a])
|
||||||
result = [p.recv() for p in self.parent_remote]
|
result = [p.recv() for p in self.parent_remote]
|
||||||
@ -149,8 +154,8 @@ class SubprocVectorEnv(object):
|
|||||||
p.send(['reset', None])
|
p.send(['reset', None])
|
||||||
return np.stack([p.recv() for p in self.parent_remote])
|
return np.stack([p.recv() for p in self.parent_remote])
|
||||||
|
|
||||||
def seed(self, seed):
|
def seed(self, seed=None):
|
||||||
if np.isscalar(seed):
|
if np.isscalar(seed) or seed is None:
|
||||||
seed = [seed for _ in range(self.env_num)]
|
seed = [seed for _ in range(self.env_num)]
|
||||||
for p, s in zip(self.parent_remote, seed):
|
for p, s in zip(self.parent_remote, seed):
|
||||||
p.send(['seed', s])
|
p.send(['seed', s])
|
||||||
@ -190,7 +195,8 @@ class RayVectorEnv(object):
|
|||||||
return self.env_num
|
return self.env_num
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
result_obj = [e.step.remote(action[i]) for i, e in enumerate(self.envs)]
|
assert len(action) == self.env_num
|
||||||
|
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)]
|
||||||
obs, rew, done, info = zip(*[ray.get(r) for r in result_obj])
|
obs, rew, done, info = zip(*[ray.get(r) for r in result_obj])
|
||||||
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
|
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
|
||||||
|
|
||||||
@ -198,8 +204,8 @@ class RayVectorEnv(object):
|
|||||||
result_obj = [e.reset.remote() for e in self.envs]
|
result_obj = [e.reset.remote() for e in self.envs]
|
||||||
return np.stack([ray.get(r) for r in result_obj])
|
return np.stack([ray.get(r) for r in result_obj])
|
||||||
|
|
||||||
def seed(self, seed):
|
def seed(self, seed=None):
|
||||||
if np.isscalar(seed):
|
if np.isscalar(seed) or seed is None:
|
||||||
seed = [seed for _ in range(self.env_num)]
|
seed = [seed for _ in range(self.env_num)]
|
||||||
result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)]
|
result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)]
|
||||||
for r in result_obj:
|
for r in result_obj:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user