diff --git a/test/test_buffer.py b/test/test_buffer.py index ea6c9b2..ae0267e 100644 --- a/test/test_buffer.py +++ b/test/test_buffer.py @@ -1,19 +1,25 @@ from tianshou.data import ReplayBuffer -from test.test_env import MyTestEnv +if __name__ == '__main__': + from test_env import MyTestEnv +else: + from test.test_env import MyTestEnv -def test_replaybuffer(bufsize=20): - env = MyTestEnv(10) +def test_replaybuffer(size=10, bufsize=20): + env = MyTestEnv(size) buf = ReplayBuffer(bufsize) obs = env.reset() - action_list = [1] * 5 + [0] * 10 + [1] * 9 + action_list = [1] * 5 + [0] * 10 + [1] * 15 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) buf.add(obs, a, rew, done, obs_next, info) assert len(buf) == min(bufsize, i + 1), print(len(buf), i) indice = buf.sample_indice(4) data = buf.sample(4) + assert (indice < len(buf)).all() + assert (data.obs < size).all() + assert (0 <= data.done).all() and (data.done <= 1).all() if __name__ == '__main__': - test_replaybuffer() \ No newline at end of file + test_replaybuffer() diff --git a/test/test_env.py b/test/test_env.py index 67912e0..76f2996 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -59,11 +59,11 @@ def test_vecenv(verbose=False, size=10, num=8, sleep=0.001): VectorEnv(env_fns, reset_after_done=True), SubprocVectorEnv(env_fns, reset_after_done=True), ] - if verbose: + if __name__ == '__main__': venv.append(RayVectorEnv(env_fns, reset_after_done=True)) for v in venv: v.seed() - action_list = [1] * 5 + [0] * 10 + [1] * 9 + action_list = [1] * 5 + [0] * 10 + [1] * 15 if not verbose: o = [v.reset() for v in venv] for i, a in enumerate(action_list): diff --git a/tianshou/env/wrapper.py b/tianshou/env/wrapper.py index 852ac4e..0b03292 100644 --- a/tianshou/env/wrapper.py +++ b/tianshou/env/wrapper.py @@ -75,9 +75,9 @@ class VectorEnv(object): result = zip(*[e.step(a) for e, a in zip(self.envs, action)]) obs, rew, done, info = result if self._reset_after_done and sum(done): - for i, e in enumerate(self.envs): - if done[i]: - e.reset() + obs = np.stack(obs) + for i in np.where(done)[0]: + obs[i] = self.envs[i].reset() return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info) def seed(self, seed=None): @@ -198,6 +198,14 @@ class RayVectorEnv(object): assert len(action) == self.env_num result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)] obs, rew, done, info = zip(*[ray.get(r) for r in result_obj]) + if self._reset_after_done and sum(done): + obs = np.stack(obs) + index = np.where(done)[0] + result_obj = [] + for i in range(len(index)): + result_obj.append(self.envs[index[i]].reset.remote()) + for i in range(len(index)): + obs[index[i]] = ray.get(result_obj[i]) return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info) def reset(self):