This commit is contained in:
Trinkle23897 2020-06-08 22:20:52 +08:00
parent 560116d0b2
commit 513573ea82
4 changed files with 18 additions and 14 deletions

View File

@ -33,9 +33,10 @@
Here is Tianshou's other features: Here is Tianshou's other features:
- Elegant framework, using only ~2000 lines of code - Elegant framework, using only ~2000 lines of code
- Support parallel environment sampling for all algorithms - Support parallel environment sampling for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling)
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) - Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
- Support any type of environment state (e.g. a dict, a self-defined class, ...) - Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
- Support n-step returns estimation for all Q-learning based algorithms - Support n-step returns estimation for all Q-learning based algorithms
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment. In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.

View File

@ -23,9 +23,10 @@ Welcome to Tianshou!
Here is Tianshou's other features: Here is Tianshou's other features:
* Elegant framework, using only ~2000 lines of code * Elegant framework, using only ~2000 lines of code
* Support parallel environment sampling for all algorithms * Support parallel environment sampling for all algorithms: :ref:`parallel_sampling`
* Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) * Support recurrent state representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training`
* Support any type of environment state (e.g. a dict, a self-defined class, ...) * Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env`
* Support customized training process: :ref:`customize_training`
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` for all Q-learning based algorithms * Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` for all Q-learning based algorithms
中文文档位于 https://tianshou.readthedocs.io/zh/latest/ 中文文档位于 https://tianshou.readthedocs.io/zh/latest/

View File

@ -19,13 +19,15 @@ Build New Policy
See :class:`~tianshou.policy.BasePolicy`. See :class:`~tianshou.policy.BasePolicy`.
.. _parallel_sampling: .. _customize_training:
Customize Training Process Customize Training Process
-------------------------- --------------------------
See :ref:`customized_trainer`. See :ref:`customized_trainer`.
.. _parallel_sampling:
Parallel Sampling Parallel Sampling
----------------- -----------------

View File

@ -97,15 +97,15 @@ class BaseVectorEnv(ABC, gym.Env):
pass pass
@abstractmethod @abstractmethod
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
"""Set the seed for all environments. """Set the seed for all environments.
Accept ``None``, an int (which will extend ``i`` to Accept ``None``, an int (which will extend ``i`` to
``[i, i + 1, i + 2, ...]``) or a list. ``[i, i + 1, i + 2, ...]``) or a list.
:return: The list of seeds used in this env's random number generators. :return: The list of seeds used in this env's random number \
The first value in the list should be the "main" seed, or the value generators. The first value in the list should be the "main" seed, or \
which a reproducer should pass to 'seed'. the value which a reproducer pass to "seed".
""" """
pass pass
@ -162,7 +162,7 @@ class VectorEnv(BaseVectorEnv):
self._info = np.stack(self._info) self._info = np.stack(self._info)
return self._obs, self._rew, self._done, self._info return self._obs, self._rew, self._done, self._info
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if np.isscalar(seed): if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)] seed = [seed + _ for _ in range(self.env_num)]
elif seed is None: elif seed is None:
@ -269,7 +269,7 @@ class SubprocVectorEnv(BaseVectorEnv):
self._obs[i] = self.parent_remote[i].recv() self._obs[i] = self.parent_remote[i].recv()
return self._obs return self._obs
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if np.isscalar(seed): if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)] seed = [seed + _ for _ in range(self.env_num)]
elif seed is None: elif seed is None:
@ -347,7 +347,7 @@ class RayVectorEnv(BaseVectorEnv):
self._obs[i] = ray.get(result_obj[_]) self._obs[i] = ray.get(result_obj[_])
return self._obs return self._obs
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if not hasattr(self.envs[0], 'seed'): if not hasattr(self.envs[0], 'seed'):
return return
if np.isscalar(seed): if np.isscalar(seed):