fix a bug
This commit is contained in:
parent
6632e47b9d
commit
4a1a7dd670
@ -1,18 +1,24 @@
|
|||||||
from tianshou.data import ReplayBuffer
|
from tianshou.data import ReplayBuffer
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from test_env import MyTestEnv
|
||||||
|
else:
|
||||||
from test.test_env import MyTestEnv
|
from test.test_env import MyTestEnv
|
||||||
|
|
||||||
|
|
||||||
def test_replaybuffer(bufsize=20):
|
def test_replaybuffer(size=10, bufsize=20):
|
||||||
env = MyTestEnv(10)
|
env = MyTestEnv(size)
|
||||||
buf = ReplayBuffer(bufsize)
|
buf = ReplayBuffer(bufsize)
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
action_list = [1] * 5 + [0] * 10 + [1] * 9
|
action_list = [1] * 5 + [0] * 10 + [1] * 15
|
||||||
for i, a in enumerate(action_list):
|
for i, a in enumerate(action_list):
|
||||||
obs_next, rew, done, info = env.step(a)
|
obs_next, rew, done, info = env.step(a)
|
||||||
buf.add(obs, a, rew, done, obs_next, info)
|
buf.add(obs, a, rew, done, obs_next, info)
|
||||||
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
|
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
|
||||||
indice = buf.sample_indice(4)
|
indice = buf.sample_indice(4)
|
||||||
data = buf.sample(4)
|
data = buf.sample(4)
|
||||||
|
assert (indice < len(buf)).all()
|
||||||
|
assert (data.obs < size).all()
|
||||||
|
assert (0 <= data.done).all() and (data.done <= 1).all()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -59,11 +59,11 @@ def test_vecenv(verbose=False, size=10, num=8, sleep=0.001):
|
|||||||
VectorEnv(env_fns, reset_after_done=True),
|
VectorEnv(env_fns, reset_after_done=True),
|
||||||
SubprocVectorEnv(env_fns, reset_after_done=True),
|
SubprocVectorEnv(env_fns, reset_after_done=True),
|
||||||
]
|
]
|
||||||
if verbose:
|
if __name__ == '__main__':
|
||||||
venv.append(RayVectorEnv(env_fns, reset_after_done=True))
|
venv.append(RayVectorEnv(env_fns, reset_after_done=True))
|
||||||
for v in venv:
|
for v in venv:
|
||||||
v.seed()
|
v.seed()
|
||||||
action_list = [1] * 5 + [0] * 10 + [1] * 9
|
action_list = [1] * 5 + [0] * 10 + [1] * 15
|
||||||
if not verbose:
|
if not verbose:
|
||||||
o = [v.reset() for v in venv]
|
o = [v.reset() for v in venv]
|
||||||
for i, a in enumerate(action_list):
|
for i, a in enumerate(action_list):
|
||||||
|
14
tianshou/env/wrapper.py
vendored
14
tianshou/env/wrapper.py
vendored
@ -75,9 +75,9 @@ class VectorEnv(object):
|
|||||||
result = zip(*[e.step(a) for e, a in zip(self.envs, action)])
|
result = zip(*[e.step(a) for e, a in zip(self.envs, action)])
|
||||||
obs, rew, done, info = result
|
obs, rew, done, info = result
|
||||||
if self._reset_after_done and sum(done):
|
if self._reset_after_done and sum(done):
|
||||||
for i, e in enumerate(self.envs):
|
obs = np.stack(obs)
|
||||||
if done[i]:
|
for i in np.where(done)[0]:
|
||||||
e.reset()
|
obs[i] = self.envs[i].reset()
|
||||||
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
|
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
@ -198,6 +198,14 @@ class RayVectorEnv(object):
|
|||||||
assert len(action) == self.env_num
|
assert len(action) == self.env_num
|
||||||
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)]
|
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)]
|
||||||
obs, rew, done, info = zip(*[ray.get(r) for r in result_obj])
|
obs, rew, done, info = zip(*[ray.get(r) for r in result_obj])
|
||||||
|
if self._reset_after_done and sum(done):
|
||||||
|
obs = np.stack(obs)
|
||||||
|
index = np.where(done)[0]
|
||||||
|
result_obj = []
|
||||||
|
for i in range(len(index)):
|
||||||
|
result_obj.append(self.envs[index[i]].reset.remote())
|
||||||
|
for i in range(len(index)):
|
||||||
|
obs[index[i]] = ray.get(result_obj[i])
|
||||||
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
|
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user