diff --git a/README.md b/README.md index c4cefb0..67c86fc 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ # Tianshou ![Python package](https://github.com/Trinkle23897/tianshou/workflows/Python%20package/badge.svg) + diff --git a/setup.py b/setup.py index 87f41ce..96665cc 100644 --- a/setup.py +++ b/setup.py @@ -20,14 +20,12 @@ setup( # 4 - Beta # 5 - Production/Stable 'Development Status :: 3 - Alpha', - # Indicate who your project is intended for 'Intended Audience :: Science/Research', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Software Development :: Libraries :: Python Modules', # Pick your license as you wish (should match "license" above) 'License :: OSI Approved :: MIT License', - # Specify the Python versions you support here. In particular, ensure # that you indicate whether you support Python 2, Python 3 or both. 'Programming Language :: Python :: 3.6', @@ -35,18 +33,16 @@ setup( 'Programming Language :: Python :: 3.8', ], keywords='reinforcement learning platform', - # You can just specify the packages manually here if your project is - # simple. Or you can use find_packages(). packages=find_packages(exclude=['tests', 'tests.*', 'examples', 'examples.*', 'docs', 'docs.*']), install_requires=[ - 'numpy', - 'torch', - 'tensorboard', - 'tqdm', # 'ray', 'gym', + 'tqdm', + 'numpy', + 'torch', 'cloudpickle' + 'tensorboard', ], ) diff --git a/test/test_buffer.py b/test/test_buffer.py index ae0267e..749dc35 100644 --- a/test/test_buffer.py +++ b/test/test_buffer.py @@ -1,7 +1,7 @@ from tianshou.data import ReplayBuffer if __name__ == '__main__': from test_env import MyTestEnv -else: +else: # pytest from test.test_env import MyTestEnv diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 711feb8..b3153fe 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -1,4 +1,5 @@ from tianshou.data.batch import Batch from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer +from tianshou.data.collector import Collector -__all__ = ['Batch', 'ReplayBuffer', 'PrioritizedReplayBuffer'] +__all__ = ['Batch', 'ReplayBuffer', 'PrioritizedReplayBuffer', 'Collector'] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index ee16eb0..60cdcb3 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -2,4 +2,8 @@ class Batch(object): """Suggested keys: [obs, act, rew, done, obs_next, info]""" def __init__(self, **kwargs): super().__init__() + self.obs_next = None + self.__dict__.update(kwargs) + + def update(self, **kwargs): self.__dict__.update(kwargs) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 20e66d3..dfaffc8 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -40,12 +40,20 @@ class ReplayBuffer(object): def reset(self): self._index = self._size = 0 + self.indice = [] def sample_indice(self, batch_size): - return np.random.choice(self._size, batch_size) + if batch_size > 0: + self.indice = np.random.choice(self._size, batch_size) + else: + self.indice = np.arange(self._size) + return self.indice - def sample(self, batch_size): - indice = self.sample_indice(batch_size) + def sample(self, batch_size, indice=None): + if indice is None: + indice = self.sample_indice(batch_size) + else: + self.indice = indice return Batch( obs=self.obs[indice], act=self.act[indice], diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py new file mode 100644 index 0000000..1f43a5e --- /dev/null +++ b/tianshou/data/collector.py @@ -0,0 +1,86 @@ +import numpy as np +from copy import deepcopy + +from tianshou.env import BaseVectorEnv +from tianshou.data import Batch, ReplayBuffer +from tianshou.utils import MovAvg + +class Collector(object): + """docstring for Collector""" + def __init__(self, policy, env, buffer): + super().__init__() + self.env = env + self.env_num = 1 + self.buffer = buffer + self.policy = policy + self.process_fn = policy.process_fn + self.multi_env = isinstance(env, BaseVectorEnv) + if self.multi_env: + self.env_num = len(env) + if isinstance(self.buffer, list): + assert len(self.buffer) == self.env_num, 'The data buffer number does not match the input env number.' + elif isinstance(self.buffer, ReplayBuffer): + self.buffer = [deepcopy(buffer) for _ in range(self.env_num)] + else: + raise TypeError('The buffer in data collector is invalid!') + self.reset_env() + self.clear_buffer() + # state over batch is either a list, an np.ndarray, or torch.Tensor (hasattr 'shape') + self.state = None + + def clear_buffer(self): + if self.multi_env: + for b in self.buffer: + b.reset() + else: + self.buffer.reset() + + def reset_env(self): + self._obs = self.env.reset() + self._act = self._rew = self._done = self._info = None + + def collect(self, n_step=0, n_episode=0, tqdm_hook=None): + assert sum([(n_step > 0), (n_episode > 0)]) == 1, "One and only one collection number specification permitted!" + cur_step = 0 + cur_episode = np.zeros(self.env_num) if self.multi_env else 0 + while True: + if self.multi_env: + batch_data = Batch(obs=self._obs, act=self._act, rew=self._rew, done=self._done, info=self._info) + else: + batch_data = Batch(obs=[self._obs], act=[self._act], rew=[self._rew], done=[self._done], info=[self_info]) + result = self.policy.act(batch_data, self.state) + self.state = result.state + self._act = result.act + obs_next, self._rew, self._done, self._info = self.env.step(self._act) + cur_step += 1 + if self.multi_env: + for i in range(self.env_num): + if n_episode > 0 and cur_episode[i] < n_episode or n_episode == 0: + self.buffer[i].add(self._obs[i], self._act[i], self._rew[i], self._done[i], obs_next[i], self._info[i]) + if self._done[i]: + cur_episode[i] += 1 + if isinstance(self.state, list): + self.state[i] = None + else: + self.state[i] = self.state[i] * 0 + if hasattr(self.state, 'detach'): # remove count in torch + self.state = self.state.detach() + if n_episode > 0 and (cur_episode >= n_episode).all(): + break + else: + self.buffer.add(self._obs, self._act[0], self._rew, self._done, obs_next, self._info) + if self._done: + cur_episode += 1 + self.state = None + if n_episode > 0 and cur_episode >= n_episode: + break + if n_step > 0 and cur_step >= n_step: + break + self._obs = obs_next + self._obs = obs_next + + def sample(self): + pass + + def stat(self): + pass diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index e9ff886..eb8faa2 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -1,3 +1,3 @@ -from tianshou.env.wrapper import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv +from tianshou.env.wrapper import FrameStack, BaseVectorEnv, VectorEnv, SubprocVectorEnv, RayVectorEnv -__all__ = ['FrameStack', 'VectorEnv', 'SubprocVectorEnv', 'RayVectorEnv'] +__all__ = ['FrameStack', 'BaseVectorEnv', 'VectorEnv', 'SubprocVectorEnv', 'RayVectorEnv'] diff --git a/tianshou/env/wrapper.py b/tianshou/env/wrapper.py index 0b03292..7c6e40c 100644 --- a/tianshou/env/wrapper.py +++ b/tianshou/env/wrapper.py @@ -1,5 +1,6 @@ import numpy as np from collections import deque +from abc import ABC, abstractmethod from multiprocessing import Process, Pipe try: import ray @@ -56,13 +57,18 @@ class FrameStack(EnvWrapper): return np.stack(self._frames, axis=-1) -class VectorEnv(object): +class BaseVectorEnv(ABC): + def __init__(self): + pass + + +class VectorEnv(BaseVectorEnv): """docstring for VectorEnv""" - def __init__(self, env_fns, **kwargs): + def __init__(self, env_fns, reset_after_done=False): super().__init__() self.envs = [_() for _ in env_fns] self.env_num = len(self.envs) - self._reset_after_done = kwargs.get('reset_after_done', False) + self._reset_after_done = reset_after_done def __len__(self): return len(self.envs) @@ -97,8 +103,7 @@ class VectorEnv(object): e.close() -def worker(parent, p, env_fn_wrapper, kwargs): - reset_after_done = kwargs.get('reset_after_done', True) +def worker(parent, p, env_fn_wrapper, reset_after_done): parent.close() env = env_fn_wrapper.data() while True: @@ -115,22 +120,22 @@ def worker(parent, p, env_fn_wrapper, kwargs): p.close() break elif cmd == 'render': - p.send(env.render()) + p.send(env.render() if hasattr(env, 'render') else None) elif cmd == 'seed': - p.send(env.seed(data)) + p.send(env.seed(data) if hasattr(env, 'seed') else None) else: raise NotImplementedError -class SubprocVectorEnv(object): +class SubprocVectorEnv(BaseVectorEnv): """docstring for SubProcVectorEnv""" - def __init__(self, env_fns, **kwargs): + def __init__(self, env_fns, reset_after_done=False): super().__init__() self.env_num = len(env_fns) self.closed = False self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)]) self.processes = [ - Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True) + Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), reset_after_done), daemon=True) for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns) ] for p in self.processes: @@ -178,12 +183,12 @@ class SubprocVectorEnv(object): p.join() -class RayVectorEnv(object): +class RayVectorEnv(BaseVectorEnv): """docstring for RayVectorEnv""" - def __init__(self, env_fns, **kwargs): + def __init__(self, env_fns, reset_after_done=False): super().__init__() self.env_num = len(env_fns) - self._reset_after_done = kwargs.get('reset_after_done', False) + self._reset_after_done = reset_after_done try: if not ray.is_initialized(): ray.init() @@ -213,6 +218,8 @@ class RayVectorEnv(object): return np.stack([ray.get(r) for r in result_obj]) def seed(self, seed=None): + if not hasattr(self.envs[0], 'seed'): + return if np.isscalar(seed) or seed is None: seed = [seed for _ in range(self.env_num)] result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)] @@ -220,6 +227,8 @@ class RayVectorEnv(object): ray.get(r) def render(self): + if not hasattr(self.envs[0], 'render'): + return result_obj = [e.render.remote() for e in self.envs] for r in result_obj: ray.get(r) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py new file mode 100644 index 0000000..ba28e5d --- /dev/null +++ b/tianshou/policy/__init__.py @@ -0,0 +1,3 @@ +from tianshou.policy import BasePolicy + +__all__ = ['BasePolicy'] diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py new file mode 100644 index 0000000..d3d130a --- /dev/null +++ b/tianshou/policy/base.py @@ -0,0 +1,28 @@ +from abc import ABC, abstractmethod + + +class BasePolicy(ABC): + """docstring for BasePolicy""" + def __init__(self): + super().__init__() + + @abstractmethod + def act(self, batch, hidden_state=None): + # return {policy, action, hidden} + pass + + def train(self): + pass + + def eval(self): + pass + + def reset(self): + pass + + @staticmethod + def process_fn(batch, buffer, index): + pass + + def exploration(self): + pass diff --git a/tianshou/policy/reward_processor.py b/tianshou/policy/reward_processor.py new file mode 100644 index 0000000..e69de29 diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index 266c2b2..89be458 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,4 +1,9 @@ -from tianshou.utils.cloudpicklewrapper import CloudpickleWrapper from tianshou.utils.config import tqdm_config +from tianshou.utils.moving_average import MovAvg +from tianshou.utils.cloudpicklewrapper import CloudpickleWrapper -__all__ = ['CloudpickleWrapper', 'tqdm_config'] +__all__ = [ + 'CloudpickleWrapper', + 'tqdm_config', + 'MovAvg' +] diff --git a/tianshou/utils/moving_average.py b/tianshou/utils/moving_average.py new file mode 100644 index 0000000..01d0c91 --- /dev/null +++ b/tianshou/utils/moving_average.py @@ -0,0 +1,23 @@ +import numpy as np + + +class MovAvg(object): + def __init__(self, size=100): + super().__init__() + self.size = size + self.cache = [] + + def add(self, x): + if hasattr(x, 'detach'): + # which means x is torch.Tensor (?) + x = x.detach().cpu().numpy() + if x != np.inf: + self.cache.append(x) + if self.size > 0 and len(self.cache) > self.size: + self.cache = self.cache[-self.size:] + return self.get() + + def get(self): + if len(self.cache) == 0: + return 0 + return np.mean(self.cache)