env test \ ray
This commit is contained in:
parent
7533e5b0ac
commit
04557fdb82
4
.github/workflows/pytest.yml
vendored
4
.github/workflows/pytest.yml
vendored
@ -33,5 +33,5 @@ jobs:
|
||||
./flake8.sh
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pip install pytest
|
||||
pytest
|
||||
pip install pytest pytest-cov
|
||||
pytest --cov tianshou
|
||||
|
@ -1 +1,3 @@
|
||||
# Tianshou
|
||||
|
||||

|
||||
|
@ -27,9 +27,8 @@ class MyTestEnv(gym.Env):
|
||||
finished = self.index == self.size
|
||||
return self.index, int(finished), finished, {}
|
||||
|
||||
def test_framestack():
|
||||
k = 4
|
||||
size = 10
|
||||
|
||||
def test_framestack(k=4, size=10):
|
||||
env = MyTestEnv(size=size)
|
||||
fsenv = FrameStack(env, k)
|
||||
fsenv.seed()
|
||||
@ -53,5 +52,36 @@ def test_framestack():
|
||||
assert (rew, done) == (0, True)
|
||||
fsenv.close()
|
||||
|
||||
|
||||
def test_vecenv(verbose=False, size=10, num=8, sleep=0.001):
|
||||
env_fns = [lambda: MyTestEnv(size=size, sleep=sleep) for _ in range(num)]
|
||||
venv = [
|
||||
VectorEnv(env_fns, reset_after_done=True),
|
||||
SubprocVectorEnv(env_fns, reset_after_done=True),
|
||||
]
|
||||
if verbose:
|
||||
venv.append(RayVectorEnv(env_fns, reset_after_done=True))
|
||||
for v in venv:
|
||||
v.seed()
|
||||
action_list = [1] * 5 + [0] * 10 + [1] * 9
|
||||
if not verbose:
|
||||
o = [v.reset() for v in venv]
|
||||
for i, a in enumerate(action_list):
|
||||
o = [v.step([a] * num) for v in venv]
|
||||
for i in zip(*o):
|
||||
for j in range(1, len(i)):
|
||||
assert (i[0] == i[j]).all()
|
||||
else:
|
||||
t = [0, 0, 0]
|
||||
for i, e in enumerate(venv):
|
||||
t[i] = time.time()
|
||||
e.reset()
|
||||
for a in action_list:
|
||||
e.step([a] * num)
|
||||
t[i] = time.time() - t[i]
|
||||
print(f'VectorEnv: {t[0]:.6f}s\nSubprocVectorEnv: {t[1]:.6f}s\nRayVectorEnv: {t[2]:.6f}s')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_framestack()
|
||||
test_framestack()
|
||||
test_vecenv(True)
|
||||
|
78
tianshou/env/wrapper.py
vendored
78
tianshou/env/wrapper.py
vendored
@ -17,7 +17,7 @@ class EnvWrapper(object):
|
||||
return self.env.step(action)
|
||||
|
||||
def reset(self):
|
||||
self.env.reset()
|
||||
return self.env.reset()
|
||||
|
||||
def seed(self, seed=None):
|
||||
if hasattr(self.env, 'seed'):
|
||||
@ -61,6 +61,7 @@ class VectorEnv(object):
|
||||
def __init__(self, env_fns, **kwargs):
|
||||
super().__init__()
|
||||
self.envs = [_() for _ in env_fns]
|
||||
self.env_num = len(self.envs)
|
||||
self._reset_after_done = kwargs.get('reset_after_done', False)
|
||||
|
||||
def __len__(self):
|
||||
@ -70,8 +71,9 @@ class VectorEnv(object):
|
||||
return np.stack([e.reset() for e in self.envs])
|
||||
|
||||
def step(self, action):
|
||||
result = zip(*[e.step(action[i]) for i, e in enumerate(self.envs)])
|
||||
obs, rew, done, info = zip(*result)
|
||||
assert len(action) == self.env_num
|
||||
result = zip(*[e.step(a) for e, a in zip(self.envs, action)])
|
||||
obs, rew, done, info = result
|
||||
if self._reset_after_done and sum(done):
|
||||
for i, e in enumerate(self.envs):
|
||||
if done[i]:
|
||||
@ -79,9 +81,11 @@ class VectorEnv(object):
|
||||
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
|
||||
|
||||
def seed(self, seed=None):
|
||||
for e in self.envs:
|
||||
if np.isscalar(seed) or seed is None:
|
||||
seed = [seed for _ in range(self.env_num)]
|
||||
for e, s in zip(self.envs, seed):
|
||||
if hasattr(e, 'seed'):
|
||||
e.seed(seed)
|
||||
e.seed(s)
|
||||
|
||||
def render(self):
|
||||
for e in self.envs:
|
||||
@ -93,16 +97,40 @@ class VectorEnv(object):
|
||||
e.close()
|
||||
|
||||
|
||||
def worker(parent, p, env_fn_wrapper, kwargs):
|
||||
reset_after_done = kwargs.get('reset_after_done', True)
|
||||
parent.close()
|
||||
env = env_fn_wrapper.data()
|
||||
while True:
|
||||
cmd, data = p.recv()
|
||||
if cmd == 'step':
|
||||
obs, rew, done, info = env.step(data)
|
||||
if reset_after_done and done:
|
||||
# s_ is useless when episode finishes
|
||||
obs = env.reset()
|
||||
p.send([obs, rew, done, info])
|
||||
elif cmd == 'reset':
|
||||
p.send(env.reset())
|
||||
elif cmd == 'close':
|
||||
p.close()
|
||||
break
|
||||
elif cmd == 'render':
|
||||
p.send(env.render())
|
||||
elif cmd == 'seed':
|
||||
p.send(env.seed(data))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SubprocVectorEnv(object):
|
||||
"""docstring for SubProcVectorEnv"""
|
||||
def __init__(self, env_fns, **kwargs):
|
||||
|
||||
super().__init__()
|
||||
self.env_num = len(env_fns)
|
||||
self.closed = False
|
||||
self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)])
|
||||
self.processes = [
|
||||
Process(target=self.worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True)
|
||||
Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True)
|
||||
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)
|
||||
]
|
||||
for p in self.processes:
|
||||
@ -113,31 +141,8 @@ class SubprocVectorEnv(object):
|
||||
def __len__(self):
|
||||
return self.env_num
|
||||
|
||||
def worker(self, parent, p, env_fn_wrapper, **kwargs):
|
||||
reset_after_done = kwargs.get('reset_after_done', True)
|
||||
parent.close()
|
||||
env = env_fn_wrapper.data()
|
||||
while True:
|
||||
cmd, data = p.recv()
|
||||
if cmd == 'step':
|
||||
obs, rew, done, info = env.step(data)
|
||||
if reset_after_done and done:
|
||||
# s_ is useless when episode finishes
|
||||
obs = env.reset()
|
||||
p.send([obs, rew, done, info])
|
||||
elif cmd == 'reset':
|
||||
p.send(env.reset())
|
||||
elif cmd == 'close':
|
||||
p.close()
|
||||
break
|
||||
elif cmd == 'render':
|
||||
p.send(env.render())
|
||||
elif cmd == 'seed':
|
||||
p.send(env.seed(data))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def step(self, action):
|
||||
assert len(action) == self.env_num
|
||||
for p, a in zip(self.parent_remote, action):
|
||||
p.send(['step', a])
|
||||
result = [p.recv() for p in self.parent_remote]
|
||||
@ -149,8 +154,8 @@ class SubprocVectorEnv(object):
|
||||
p.send(['reset', None])
|
||||
return np.stack([p.recv() for p in self.parent_remote])
|
||||
|
||||
def seed(self, seed):
|
||||
if np.isscalar(seed):
|
||||
def seed(self, seed=None):
|
||||
if np.isscalar(seed) or seed is None:
|
||||
seed = [seed for _ in range(self.env_num)]
|
||||
for p, s in zip(self.parent_remote, seed):
|
||||
p.send(['seed', s])
|
||||
@ -190,7 +195,8 @@ class RayVectorEnv(object):
|
||||
return self.env_num
|
||||
|
||||
def step(self, action):
|
||||
result_obj = [e.step.remote(action[i]) for i, e in enumerate(self.envs)]
|
||||
assert len(action) == self.env_num
|
||||
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)]
|
||||
obs, rew, done, info = zip(*[ray.get(r) for r in result_obj])
|
||||
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
|
||||
|
||||
@ -198,8 +204,8 @@ class RayVectorEnv(object):
|
||||
result_obj = [e.reset.remote() for e in self.envs]
|
||||
return np.stack([ray.get(r) for r in result_obj])
|
||||
|
||||
def seed(self, seed):
|
||||
if np.isscalar(seed):
|
||||
def seed(self, seed=None):
|
||||
if np.isscalar(seed) or seed is None:
|
||||
seed = [seed for _ in range(self.env_num)]
|
||||
result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)]
|
||||
for r in result_obj:
|
||||
|
Loading…
x
Reference in New Issue
Block a user