From f16e05c0e77d0c26a0993ca4bd9acac3dac20bb9 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 13 Mar 2020 17:49:22 +0800 Subject: [PATCH] maybe finished collector? --- .github/workflows/pytest.yml | 2 +- flake8.sh | 2 +- setup.py | 2 +- test/test_batch.py | 16 ++++++++ test/test_buffer.py | 3 +- test/test_env.py | 4 +- tianshou/__init__.py | 9 +++- tianshou/data/__init__.py | 7 +++- tianshou/data/batch.py | 22 +++++++++- tianshou/data/buffer.py | 23 +++++------ tianshou/data/collector.py | 79 +++++++++++++++++++++++++++++------- tianshou/env/__init__.py | 12 +++++- tianshou/env/wrapper.py | 22 +++++++--- tianshou/policy/__init__.py | 6 ++- tianshou/policy/base.py | 5 ++- 15 files changed, 165 insertions(+), 49 deletions(-) create mode 100644 test/test_batch.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 1715738..7cf30a2 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -33,7 +33,7 @@ jobs: # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + flake8 . --count --exit-zero --max-complexity=20 --max-line-length=79 --statistics - name: Test with pytest run: | pip install pytest pytest-cov diff --git a/flake8.sh b/flake8.sh index 9d03862..5f3cff1 100755 --- a/flake8.sh +++ b/flake8.sh @@ -1,3 +1,3 @@ #!/bin/sh flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics -flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics +flake8 . --count --exit-zero --max-complexity=20 --max-line-length=79 --statistics diff --git a/setup.py b/setup.py index 96665cc..ae5ce99 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ setup( 'tqdm', 'numpy', 'torch', - 'cloudpickle' + 'cloudpickle', 'tensorboard', ], ) diff --git a/test/test_batch.py b/test/test_batch.py new file mode 100644 index 0000000..205e1be --- /dev/null +++ b/test/test_batch.py @@ -0,0 +1,16 @@ +import numpy as np + +from tianshou.data import Batch + + +def test_batch(): + batch = Batch(obs=[0], np=np.zeros([3, 4])) + batch.update(obs=[1]) + assert batch.obs == [1] + batch.append(batch) + assert batch.obs == [1, 1] + assert batch.np.shape == (6, 4) + + +if __name__ == '__main__': + test_batch() diff --git a/test/test_buffer.py b/test/test_buffer.py index 749dc35..4477482 100644 --- a/test/test_buffer.py +++ b/test/test_buffer.py @@ -14,8 +14,7 @@ def test_replaybuffer(size=10, bufsize=20): obs_next, rew, done, info = env.step(a) buf.add(obs, a, rew, done, obs_next, info) assert len(buf) == min(bufsize, i + 1), print(len(buf), i) - indice = buf.sample_indice(4) - data = buf.sample(4) + data, indice = buf.sample(4) assert (indice < len(buf)).all() assert (data.obs < size).all() assert (0 <= data.done).all() and (data.done <= 1).all() diff --git a/test/test_env.py b/test/test_env.py index 76f2996..087ec34 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -79,7 +79,9 @@ def test_vecenv(verbose=False, size=10, num=8, sleep=0.001): for a in action_list: e.step([a] * num) t[i] = time.time() - t[i] - print(f'VectorEnv: {t[0]:.6f}s\nSubprocVectorEnv: {t[1]:.6f}s\nRayVectorEnv: {t[2]:.6f}s') + print(f'VectorEnv: {t[0]:.6f}s') + print(f'SubprocVectorEnv: {t[1]:.6f}s') + print(f'RayVectorEnv: {t[2]:.6f}s') for v in venv: v.close() diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 7922766..3e97de1 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,4 +1,9 @@ -from tianshou import data, env, utils +from tianshou import data, env, utils, policy __version__ = '0.2.0' -__all__ = ['data', 'env', 'utils'] +__all__ = [ + 'data', + 'env', + 'utils', + 'policy' +] diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index b3153fe..ea536ae 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -2,4 +2,9 @@ from tianshou.data.batch import Batch from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer from tianshou.data.collector import Collector -__all__ = ['Batch', 'ReplayBuffer', 'PrioritizedReplayBuffer', 'Collector'] +__all__ = [ + 'Batch', + 'ReplayBuffer', + 'PrioritizedReplayBuffer', + 'Collector' +] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 60cdcb3..d254217 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,9 +1,29 @@ +import numpy as np + + class Batch(object): """Suggested keys: [obs, act, rew, done, obs_next, info]""" + def __init__(self, **kwargs): super().__init__() - self.obs_next = None self.__dict__.update(kwargs) def update(self, **kwargs): self.__dict__.update(kwargs) + + def append(self, batch): + assert isinstance(batch, Batch), 'Only append Batch is allowed!' + for k in batch.__dict__.keys(): + if batch.__dict__[k] is None: + continue + if not hasattr(self, k) or self.__dict__[k] is None: + self.__dict__[k] = batch.__dict__[k] + elif isinstance(batch.__dict__[k], np.ndarray): + self.__dict__[k] = np.concatenate([ + self.__dict__[k], batch.__dict__[k]]) + elif isinstance(batch.__dict__[k], list): + self.__dict__[k] += batch.__dict__[k] + else: + raise TypeError( + 'No support append method with {} in class Batch.' + .format(type(batch.__dict__[k]))) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index dfaffc8..0cf5168 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -4,6 +4,7 @@ from tianshou.data.batch import Batch class ReplayBuffer(object): """docstring for ReplayBuffer""" + def __init__(self, size): super().__init__() self._maxsize = size @@ -19,7 +20,8 @@ class ReplayBuffer(object): if isinstance(inst, np.ndarray): self.__dict__[name] = np.zeros([self._maxsize, *inst.shape]) elif isinstance(inst, dict): - self.__dict__[name] = np.array([{} for _ in range(self._maxsize)]) + self.__dict__[name] = np.array( + [{} for _ in range(self._maxsize)]) else: # assume `inst` is a number self.__dict__[name] = np.zeros([self._maxsize]) self.__dict__[name][self._index] = inst @@ -28,7 +30,8 @@ class ReplayBuffer(object): ''' weight: importance weights, disabled here ''' - assert isinstance(info, dict), 'You should return a dict in the last argument of env.step function.' + assert isinstance(info, dict),\ + 'You should return a dict in the last argument of env.step().' self._add_to_buffer('obs', obs) self._add_to_buffer('act', act) self._add_to_buffer('rew', rew) @@ -42,18 +45,11 @@ class ReplayBuffer(object): self._index = self._size = 0 self.indice = [] - def sample_indice(self, batch_size): + def sample(self, batch_size): if batch_size > 0: - self.indice = np.random.choice(self._size, batch_size) + indice = np.random.choice(self._size, batch_size) else: - self.indice = np.arange(self._size) - return self.indice - - def sample(self, batch_size, indice=None): - if indice is None: - indice = self.sample_indice(batch_size) - else: - self.indice = indice + indice = np.arange(self._size) return Batch( obs=self.obs[indice], act=self.act[indice], @@ -61,11 +57,12 @@ class ReplayBuffer(object): done=self.done[indice], obs_next=self.obs_next[indice], info=self.info[indice] - ) + ), indice class PrioritizedReplayBuffer(ReplayBuffer): """docstring for PrioritizedReplayBuffer""" + def __init__(self, size): super().__init__(size) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 1f43a5e..8378d53 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -5,8 +5,10 @@ from tianshou.env import BaseVectorEnv from tianshou.data import Batch, ReplayBuffer from tianshou.utils import MovAvg + class Collector(object): """docstring for Collector""" + def __init__(self, policy, env, buffer): super().__init__() self.env = env @@ -18,15 +20,18 @@ class Collector(object): if self.multi_env: self.env_num = len(env) if isinstance(self.buffer, list): - assert len(self.buffer) == self.env_num, 'The data buffer number does not match the input env number.' + assert len(self.buffer) == self.env_num,\ + 'Data buffer number does not match the input env number.' elif isinstance(self.buffer, ReplayBuffer): self.buffer = [deepcopy(buffer) for _ in range(self.env_num)] else: raise TypeError('The buffer in data collector is invalid!') self.reset_env() self.clear_buffer() - # state over batch is either a list, an np.ndarray, or torch.Tensor (hasattr 'shape') + # state over batch is either a list, an np.ndarray, or torch.Tensor self.state = None + self.stat_reward = MovAvg() + self.stat_length = MovAvg() def clear_buffer(self): if self.multi_env: @@ -38,39 +43,64 @@ class Collector(object): def reset_env(self): self._obs = self.env.reset() self._act = self._rew = self._done = self._info = None + if self.multi_env: + self.reward = np.zeros(self.env_num) + self.length = np.zeros(self.env_num) + else: + self.reward, self.length = 0, 0 - def collect(self, n_step=0, n_episode=0, tqdm_hook=None): - assert sum([(n_step > 0), (n_episode > 0)]) == 1, "One and only one collection number specification permitted!" + def collect(self, n_step=0, n_episode=0): + assert sum([(n_step > 0), (n_episode > 0)]) == 1,\ + "One and only one collection number specification permitted!" cur_step = 0 cur_episode = np.zeros(self.env_num) if self.multi_env else 0 while True: if self.multi_env: - batch_data = Batch(obs=self._obs, act=self._act, rew=self._rew, done=self._done, info=self._info) + batch_data = Batch( + obs=self._obs, act=self._act, rew=self._rew, + done=self._done, obs_next=None, info=self._info) else: - batch_data = Batch(obs=[self._obs], act=[self._act], rew=[self._rew], done=[self._done], info=[self_info]) + batch_data = Batch( + obs=[self._obs], act=[self._act], rew=[self._rew], + done=[self._done], obs_next=None, info=[self._info]) result = self.policy.act(batch_data, self.state) self.state = result.state self._act = result.act - obs_next, self._rew, self._done, self._info = self.env.step(self._act) + obs_next, self._rew, self._done, self._info = self.env.step( + self._act) cur_step += 1 + self.length += 1 + self.reward += self._rew if self.multi_env: for i in range(self.env_num): - if n_episode > 0 and cur_episode[i] < n_episode or n_episode == 0: - self.buffer[i].add(self._obs[i], self._act[i], self._rew[i], self._done[i], obs_next[i], self._info[i]) + if n_episode > 0 and \ + cur_episode[i] < n_episode or n_episode == 0: + self.buffer[i].add( + self._obs[i], self._act[i], self._rew[i], + self._done[i], obs_next[i], self._info[i]) if self._done[i]: cur_episode[i] += 1 + self.stat_reward.add(self.reward[i]) + self.stat_length.add(self.length[i]) + self.reward[i], self.length[i] = 0, 0 if isinstance(self.state, list): self.state[i] = None else: self.state[i] = self.state[i] * 0 - if hasattr(self.state, 'detach'): # remove count in torch + if hasattr(self.state, 'detach'): + # remove ref in torch self.state = self.state.detach() if n_episode > 0 and (cur_episode >= n_episode).all(): break else: - self.buffer.add(self._obs, self._act[0], self._rew, self._done, obs_next, self._info) + self.buffer.add( + self._obs, self._act[0], self._rew, + self._done, obs_next, self._info) if self._done: cur_episode += 1 + self.stat_reward.add(self.reward) + self.stat_length.add(self.length) + self.reward, self.length = 0, 0 self.state = None if n_episode > 0 and cur_episode >= n_episode: break @@ -79,8 +109,29 @@ class Collector(object): self._obs = obs_next self._obs = obs_next - def sample(self): - pass + def sample(self, batch_size): + if self.multi_env: + if batch_size > 0: + lens = [len(b) for b in self.buffer] + total = sum(lens) + ib = np.random.choice( + total, batch_size, p=np.array(lens) / total) + else: + ib = np.array([]) + batch_data = Batch() + for i, b in enumerate(self.buffer): + cur_batch = (ib == i).sum() + if batch_size and cur_batch or batch_size <= 0: + batch, indice = b.sample(cur_batch) + batch = self.process_fn(batch, b, indice) + batch_data.append(batch) + else: + batch_data, indice = self.buffer.sample(batch_size) + batch_data = self.process_fn(batch_data, self.buffer, indice) + return batch_data def stat(self): - pass + return { + 'reward': self.stat_reward.get(), + 'length': self.stat_length.get(), + } diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index eb8faa2..940e749 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -1,3 +1,11 @@ -from tianshou.env.wrapper import FrameStack, BaseVectorEnv, VectorEnv, SubprocVectorEnv, RayVectorEnv +from tianshou.env.wrapper import FrameStack,\ + BaseVectorEnv, VectorEnv, SubprocVectorEnv,\ + RayVectorEnv -__all__ = ['FrameStack', 'BaseVectorEnv', 'VectorEnv', 'SubprocVectorEnv', 'RayVectorEnv'] +__all__ = [ + 'FrameStack', + 'BaseVectorEnv', + 'VectorEnv', + 'SubprocVectorEnv', + 'RayVectorEnv' +] diff --git a/tianshou/env/wrapper.py b/tianshou/env/wrapper.py index 7c6e40c..5a4fbb5 100644 --- a/tianshou/env/wrapper.py +++ b/tianshou/env/wrapper.py @@ -1,6 +1,6 @@ import numpy as np +from abc import ABC from collections import deque -from abc import ABC, abstractmethod from multiprocessing import Process, Pipe try: import ray @@ -64,6 +64,7 @@ class BaseVectorEnv(ABC): class VectorEnv(BaseVectorEnv): """docstring for VectorEnv""" + def __init__(self, env_fns, reset_after_done=False): super().__init__() self.envs = [_() for _ in env_fns] @@ -129,14 +130,19 @@ def worker(parent, p, env_fn_wrapper, reset_after_done): class SubprocVectorEnv(BaseVectorEnv): """docstring for SubProcVectorEnv""" + def __init__(self, env_fns, reset_after_done=False): super().__init__() self.env_num = len(env_fns) self.closed = False - self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)]) + self.parent_remote, self.child_remote = \ + zip(*[Pipe() for _ in range(self.env_num)]) self.processes = [ - Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), reset_after_done), daemon=True) - for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns) + Process(target=worker, args=( + parent, child, + CloudpickleWrapper(env_fn), reset_after_done), daemon=True) + for (parent, child, env_fn) in zip( + self.parent_remote, self.child_remote, env_fns) ] for p in self.processes: p.start() @@ -185,6 +191,7 @@ class SubprocVectorEnv(BaseVectorEnv): class RayVectorEnv(BaseVectorEnv): """docstring for RayVectorEnv""" + def __init__(self, env_fns, reset_after_done=False): super().__init__() self.env_num = len(env_fns) @@ -193,8 +200,11 @@ class RayVectorEnv(BaseVectorEnv): if not ray.is_initialized(): ray.init() except NameError: - raise ImportError('Please install ray to support VectorEnv: pip3 install ray -U') - self.envs = [ray.remote(EnvWrapper).options(num_cpus=0).remote(e()) for e in env_fns] + raise ImportError( + 'Please install ray to support VectorEnv: pip3 install ray -U') + self.envs = [ + ray.remote(EnvWrapper).options(num_cpus=0).remote(e()) + for e in env_fns] def __len__(self): return self.env_num diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index ba28e5d..d7a3fb5 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,3 +1,5 @@ -from tianshou.policy import BasePolicy +from tianshou.policy.base import BasePolicy -__all__ = ['BasePolicy'] +__all__ = [ + 'BasePolicy' +] diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index d3d130a..de2eb9f 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod class BasePolicy(ABC): """docstring for BasePolicy""" + def __init__(self): super().__init__() @@ -21,8 +22,8 @@ class BasePolicy(ABC): pass @staticmethod - def process_fn(batch, buffer, index): - pass + def process_fn(batch, buffer, indice): + return batch def exploration(self): pass