diff --git a/test/test_batch.py b/test/test_batch.py index 205e1be..c3c62f5 100644 --- a/test/test_batch.py +++ b/test/test_batch.py @@ -1,3 +1,4 @@ +import pytest import numpy as np from tianshou.data import Batch @@ -10,6 +11,9 @@ def test_batch(): batch.append(batch) assert batch.obs == [1, 1] assert batch.np.shape == (6, 4) + assert batch[0].obs == batch[1].obs + with pytest.raises(IndexError): + batch[2] if __name__ == '__main__': diff --git a/test/test_buffer.py b/test/test_buffer.py index 41ab3e4..cd6d59a 100644 --- a/test/test_buffer.py +++ b/test/test_buffer.py @@ -8,16 +8,23 @@ else: # pytest def test_replaybuffer(size=10, bufsize=20): env = MyTestEnv(size) buf = ReplayBuffer(bufsize) + buf2 = ReplayBuffer(bufsize) obs = env.reset() - action_list = [1] * 5 + [0] * 10 + [1] * 15 + action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) buf.add(obs, a, rew, done, obs_next, info) + obs = obs_next assert len(buf) == min(bufsize, i + 1), print(len(buf), i) data, indice = buf.sample(bufsize * 2) assert (indice < len(buf)).all() assert (data.obs < size).all() assert (0 <= data.done).all() and (data.done <= 1).all() + assert len(buf) > len(buf2) + buf2.update(buf) + assert len(buf) == len(buf2) + assert buf2[0].obs == buf[5].obs + assert buf2[-1].obs == buf[4].obs if __name__ == '__main__': diff --git a/test/test_dqn.py b/test/test_dqn.py index 59bda18..98d6404 100644 --- a/test/test_dqn.py +++ b/test/test_dqn.py @@ -1,4 +1,5 @@ import gym +import time import tqdm import torch import argparse @@ -36,7 +37,7 @@ class Net(nn.Module): def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--eps-test', type=float, default=0.05) parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) @@ -65,6 +66,7 @@ def test_dqn(args=get_args()): train_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.training_num)], reset_after_done=True) + # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)], reset_after_done=False) @@ -124,11 +126,26 @@ def test_dqn(args=get_args()): if best_reward < result['reward']: best_reward = result['reward'] best_epoch = epoch - print(f'Epoch #{epoch + 1} reward: {result["reward"]:.6f}, ' + print(f'Epoch #{epoch + 1} test_reward: {result["reward"]:.6f}, ' f'best_reward: {best_reward:.6f} in #{best_epoch}') if args.task == 'CartPole-v0' and best_reward >= 200: break assert best_reward >= 200 + if __name__ == '__main__': + # let's watch its performance! + env = gym.make(args.task) + obs = env.reset() + done = False + total = 0 + while not done: + q, _ = net([obs]) + action = q.max(dim=1)[1] + obs, rew, done, info = env.step(action[0].detach().cpu().numpy()) + total += rew + env.render() + time.sleep(1 / 100) + env.close() + print(f'Final test: {total}') return best_reward diff --git a/test/test_env.py b/test/test_env.py index 087ec34..9f38606 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -1,5 +1,6 @@ import gym import time +import pytest import numpy as np from tianshou.env import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv @@ -8,24 +9,28 @@ class MyTestEnv(gym.Env): def __init__(self, size, sleep=0): self.size = size self.sleep = sleep - self.index = 0 + self.reset() def reset(self): + self.done = False self.index = 0 return self.index def step(self, action): + if self.done: + raise ValueError('step after done !!!') if self.sleep > 0: time.sleep(self.sleep) if self.index == self.size: + self.done = True return self.index, 0, True, {} if action == 0: self.index = max(self.index - 1, 0) return self.index, 0, False, {} elif action == 1: self.index += 1 - finished = self.index == self.size - return self.index, int(finished), finished, {} + self.done = self.index == self.size + return self.index, int(self.done), self.done, {} def test_framestack(k=4, size=10): @@ -47,19 +52,21 @@ def test_framestack(k=4, size=10): obs, rew, done, info = fsenv.step(1) assert abs(obs - np.array([7, 8, 9, 10])).sum() == 0 assert (rew, done) == (1, True) - obs, rew, done, info = fsenv.step(0) - assert abs(obs - np.array([8, 9, 10, 10])).sum() == 0 - assert (rew, done) == (0, True) + with pytest.raises(ValueError): + obs, rew, done, info = fsenv.step(0) + # assert abs(obs - np.array([8, 9, 10, 10])).sum() == 0 + # assert (rew, done) == (0, True) fsenv.close() -def test_vecenv(verbose=False, size=10, num=8, sleep=0.001): +def test_vecenv(size=10, num=8, sleep=0.001): + verbose = __name__ == '__main__' env_fns = [lambda: MyTestEnv(size=size, sleep=sleep) for _ in range(num)] venv = [ VectorEnv(env_fns, reset_after_done=True), SubprocVectorEnv(env_fns, reset_after_done=True), ] - if __name__ == '__main__': + if verbose: venv.append(RayVectorEnv(env_fns, reset_after_done=True)) for v in venv: v.seed() @@ -86,6 +93,40 @@ def test_vecenv(verbose=False, size=10, num=8, sleep=0.001): v.close() +def test_vecenv2(): + verbose = __name__ == '__main__' + env_fns = [ + lambda: MyTestEnv(size=1), + lambda: MyTestEnv(size=2), + lambda: MyTestEnv(size=3), + lambda: MyTestEnv(size=4), + ] + num = len(env_fns) + venv = [ + VectorEnv(env_fns, reset_after_done=False), + SubprocVectorEnv(env_fns, reset_after_done=False), + ] + if verbose: + venv.append(RayVectorEnv(env_fns, reset_after_done=False)) + for v in venv: + v.seed() + o = [v.reset() for v in venv] + action_list = [1] * 6 + for i, a in enumerate(action_list): + o = [v.step([a] * num) for v in venv] + if verbose: + print(o[0]) + print(o[1]) + print(o[2]) + print('---') + for i in zip(*o): + for j in range(1, len(i)): + assert (i[0] == i[j]).all() + for v in venv: + v.close() + + if __name__ == '__main__': test_framestack() - test_vecenv(True) + test_vecenv() + test_vecenv2() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 07fdbe4..ecd192a 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -13,7 +13,7 @@ class Batch(object): b = Batch() for k in self.__dict__.keys(): if self.__dict__[k] is not None: - b.update(k=self.__dict__[k][index]) + b.update(**{k: self.__dict__[k][index]}) return b def update(self, **kwargs): diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 1f2a48f..27a315b 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -27,10 +27,14 @@ class ReplayBuffer(object): self.__dict__[name][self._index] = inst def update(self, buffer): - for i in range(len(buffer)): + i = begin = buffer._index % len(buffer) + while True: self.add( buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i], buffer.obs_next[i], buffer.info[i]) + i = (i + 1) % len(buffer) + if i == begin: + break def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None): ''' diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index d09da86..b19a05f 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1,3 +1,4 @@ +import time import torch import numpy as np from copy import deepcopy @@ -70,13 +71,13 @@ class Collector(object): if hasattr(self.env, 'close'): self.env.close() - def _make_batch(data): + def _make_batch(self, data): if isinstance(data, np.ndarray): return data[None] else: return [data] - def collect(self, n_step=0, n_episode=0): + def collect(self, n_step=0, n_episode=0, render=0): assert sum([(n_step > 0), (n_episode > 0)]) == 1,\ "One and only one collection number specification permitted!" cur_step = 0 @@ -98,7 +99,10 @@ class Collector(object): self.state = result.state if hasattr(result, 'state') else None self._act = result.act obs_next, self._rew, self._done, self._info = self.env.step( - self._act) + self._act if self._multi_env else self._act[0]) + if render > 0: + self.env.render() + time.sleep(render) self.length += 1 self.reward += self._rew if self._multi_env: @@ -147,6 +151,7 @@ class Collector(object): self.stat_length.add(self.length) self.reward, self.length = 0, 0 self.state = None + self._obs = self.env.reset() if n_episode > 0 and cur_episode >= n_episode: break if n_step > 0 and cur_step >= n_step: diff --git a/tianshou/env/wrapper.py b/tianshou/env/wrapper.py index 00aee18..f456283 100644 --- a/tianshou/env/wrapper.py +++ b/tianshou/env/wrapper.py @@ -62,6 +62,7 @@ class BaseVectorEnv(ABC): self._env_fns = env_fns self.env_num = len(env_fns) self._reset_after_done = reset_after_done + self._done = np.zeros(self.env_num) def is_reset_after_done(self): return self._reset_after_done @@ -98,17 +99,26 @@ class VectorEnv(BaseVectorEnv): self.envs = [_() for _ in env_fns] def reset(self): - return np.stack([e.reset() for e in self.envs]) + self._done = np.zeros(self.env_num) + self._obs = np.stack([e.reset() for e in self.envs]) + return self._obs def step(self, action): assert len(action) == self.env_num - result = zip(*[e.step(a) for e, a in zip(self.envs, action)]) - obs, rew, done, info = result - if self._reset_after_done and sum(done): - obs = np.stack(obs) - for i in np.where(done)[0]: - obs[i] = self.envs[i].reset() - return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info) + result = [] + for i, e in enumerate(self.envs): + if not self.is_reset_after_done() and self._done[i]: + result.append([ + self._obs[i], self._rew[i], self._done[i], self._info[i]]) + else: + result.append(e.step(action[i])) + self._obs, self._rew, self._done, self._info = zip(*result) + if self.is_reset_after_done() and sum(self._done): + self._obs = np.stack(self._obs) + for i in np.where(self._done)[0]: + self._obs[i] = self.envs[i].reset() + return np.stack(self._obs), np.stack(self._rew),\ + np.stack(self._done), np.stack(self._info) def seed(self, seed=None): if np.isscalar(seed) or seed is None: @@ -130,15 +140,18 @@ class VectorEnv(BaseVectorEnv): def worker(parent, p, env_fn_wrapper, reset_after_done): parent.close() env = env_fn_wrapper.data() + done = False while True: cmd, data = p.recv() if cmd == 'step': - obs, rew, done, info = env.step(data) + if reset_after_done or not done: + obs, rew, done, info = env.step(data) if reset_after_done and done: # s_ is useless when episode finishes obs = env.reset() p.send([obs, rew, done, info]) elif cmd == 'reset': + done = False p.send(env.reset()) elif cmd == 'close': p.close() @@ -225,21 +238,36 @@ class RayVectorEnv(BaseVectorEnv): def step(self, action): assert len(action) == self.env_num - result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)] - obs, rew, done, info = zip(*[ray.get(r) for r in result_obj]) - if self._reset_after_done and sum(done): - obs = np.stack(obs) - index = np.where(done)[0] + result_obj = [] + for i, e in enumerate(self.envs): + if not self.is_reset_after_done() and self._done[i]: + result_obj.append(None) + else: + result_obj.append(e.step.remote(action[i])) + result = [] + for i, r in enumerate(result_obj): + if r is None: + result.append([ + self._obs[i], self._rew[i], self._done[i], self._info[i]]) + else: + result.append(ray.get(r)) + self._obs, self._rew, self._done, self._info = zip(*result) + if self.is_reset_after_done() and sum(self._done): + self._obs = np.stack(self._obs) + index = np.where(self._done)[0] result_obj = [] for i in range(len(index)): result_obj.append(self.envs[index[i]].reset.remote()) for i in range(len(index)): - obs[index[i]] = ray.get(result_obj[i]) - return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info) + self._obs[index[i]] = ray.get(result_obj[i]) + return np.stack(self._obs), np.stack(self._rew),\ + np.stack(self._done), np.stack(self._info) def reset(self): + self._done = np.zeros(self.env_num) result_obj = [e.reset.remote() for e in self.envs] - return np.stack([ray.get(r) for r in result_obj]) + self._obs = np.stack([ray.get(r) for r in result_obj]) + return self._obs def seed(self, seed=None): if not hasattr(self.envs[0], 'seed'): diff --git a/tianshou/policy/dqn.py b/tianshou/policy/dqn.py index b3f41e3..76c8ddc 100644 --- a/tianshou/policy/dqn.py +++ b/tianshou/policy/dqn.py @@ -35,8 +35,10 @@ class DQNPolicy(BasePolicy, nn.Module): q, h = model(obs, hidden_state=hidden_state, info=batch.info) act = q.max(dim=1)[1].detach().cpu().numpy() # add eps to act + if eps is None: + eps = self.eps for i in range(len(q)): - if np.random.rand() < self.eps: + if np.random.rand() < eps: act[i] = np.random.randint(q.shape[1]) return Batch(Q=q, act=act, state=h) @@ -66,7 +68,7 @@ class DQNPolicy(BasePolicy, nn.Module): terminal = (indice + self._n_step - 1) % len(buffer) if self._target: # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - a = self(buffer[terminal], input='obs_next').act + a = self(buffer[terminal], input='obs_next', eps=0).act target_q = self( buffer[terminal], model='model_old', input='obs_next').Q if isinstance(target_q, torch.Tensor):