diff --git a/docs/index.rst b/docs/index.rst index 628f698..4cde13c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -49,7 +49,8 @@ If no error occurs, you have successfully installed Tianshou. tutorials/dqn tutorials/concepts - tutorials/trick.rst + tutorials/trick + tutorials/tabular .. toctree:: :maxdepth: 1 diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 5567aae..90ec59a 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -3,7 +3,7 @@ Basic concepts in Tianshou Tianshou splits a Reinforcement Learning agent training procedure into these parts: trainer, collector, policy, and data buffer. The general control flow can be described as: -.. image:: ../_static/images/concepts_arch.png +.. image:: /_static/images/concepts_arch.png :align: center :height: 300 diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index c509d74..96e5f06 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -211,8 +211,9 @@ No problem! Tianshou supports user-defined training code. Here is the usage: # train policy with a sampled batch data losses = policy.learn(train_collector.sample(batch_size=64)) +For further usage, you can refer to :doc:`/tutorials/tabular`. .. rubric:: References -.. bibliography:: ../refs.bib +.. bibliography:: /refs.bib :style: unsrtalpha diff --git a/docs/tutorials/tabular.rst b/docs/tutorials/tabular.rst new file mode 100644 index 0000000..fba4b2e --- /dev/null +++ b/docs/tutorials/tabular.rst @@ -0,0 +1,11 @@ +Tabular Q Learning Implementation +================================= + +This tutorial shows how to use Tianshou to develop new algorithms. + + +Background +---------- + +TODO + diff --git a/docs/tutorials/trick.rst b/docs/tutorials/trick.rst index 8c616e9..6cfd4b6 100644 --- a/docs/tutorials/trick.rst +++ b/docs/tutorials/trick.rst @@ -80,5 +80,5 @@ With fast-speed sampling, we could use large batch-size and large learning rate RL algorithms are seed-sensitive. Try more seeds and pick the best. But for our demo, we just used seed = 0 and found it work surprisingly well on policy gradient, so we did not try other seed. -.. image:: ../_static/images/testpg.gif +.. image:: /_static/images/testpg.gif :align: center diff --git a/test/base/test_env.py b/test/base/test_env.py index ae49c58..68f55f3 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -1,7 +1,6 @@ import time -import pytest import numpy as np -from tianshou.env import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv +from tianshou.env import VectorEnv, SubprocVectorEnv, RayVectorEnv if __name__ == '__main__': from env import MyTestEnv @@ -9,32 +8,6 @@ else: # pytest from test.base.env import MyTestEnv -def test_framestack(k=4, size=10): - env = MyTestEnv(size=size) - fsenv = FrameStack(env, k) - fsenv.seed() - obs = fsenv.reset() - assert abs(obs - np.array([0, 0, 0, 0])).sum() == 0 - for i in range(5): - obs, rew, done, info = fsenv.step(1) - assert abs(obs - np.array([2, 3, 4, 5])).sum() == 0 - for i in range(10): - obs, rew, done, info = fsenv.step(0) - assert abs(obs - np.array([0, 0, 0, 0])).sum() == 0 - for i in range(9): - obs, rew, done, info = fsenv.step(1) - assert abs(obs - np.array([6, 7, 8, 9])).sum() == 0 - assert (rew, done) == (0, False) - obs, rew, done, info = fsenv.step(1) - assert abs(obs - np.array([7, 8, 9, 10])).sum() == 0 - assert (rew, done) == (1, True) - with pytest.raises(ValueError): - obs, rew, done, info = fsenv.step(0) - # assert abs(obs - np.array([8, 9, 10, 10])).sum() == 0 - # assert (rew, done) == (0, True) - fsenv.close() - - def test_vecenv(size=10, num=8, sleep=0.001): verbose = __name__ == '__main__' env_fns = [ @@ -86,5 +59,4 @@ def test_vecenv(size=10, num=8, sleep=0.001): if __name__ == '__main__': - test_framestack() test_vecenv() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 370ec16..e6d7b90 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -67,9 +67,7 @@ class Batch(object): self.__dict__.update(kwargs) def __getitem__(self, index): - """ - Return self[index]. - """ + """Return self[index].""" b = Batch() for k in self.__dict__.keys(): if self.__dict__[k] is not None: @@ -77,9 +75,7 @@ class Batch(object): return b def append(self, batch): - """ - Append a :class:`~tianshou.data.Batch` object to the end. - """ + """Append a :class:`~tianshou.data.Batch` object to the end.""" assert isinstance(batch, Batch), 'Only append Batch is allowed!' for k in batch.__dict__.keys(): if batch.__dict__[k] is None: @@ -101,9 +97,7 @@ class Batch(object): raise TypeError(s) def __len__(self): - """ - Return len(self). - """ + """Return len(self).""" return min([ len(self.__dict__[k]) for k in self.__dict__.keys() if self.__dict__[k] is not None]) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 16ffaa1..cf96808 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -43,14 +43,8 @@ class ReplayBuffer(object): self._maxsize = size self.reset() - def __del__(self): - for k in list(self.__dict__.keys()): - del self.__dict__[k] - def __len__(self): - """ - Return len(self). - """ + """Return len(self).""" return self._size def _add_to_buffer(self, name, inst): @@ -70,9 +64,7 @@ class ReplayBuffer(object): self.__dict__[name][self._index] = inst def update(self, buffer): - """ - Move the data from the given buffer to self. - """ + """Move the data from the given buffer to self.""" i = begin = buffer._index % len(buffer) while True: self.add( @@ -83,9 +75,7 @@ class ReplayBuffer(object): break def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None): - ''' - Add a batch of data into replay buffer. - ''' + """Add a batch of data into replay buffer.""" assert isinstance(info, dict), \ 'You should return a dict in the last argument of env.step().' self._add_to_buffer('obs', obs) @@ -101,9 +91,7 @@ class ReplayBuffer(object): self._size = self._index = self._index + 1 def reset(self): - """ - Clear all the data in replay buffer. - """ + """Clear all the data in replay buffer.""" self._index = self._size = 0 self.indice = [] @@ -123,9 +111,7 @@ class ReplayBuffer(object): return self[indice], indice def __getitem__(self, index): - """ - Return a data batch: self[index]. - """ + """Return a data batch: self[index].""" return Batch( obs=self.obs[index], act=self.act[index], diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index 81f2014..98a09a4 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -1,14 +1,9 @@ -from tianshou.env.utils import CloudpickleWrapper -from tianshou.env.common import EnvWrapper, FrameStack from tianshou.env.vecenv import BaseVectorEnv, VectorEnv, \ SubprocVectorEnv, RayVectorEnv __all__ = [ - 'EnvWrapper', - 'FrameStack', 'BaseVectorEnv', 'VectorEnv', 'SubprocVectorEnv', 'RayVectorEnv', - 'CloudpickleWrapper', ] diff --git a/tianshou/env/common.py b/tianshou/env/common.py deleted file mode 100644 index 1a6f2f6..0000000 --- a/tianshou/env/common.py +++ /dev/null @@ -1,49 +0,0 @@ -import numpy as np -from collections import deque - - -class EnvWrapper(object): - def __init__(self, env): - self.env = env - - def step(self, action): - return self.env.step(action) - - def reset(self): - return self.env.reset() - - def seed(self, seed=None): - if hasattr(self.env, 'seed'): - return self.env.seed(seed) - - def render(self, **kwargs): - if hasattr(self.env, 'render'): - return self.env.render(**kwargs) - - def close(self): - self.env.close() - - -class FrameStack(EnvWrapper): - def __init__(self, env, stack_num): - """Stack last k frames.""" - super().__init__(env) - self.stack_num = stack_num - self._frames = deque([], maxlen=stack_num) - - def step(self, action): - obs, reward, done, info = self.env.step(action) - self._frames.append(obs) - return self._get_obs(), reward, done, info - - def reset(self): - obs = self.env.reset() - for _ in range(self.stack_num): - self._frames.append(obs) - return self._get_obs() - - def _get_obs(self): - try: - return np.concatenate(self._frames, axis=-1) - except ValueError: - return np.stack(self._frames, axis=-1) diff --git a/tianshou/env/utils.py b/tianshou/env/utils.py index 2bf3cd5..41b9ede 100644 --- a/tianshou/env/utils.py +++ b/tianshou/env/utils.py @@ -2,6 +2,8 @@ import cloudpickle class CloudpickleWrapper(object): + """A cloudpickle wrapper used in :class:`~tianshou.env.SubprocVectorEnv`""" + def __init__(self, data): self.data = data diff --git a/tianshou/env/vecenv.py b/tianshou/env/vecenv.py index b5e25f9..4dbeeb2 100644 --- a/tianshou/env/vecenv.py +++ b/tianshou/env/vecenv.py @@ -1,3 +1,4 @@ +import gym import numpy as np from abc import ABC, abstractmethod from multiprocessing import Process, Pipe @@ -7,40 +8,98 @@ try: except ImportError: pass -from tianshou.env import EnvWrapper, CloudpickleWrapper +from tianshou.env.utils import CloudpickleWrapper -class BaseVectorEnv(ABC): +class BaseVectorEnv(ABC, gym.Wrapper): + """ + Base class for vectorized environments wrapper. Usage: + :: + + env_num = 8 + envs = VectorEnv([lambda: gym.make(task) for _ in range(env_num)]) + + It accepts a list of environment generators. In other words, an environment + generator ``efn`` of a specific task means that ``efn()`` returns the + environment of the given task, for example, ``gym.make(task)``. + + All of the VectorEnv must inherit :class:`~tianshou.env.BaseVectorEnv`. + Here are some other usages: + :: + + envs.seed(2) # which is equal to the next line + envs.seed([2, 3, 4, 5, 6, 7, 8, 9]) # set specific seed for each env + obs = envs.reset() # reset all environments + obs = envs.reset([0, 5, 7]) # reset 3 specific environments + obs, rew, done, info = envs.step([1] * 8) # step synchronously + envs.render() # render all environments + envs.close() # close all environments + """ + def __init__(self, env_fns): self._env_fns = env_fns self.env_num = len(env_fns) def __len__(self): + """Return len(self), which is the number of environments.""" return self.env_num @abstractmethod - def reset(self): + def reset(self, id=None): + """ + Reset the state of all the environments and returns initial + observations if id is ``None``, otherwise reset the specific + environments with given id, either an int or a list. + """ pass @abstractmethod def step(self, action): + """ + Run one timestep of all the environments’ dynamics. When end of episode + is reached, you are responsible for calling reset(id) to reset this + environment’s state. + + Accepts a batch of action and returns a tuple (obs, rew, done, info). + + :args: + action (numpy.ndarray): a batch of action provided by the agent + + :return: + * obs (numpy.ndarray): agent's observation of current environments + * rew (numpy.ndarray) : amount of rewards returned after previous \ + actions + * done (numpy.ndarray): whether these episodes have ended, in \ + which case further step() calls will return undefined results + * info (numpy.ndarray): contains auxiliary diagnostic information \ + (helpful for debugging, and sometimes learning) + """ pass @abstractmethod def seed(self, seed=None): + """ + Set the seed for all environments. Accept ``None``, an int (which will + extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list. + """ pass @abstractmethod def render(self, **kwargs): + """Renders the environment.""" pass @abstractmethod def close(self): + """Close all of the environments.""" pass class VectorEnv(BaseVectorEnv): - """docstring for VectorEnv""" + """ + Dummy vectorized environment wrapper, implemented in for-loop. The usage \ + is in :class:`~tianshou.env.BaseVectorEnv`. + """ def __init__(self, env_fns): super().__init__(env_fns) @@ -85,8 +144,7 @@ class VectorEnv(BaseVectorEnv): return result def close(self): - for e in self.envs: - e.close() + return [e.close() for e in self.envs] def worker(parent, p, env_fn_wrapper): @@ -100,6 +158,7 @@ def worker(parent, p, env_fn_wrapper): elif cmd == 'reset': p.send(env.reset()) elif cmd == 'close': + p.send(env.close()) p.close() break elif cmd == 'render': @@ -114,7 +173,10 @@ def worker(parent, p, env_fn_wrapper): class SubprocVectorEnv(BaseVectorEnv): - """docstring for SubProcVectorEnv""" + """ + Vectorized environment wrapper based on subprocess. The usage is in \ + :class:`~tianshou.env.BaseVectorEnv`. + """ def __init__(self, env_fns): super().__init__(env_fns) @@ -178,13 +240,20 @@ class SubprocVectorEnv(BaseVectorEnv): return for p in self.parent_remote: p.send(['close', None]) + result = [p.recv() for p in self.parent_remote] self.closed = True for p in self.processes: p.join() + return result class RayVectorEnv(BaseVectorEnv): - """docstring for RayVectorEnv""" + """ + Vectorized environment wrapper based on \ + `ray `_. However, according to our \ + test, it is slower than :class:`~tianshou.env.SubprocVectorEnv`. The usage\ + is in :class:`~tianshou.env.BaseVectorEnv`. + """ def __init__(self, env_fns): super().__init__(env_fns) @@ -195,7 +264,7 @@ class RayVectorEnv(BaseVectorEnv): raise ImportError( 'Please install ray to support RayVectorEnv: pip3 install ray') self.envs = [ - ray.remote(EnvWrapper).options(num_cpus=0).remote(e()) + ray.remote(gym.Wrapper).options(num_cpus=0).remote(e()) for e in env_fns] def step(self, action): diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 00100f1..26af788 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -3,9 +3,7 @@ import numpy as np def test_episode(policy, collector, test_fn, epoch, n_episode): - """ - A simple wrapper of testing policy in collector. - """ + """A simple wrapper of testing policy in collector.""" collector.reset_env() collector.reset_buffer() policy.eval() diff --git a/tianshou/utils/moving_average.py b/tianshou/utils/moving_average.py index 1cbb197..acbc176 100644 --- a/tianshou/utils/moving_average.py +++ b/tianshou/utils/moving_average.py @@ -43,23 +43,17 @@ class MovAvg(object): return self.get() def get(self): - """ - Get the average. - """ + """Get the average.""" if len(self.cache) == 0: return 0 return np.mean(self.cache) def mean(self): - """ - Get the average. Same as :meth:`get`. - """ + """Get the average. Same as :meth:`get`.""" return self.get() def std(self): - """ - Get the standard deviation. - """ + """Get the standard deviation.""" if len(self.cache) == 0: return 0 return np.std(self.cache)