diff --git a/.gitignore b/.gitignore index 762cc89..b1ec891 100644 --- a/.gitignore +++ b/.gitignore @@ -107,7 +107,6 @@ celerybeat.pid # Environments .env .venv -env/ venv/ ENV/ env.bak/ diff --git a/setup.py b/setup.py index 077023f..a82aef7 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,16 @@ setup( keywords='reinforcement learning platform', # You can just specify the packages manually here if your project is # simple. Or you can use find_packages(). - packages=find_packages(exclude=['__pycache__']), - install_requires=['numpy', 'torch', 'tensorboard', 'tqdm'], + packages=find_packages(exclude=['tests', 'tests.*', + 'examples', 'examples.*', + 'docs', 'docs.*']), + install_requires=[ + 'numpy', + 'torch', + 'tensorboard', + 'tqdm', + # 'ray', + 'gym', + 'cloudpickle' + ], ) \ No newline at end of file diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 6e11e27..7922766 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,4 +1,4 @@ -import os +from tianshou import data, env, utils -name = 'tianshou' __version__ = '0.2.0' +__all__ = ['data', 'env', 'utils'] diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py new file mode 100644 index 0000000..711feb8 --- /dev/null +++ b/tianshou/data/__init__.py @@ -0,0 +1,4 @@ +from tianshou.data.batch import Batch +from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer + +__all__ = ['Batch', 'ReplayBuffer', 'PrioritizedReplayBuffer'] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py new file mode 100644 index 0000000..937afa6 --- /dev/null +++ b/tianshou/data/batch.py @@ -0,0 +1,6 @@ +class Batch(object): + """Suggested keys: [obs, act, rew, done, obs_next, info]""" + + def __init__(self, **kwargs): + super().__init__() + self.__dict__.update(kwargs) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py new file mode 100644 index 0000000..cafde94 --- /dev/null +++ b/tianshou/data/buffer.py @@ -0,0 +1,65 @@ +import numpy as np +from tianshou.data.batch import Batch + + +class ReplayBuffer(object): + """docstring for ReplayBuffer""" + def __init__(self, size): + super().__init__() + self._maxsize = size + self._index = self._size = 0 + + def __len__(self): + return self._size + + def _add_to_buffer(self, name, inst): + if inst is None: + return + if self.__dict__.get(name, None) is None: + if isinstance(inst, np.ndarray): + self.__dict__[name] = np.zeros([self._maxsize, *inst.shape]) + elif isinstance(inst, dict): + self.__dict__[name] = np.array([{} for _ in range(self._maxsize)]) + else: # assume `inst` is a number + self.__dict__[name] = np.zeros([self._maxsize]) + self.__dict__[name][self._index] = inst + + def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None): + ''' + weight: importance weights, disabled here + ''' + assert isinstance(info, dict), 'You should return a dict in the last argument of env.step function.' + self._add_to_buffer('obs', obs) + self._add_to_buffer('act', act) + self._add_to_buffer('rew', rew) + self._add_to_buffer('done', done) + self._add_to_buffer('obs_next', obs_next) + self._add_to_buffer('info', info) + self._size = min(self._size + 1, self._maxsize) + self._index = (self._index + 1) % self._maxsize + + def reset(self): + self._index = self._size = 0 + + def sample_indice(self, batch_size): + return np.random.choice(self._size, batch_size) + + def sample(self, batch_size): + indice = self.sample_index(batch_size) + return Batch(obs=self.obs[indice], act=self.act[indice], rew=self.rew[indice], + done=self.done[indice], obs_next=self.obs_next[indice], info=self.info[indice]) + + +class PrioritizedReplayBuffer(ReplayBuffer): + """docstring for PrioritizedReplayBuffer""" + def __init__(self, size): + super().__init__(size) + + def add(self, obs, act, rew, done, obs_next, info={}, weight=None): + raise NotImplementedError + + def sample_indice(self, batch_size): + raise NotImplementedError + + def sample(self, batch_size): + raise NotImplementedError diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py new file mode 100644 index 0000000..e9ff886 --- /dev/null +++ b/tianshou/env/__init__.py @@ -0,0 +1,3 @@ +from tianshou.env.wrapper import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv + +__all__ = ['FrameStack', 'VectorEnv', 'SubprocVectorEnv', 'RayVectorEnv'] diff --git a/tianshou/env/wrapper.py b/tianshou/env/wrapper.py new file mode 100644 index 0000000..f068a0e --- /dev/null +++ b/tianshou/env/wrapper.py @@ -0,0 +1,208 @@ +import numpy as np +from collections import deque +from multiprocessing import Process, Pipe + +from tianshou.utils import CloudpickleWrapper + + +class EnvWrapper(object): + def __init__(self, env): + self.env = env + + def step(self, action): + return self.env.step(action) + + def reset(self): + self.env.reset() + + def seed(self, seed=None): + if hasattr(self.env, 'seed'): + self.env.seed(seed) + + def render(self): + if hasattr(self.env, 'render'): + self.env.render() + + def close(self): + self.env.close() + + +class FrameStack(EnvWrapper): + def __init__(self, env, stack_num): + """Stack last k frames.""" + super().__init__(env) + self.stack_num = stack_num + self._frames = deque([], maxlen=stack_num) + + def step(self, action): + obs, reward, done, info = self.env.step(action) + self._frames.append(obs) + return self._get_obs(), reward, done, info + + def reset(self): + obs = self.env.reset() + for _ in range(self.stack_num): + self._frames.append(obs) + return self._get_obs() + + def _get_obs(self): + return np.concatenate(self._frames, axis=-1) + + +class VectorEnv(object): + """docstring for VectorEnv""" + def __init__(self, env_fns, **kwargs): + super().__init__() + self.envs = [_() for _ in env_fns] + self._reset_after_done = kwargs.get('reset_after_done', False) + + def __len__(self): + return len(self.envs) + + def reset(self): + return np.stack([e.reset() for e in self.envs]) + + def step(self, action): + result = zip(*[e.step(action[i]) for i, e in enumerate(self.envs)]) + obs, rew, done, info = zip(*result) + if self._reset_after_done and sum(done): + for i, e in enumerate(self.envs): + if done[i]: + e.reset() + return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info) + + def seed(self, seed=None): + for e in self.envs: + if hasattr(e, 'seed'): + e.seed(seed) + + def render(self): + for e in self.envs: + if hasattr(e, 'render'): + e.render() + + def close(self): + for e in self.envs: + e.close() + + +class SubprocVectorEnv(object): + """docstring for SubProcVectorEnv""" + def __init__(self, env_fns, **kwargs): + super().__init__() + self.env_num = len(env_fns) + self.closed = False + self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)]) + self.processes = [Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True) + for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)] + for p in self.processes: + p.start() + for c in self.child_remote: + c.close() + + def __len__(self): + return self.env_num + + def worker(parent, p, env_fn_wrapper, **kwargs): + reset_after_done = kwargs.get('reset_after_done', True) + parent.close() + env = env_fn_wrapper.data() + while True: + cmd, data = p.recv() + if cmd is 'step': + obs, rew, done, info = env.step(data) + if reset_after_done and done: + # s_ is useless when episode finishes + obs = env.reset() + p.send([obs, rew, done, info]) + elif cmd is 'reset': + p.send(env.reset()) + elif cmd is 'close': + p.close() + break + elif cmd is 'render': + p.send(env.render()) + elif cmd is 'seed': + p.send(env.seed(data)) + else: + raise NotImplementedError + + def step(self, action): + for p, a in zip(self.parent_remote, action): + p.send(['step', a]) + result = [p.recv() for p in self.parent_remote] + obs, rew, done, info = zip(*result) + return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info) + + def reset(self): + for p in self.parent_remote: + p.send(['reset', None]) + return np.stack([p.recv() for p in self.parent_remote]) + + def seed(self, seed): + if np.isscalar(seed): + seed = [seed for _ in range(self.env_num)] + for p, s in zip(self.parent_remote, seed): + p.send(['seed', s]) + for p in self.parent_remote: + p.recv() + + def render(self): + for p in self.parent_remote: + p.send(['render', None]) + for p in self.parent_remote: + p.recv() + + def close(self): + if self.closed: + return + for p in self.parent_remote: + p.send(['close', None]) + self.closed = True + for p in self.processes: + p.join() + + + +class RayVectorEnv(object): + """docstring for RayVectorEnv""" + def __init__(self, env_fns, **kwargs): + super().__init__() + self.env_num = len(env_fns) + self._reset_after_done = kwargs.get('reset_after_done', False) + try: + import ray + except ImportError: + raise ImportError('Please install ray to support VectorEnv: pip3 install ray -U') + if not ray.is_initialized(): + ray.init() + self.envs = [ray.remote(EnvWrapper).options(num_cpus=0).remote(e()) for e in env_fns] + + def __len__(self): + return self.env_num + + def step(self, action): + result_obj = [e.step.remote(action[i]) for i, e in enumerate(self.envs)] + obs, rew, done, info = zip(*[ray.get(r) for r in result_obj]) + return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info) + + def reset(self): + result_obj = [e.reset.remote() for e in self.envs] + return np.stack([ray.get(r) for r in result_obj]) + + def seed(self, seed): + if np.isscalar(seed): + seed = [seed for _ in range(self.env_num)] + result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)] + for r in result_obj: + ray.get(r) + + def render(self): + result_obj = [e.render.remote() for e in self.envs] + for r in result_obj: + ray.get(r) + + def close(self): + result_obj = [e.close.remote() for e in self.envs] + for r in result_obj: + ray.get(r) diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py new file mode 100644 index 0000000..128ba80 --- /dev/null +++ b/tianshou/utils/__init__.py @@ -0,0 +1,3 @@ +from tianshou.utils.cloudpicklewrapper import CloudpickleWrapper + +__all__ = ['CloudpickleWrapper'] diff --git a/tianshou/utils/cloudpicklewrapper.py b/tianshou/utils/cloudpicklewrapper.py new file mode 100644 index 0000000..c2221ca --- /dev/null +++ b/tianshou/utils/cloudpicklewrapper.py @@ -0,0 +1,10 @@ +import cloudpickle + + +class CloudpickleWrapper(object): + def __init__(self, data): + self.data = data + def __getstate__(self): + return cloudpickle.dumps(self.data) + def __setstate__(self, data): + self.data = cloudpickle.loads(data)