add some docs
This commit is contained in:
parent
6cfa876591
commit
974ade8019
@ -29,11 +29,11 @@ Tianshou supports parallel workers for all algorithms as well. All of these algo
|
||||
|
||||
In Chinese, Tianshou means the innate talent, not taught by others. Tianshou is a reinforcement learning platform. As we know, an RL agent does not learn from humans, so taking "Tianshou" means that there is no teacher to study with, but to learn by interacting with an environment.
|
||||
|
||||
“[天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)”意指上天所授,引申为与生具有的天赋。天授是强化学习平台,不是和人类学习的,所以取“天授”意思是没有老师来教,而是自己通过跟环境不断交互来进行学习。
|
||||
“天授”意指上天所授,引申为与生具有的天赋。天授是强化学习平台,而强化学习算法并不是向人类学习的,所以取“天授”意思是没有老师来教,而是自己通过跟环境不断交互来进行学习。
|
||||
|
||||
## Installation
|
||||
|
||||
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). You can simply install Tianshou with the following command:
|
||||
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command:
|
||||
|
||||
```bash
|
||||
pip3 install tianshou
|
||||
|
||||
@ -58,7 +58,7 @@ master_doc = 'index'
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
|
||||
autodoc_default_options = {'special-members': '__call__, __getitem__, __len__'}
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
Basic Concepts in Tianshou
|
||||
Basic concepts in Tianshou
|
||||
==========================
|
||||
|
||||
Tianshou splits a Reinforcement Learning agent training procedure into these parts: trainer, collector, policy, and data buffer. The general control flow can be described as:
|
||||
@ -11,85 +11,19 @@ Tianshou splits a Reinforcement Learning agent training procedure into these par
|
||||
Data Batch
|
||||
----------
|
||||
|
||||
Tianshou provides :class:`~tianshou.data.Batch` as the internal data structure to pass any kind of data to other methods, for example, a collector gives a :class:`~tianshou.data.Batch` to policy for learning. Here is its usage:
|
||||
::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from tianshou.data import Batch
|
||||
>>> data = Batch(a=4, b=[5, 5], c='2312312')
|
||||
>>> data.b
|
||||
[5, 5]
|
||||
>>> data.b = np.array([3, 4, 5])
|
||||
>>> len(data.b)
|
||||
3
|
||||
>>> data.b[-1]
|
||||
5
|
||||
|
||||
In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair.
|
||||
|
||||
The current implementation of Tianshou typically use 6 keys in :class:`~tianshou.data.Batch`:
|
||||
|
||||
* ``obs``: the observation of step :math:`t` ;
|
||||
* ``act``: the action of step :math:`t` ;
|
||||
* ``rew``: the reward of step :math:`t` ;
|
||||
* ``done``: the done flag of step :math:`t` ;
|
||||
* ``obs_next``: the observation of step :math:`t+1` ;
|
||||
* ``info``: the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function return 4 arguments, and the last one is ``info``);
|
||||
|
||||
:class:`~tianshou.data.Batch` has other methods, including ``__getitem__``, ``append``, and ``split``:
|
||||
::
|
||||
|
||||
>>> data = Batch(obs=np.array([0, 11, 22]), rew=np.array([6, 6, 6]))
|
||||
>>> # here we test __getitem__
|
||||
>>> index = [2, 1]
|
||||
>>> data[index].obs
|
||||
array([22, 11])
|
||||
|
||||
>>> data.append(data) # how we use a list
|
||||
>>> data.obs
|
||||
array([0, 11, 22, 0, 11, 22])
|
||||
|
||||
>>> # split whole data into multiple small batch
|
||||
>>> for d in data.split(size=2, permute=False):
|
||||
... print(d.obs, d.rew)
|
||||
[ 0 11] [6 6]
|
||||
[22 0] [6 6]
|
||||
[11 22] [6 6]
|
||||
.. automodule:: tianshou.data.Batch
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
|
||||
Data Buffer
|
||||
-----------
|
||||
|
||||
:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. It stores basically 6 types of data as mentioned above (7 types with importance weight in :class:`~tianshou.data.PrioritizedReplayBuffer`). Here is the :class:`~tianshou.data.ReplayBuffer`'s usage:
|
||||
::
|
||||
.. automodule:: tianshou.data.ReplayBuffer
|
||||
:members:
|
||||
:noindex:
|
||||
|
||||
>>> from tianshou.data import ReplayBuffer
|
||||
>>> buf = ReplayBuffer(size=20)
|
||||
>>> for i in range(3):
|
||||
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
||||
>>> buf.obs
|
||||
# since we set size = 20, len(buf.obs) == 20.
|
||||
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
0., 0., 0.])
|
||||
|
||||
>>> buf2 = ReplayBuffer(size=10)
|
||||
>>> for i in range(15):
|
||||
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
||||
>>> buf2.obs
|
||||
# since its size = 10, it only stores the last 10 steps' result.
|
||||
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
|
||||
|
||||
>>> # move buf2's result into buf (keep it chronologically meanwhile)
|
||||
>>> buf.update(buf2)
|
||||
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
|
||||
0., 0., 0., 0., 0., 0., 0.])
|
||||
|
||||
>>> # get a random sample from buffer, the batch_data is equal to buf[incide].
|
||||
>>> batch_data, indice = buf.sample(batch_size=4)
|
||||
>>> batch_data.obs == buf[indice].obs
|
||||
array([ True, True, True, True])
|
||||
|
||||
The :class:`~tianshou.data.ReplayBuffer` is based on ``numpy.ndarray``. Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out the API documentation for more detail.
|
||||
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out the :doc:`/api/tianshou.data` API documentation for more detail.
|
||||
|
||||
|
||||
Policy
|
||||
@ -110,12 +44,12 @@ Take 2-step return DQN as an example. The 2-step return DQN compute each frame's
|
||||
|
||||
G_t = r_t + \gamma r_{t + 1} + \gamma^2 \max_a Q(s_{t + 2}, a)
|
||||
|
||||
Here is the pseudocode showing the training process **without Tianshou framework**:
|
||||
where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. Here is the pseudocode showing the training process **without Tianshou framework**:
|
||||
::
|
||||
|
||||
# pseudocode, cannot work
|
||||
buffer = Buffer(size=10000)
|
||||
s = env.reset()
|
||||
buffer = Buffer(size=10000)
|
||||
agent = DQN()
|
||||
for i in range(int(1e6)):
|
||||
a = agent.compute_action(s)
|
||||
@ -155,8 +89,8 @@ For other method, you can check out the API documentation for more detail. We gi
|
||||
::
|
||||
|
||||
# pseudocode, cannot work # methods in tianshou
|
||||
buffer = Buffer(size=10000)
|
||||
s = env.reset()
|
||||
buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000)
|
||||
agent = DQN() # done in policy.__init__(...)
|
||||
for i in range(int(1e6)): # done in trainer
|
||||
a = agent.compute_action(s) # done in policy.__call__(batch, ...)
|
||||
@ -174,7 +108,7 @@ For other method, you can check out the API documentation for more detail. We gi
|
||||
Collector
|
||||
---------
|
||||
|
||||
The collector enables the policy to interact with different types of environments conveniently.
|
||||
The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.
|
||||
|
||||
* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of steps ``n_step`` or episodes ``n_episode`` and store the data in the replay buffer;
|
||||
* :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data.
|
||||
|
||||
@ -3,6 +3,8 @@ Train a model-free RL agent within 30s
|
||||
|
||||
This page summarizes some hyper-parameter tuning experience and code-level trick when training a model-free DRL agent.
|
||||
|
||||
You can also contribute to this page with your own tricks :)
|
||||
|
||||
|
||||
Avoid batch-size = 1
|
||||
--------------------
|
||||
@ -44,7 +46,7 @@ Here is an example of showing how inefficient it is:
|
||||
|
||||
The first test uses batch-size 128, and the second test uses batch-size = 1 for 128 times. In our test, the first is 70-80 times faster than the second.
|
||||
|
||||
So how could we avoid the case of batch-size = 1? The answer is synchronize sampling: we create multiple independent environments and sample simultaneously. It is similar to A2C, but other algorithms can also use this method. In our experiments, sampling from more environments benefits not only the sample speed but also the convergence of neural network (we guess it lowers the sample bias).
|
||||
So how could we avoid the case of batch-size = 1? The answer is synchronize sampling: we create multiple independent environments and sample simultaneously. It is similar to A2C, but other algorithms can also use this method. In our experiments, sampling from more environments benefits not only the sample speed but also the converge speed of neural network (we guess it lowers the sample bias).
|
||||
|
||||
By the way, A2C is better than A3C in some cases: A3C needs to act independently and sync the gradient to master, but, in a single node, using A3C to act with batch-size = 1 is quite resource-consuming.
|
||||
|
||||
@ -74,6 +76,9 @@ Jiayi: I write each line of code after quite a lot of time of consideration. Det
|
||||
Finally
|
||||
-------
|
||||
|
||||
With fast-speed sampling, we could use large batch-size and large learning rate
|
||||
With fast-speed sampling, we could use large batch-size and large learning rate for faster convergence.
|
||||
|
||||
RL algorithms are seed-sensitive. Try more seeds and pick the best. But for our demo, we just used seed = 0 and found it work surprisingly well on policy gradient, so we did not try other seed.
|
||||
|
||||
.. image:: ../_static/images/testpg.gif
|
||||
:align: center
|
||||
|
||||
@ -6,7 +6,7 @@ from tianshou.data import Batch
|
||||
|
||||
def test_batch():
|
||||
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
||||
batch.update(obs=[1])
|
||||
batch.obs = [1]
|
||||
assert batch.obs == [1]
|
||||
batch.append(batch)
|
||||
assert batch.obs == [1, 1]
|
||||
|
||||
@ -5,10 +5,10 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||
|
||||
if __name__ == '__main__':
|
||||
from net import Actor, Critic
|
||||
@ -49,11 +49,12 @@ def test_ddpg(args=get_args()):
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
test_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -89,7 +90,7 @@ def test_ddpg(args=get_args()):
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
||||
@ -6,7 +6,7 @@ import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
@ -53,11 +53,12 @@ def _test_ppo(args=get_args()):
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
train_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
test_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -97,8 +98,7 @@ def _test_ppo(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||
task=args.task)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
||||
@ -5,10 +5,10 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.policy import SACPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||
|
||||
if __name__ == '__main__':
|
||||
from net import ActorProb, Critic
|
||||
@ -49,11 +49,12 @@ def test_sac(args=get_args()):
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
test_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -94,7 +95,7 @@ def test_sac(args=get_args()):
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
||||
@ -5,10 +5,10 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.policy import TD3Policy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||
|
||||
if __name__ == '__main__':
|
||||
from net import Actor, Critic
|
||||
@ -52,11 +52,12 @@ def test_td3(args=get_args()):
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
test_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -98,7 +99,7 @@ def test_td3(args=get_args()):
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
||||
@ -6,7 +6,7 @@ import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
@ -45,14 +45,16 @@ def get_args():
|
||||
|
||||
|
||||
def test_a2c(args=get_args()):
|
||||
torch.set_num_threads(1) # for poor CPU
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
train_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
test_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -83,8 +85,7 @@ def test_a2c(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||
task=args.task)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
||||
@ -5,8 +5,8 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
@ -48,6 +48,7 @@ def test_dqn(args=get_args()):
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
@ -89,7 +90,7 @@ def test_dqn(args=get_args()):
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, writer=writer, task=args.task)
|
||||
stop_fn=stop_fn, writer=writer)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
|
||||
@ -6,8 +6,8 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Batch, Collector, ReplayBuffer
|
||||
|
||||
@ -25,7 +25,7 @@ def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
|
||||
if not batch.done[i]:
|
||||
returns[i] += last * gamma
|
||||
last = returns[i]
|
||||
batch.update(returns=returns)
|
||||
batch.returns = returns
|
||||
return batch
|
||||
|
||||
|
||||
@ -99,6 +99,7 @@ def test_pg(args=get_args()):
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
@ -131,8 +132,7 @@ def test_pg(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||
task=args.task)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
||||
@ -5,8 +5,8 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
@ -46,14 +46,16 @@ def get_args():
|
||||
|
||||
|
||||
def test_ppo(args=get_args()):
|
||||
torch.set_num_threads(1) # for poor CPU
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
test_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -88,8 +90,7 @@ def test_ppo(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||
task=args.task)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
||||
@ -3,23 +3,83 @@ import numpy as np
|
||||
|
||||
|
||||
class Batch(object):
|
||||
"""Suggested keys: [obs, act, rew, done, obs_next, info]"""
|
||||
"""
|
||||
Tianshou provides :class:`~tianshou.data.Batch` as the internal data
|
||||
structure to pass any kind of data to other methods, for example, a
|
||||
collector gives a :class:`~tianshou.data.Batch` to policy for learning.
|
||||
Here is the usage:
|
||||
::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from tianshou.data import Batch
|
||||
>>> data = Batch(a=4, b=[5, 5], c='2312312')
|
||||
>>> data.b
|
||||
[5, 5]
|
||||
>>> data.b = np.array([3, 4, 5])
|
||||
>>> len(data.b)
|
||||
3
|
||||
>>> data.b[-1]
|
||||
5
|
||||
|
||||
In short, you can define a :class:`Batch` with any key-value pair. The
|
||||
current implementation of Tianshou typically use 6 keys in
|
||||
:class:`~tianshou.data.Batch`:
|
||||
|
||||
* ``obs``: the observation of step :math:`t` ;
|
||||
* ``act``: the action of step :math:`t` ;
|
||||
* ``rew``: the reward of step :math:`t` ;
|
||||
* ``done``: the done flag of step :math:`t` ;
|
||||
* ``obs_next``: the observation of step :math:`t+1` ;
|
||||
* ``info``: the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
|
||||
function return 4 arguments, and the last one is ``info``);
|
||||
|
||||
:class:`~tianshou.data.Batch` has other methods, including
|
||||
:meth:`~tianshou.data.Batch.__getitem__`,
|
||||
:meth:`~tianshou.data.Batch.__len__`,
|
||||
:meth:`~tianshou.data.Batch.append`,
|
||||
and :meth:`~tianshou.data.Batch.split`:
|
||||
::
|
||||
|
||||
>>> data = Batch(obs=np.array([0, 11, 22]), rew=np.array([6, 6, 6]))
|
||||
>>> # here we test __getitem__
|
||||
>>> index = [2, 1]
|
||||
>>> data[index].obs
|
||||
array([22, 11])
|
||||
|
||||
>>> # here we test __len__
|
||||
>>> len(data)
|
||||
3
|
||||
|
||||
>>> data.append(data) # similar to list.append
|
||||
>>> data.obs
|
||||
array([0, 11, 22, 0, 11, 22])
|
||||
|
||||
>>> # split whole data into multiple small batch
|
||||
>>> for d in data.split(size=2, permute=False):
|
||||
... print(d.obs, d.rew)
|
||||
[ 0 11] [6 6]
|
||||
[22 0] [6 6]
|
||||
[11 22] [6 6]
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Return self[index].
|
||||
"""
|
||||
b = Batch()
|
||||
for k in self.__dict__.keys():
|
||||
if self.__dict__[k] is not None:
|
||||
b.update(**{k: self.__dict__[k][index]})
|
||||
b.__dict__.update(**{k: self.__dict__[k][index]})
|
||||
return b
|
||||
|
||||
def update(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def append(self, batch):
|
||||
"""
|
||||
Append a :class:`~tianshou.data.Batch` object to the end.
|
||||
"""
|
||||
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
||||
for k in batch.__dict__.keys():
|
||||
if batch.__dict__[k] is None:
|
||||
@ -40,10 +100,24 @@ class Batch(object):
|
||||
+ 'in class Batch.'
|
||||
raise TypeError(s)
|
||||
|
||||
def split(self, size=None, permute=True):
|
||||
length = min([
|
||||
def __len__(self):
|
||||
"""
|
||||
Return len(self).
|
||||
"""
|
||||
return min([
|
||||
len(self.__dict__[k]) for k in self.__dict__.keys()
|
||||
if self.__dict__[k] is not None])
|
||||
|
||||
def split(self, size=None, permute=True):
|
||||
"""
|
||||
Split whole data into multiple small batch.
|
||||
|
||||
:param size: if equals to ``None``, it does not split the data batch; \
|
||||
otherwise it will divide the data batch with the given size.
|
||||
:param permute: randomly shuffle the entire data batch if it equals to\
|
||||
``True``, otherwise remain in the same.
|
||||
"""
|
||||
length = len(self)
|
||||
if size is None:
|
||||
size = length
|
||||
temp = 0
|
||||
|
||||
@ -3,7 +3,40 @@ from tianshou.data.batch import Batch
|
||||
|
||||
|
||||
class ReplayBuffer(object):
|
||||
"""docstring for ReplayBuffer"""
|
||||
"""
|
||||
:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction
|
||||
between the policy and environment. It stores basically 6 types of data, as
|
||||
mentioned in :class:`~tianshou.data.Batch`, based on ``numpy.ndarray``.
|
||||
Here is the usage:
|
||||
::
|
||||
|
||||
>>> from tianshou.data import ReplayBuffer
|
||||
>>> buf = ReplayBuffer(size=20)
|
||||
>>> for i in range(3):
|
||||
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
||||
>>> buf.obs
|
||||
# since we set size = 20, len(buf.obs) == 20.
|
||||
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
0., 0., 0., 0.])
|
||||
|
||||
>>> buf2 = ReplayBuffer(size=10)
|
||||
>>> for i in range(15):
|
||||
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
||||
>>> buf2.obs
|
||||
# since its size = 10, it only stores the last 10 steps' result.
|
||||
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
|
||||
|
||||
>>> # move buf2's result into buf (keep it chronologically meanwhile)
|
||||
>>> buf.update(buf2)
|
||||
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
|
||||
0., 0., 0., 0., 0., 0., 0.])
|
||||
|
||||
>>> # get a random sample from buffer
|
||||
>>> # the batch_data is equal to buf[incide].
|
||||
>>> batch_data, indice = buf.sample(batch_size=4)
|
||||
>>> batch_data.obs == buf[indice].obs
|
||||
array([ True, True, True, True])
|
||||
"""
|
||||
|
||||
def __init__(self, size):
|
||||
super().__init__()
|
||||
@ -15,6 +48,9 @@ class ReplayBuffer(object):
|
||||
del self.__dict__[k]
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Return len(self).
|
||||
"""
|
||||
return self._size
|
||||
|
||||
def _add_to_buffer(self, name, inst):
|
||||
@ -34,6 +70,9 @@ class ReplayBuffer(object):
|
||||
self.__dict__[name][self._index] = inst
|
||||
|
||||
def update(self, buffer):
|
||||
"""
|
||||
Move the data from the given buffer to self.
|
||||
"""
|
||||
i = begin = buffer._index % len(buffer)
|
||||
while True:
|
||||
self.add(
|
||||
@ -45,7 +84,7 @@ class ReplayBuffer(object):
|
||||
|
||||
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
|
||||
'''
|
||||
weight: importance weights, disabled here
|
||||
Add a batch of data into replay buffer.
|
||||
'''
|
||||
assert isinstance(info, dict), \
|
||||
'You should return a dict in the last argument of env.step().'
|
||||
@ -62,10 +101,18 @@ class ReplayBuffer(object):
|
||||
self._size = self._index = self._index + 1
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear all the data in replay buffer.
|
||||
"""
|
||||
self._index = self._size = 0
|
||||
self.indice = []
|
||||
|
||||
def sample(self, batch_size):
|
||||
"""
|
||||
Get a random sample from buffer with size = ``batch_size``
|
||||
|
||||
:return: Sample data and its corresponding index inside the buffer.
|
||||
"""
|
||||
if batch_size > 0:
|
||||
indice = np.random.choice(self._size, batch_size)
|
||||
else:
|
||||
@ -76,6 +123,9 @@ class ReplayBuffer(object):
|
||||
return self[indice], indice
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Return a data batch: self[index].
|
||||
"""
|
||||
return Batch(
|
||||
obs=self.obs[index],
|
||||
act=self.act[index],
|
||||
@ -87,7 +137,12 @@ class ReplayBuffer(object):
|
||||
|
||||
|
||||
class ListReplayBuffer(ReplayBuffer):
|
||||
"""docstring for ListReplayBuffer"""
|
||||
"""
|
||||
The function of :class:`~tianshou.data.ListReplayBuffer` is almost the same
|
||||
as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
|
||||
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(size=0)
|
||||
|
||||
@ -111,7 +166,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
def __init__(self, size):
|
||||
super().__init__(size)
|
||||
|
||||
def add(self, obs, act, rew, done, obs_next, info={}, weight=None):
|
||||
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def sample(self, batch_size):
|
||||
|
||||
@ -2,7 +2,21 @@ import numpy as np
|
||||
|
||||
|
||||
class OUNoise(object):
|
||||
"""docstring for OUNoise"""
|
||||
"""
|
||||
Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG.
|
||||
Usage:
|
||||
::
|
||||
|
||||
# init
|
||||
self.noise = OUNoise()
|
||||
# generate noise
|
||||
noise = self.noise(logits.shape, eps)
|
||||
|
||||
For required parameters, you can refer to the stackoverflow page. However,
|
||||
our experiment result shows that (similar to OpenAI SpinningUp) using
|
||||
vanilla gaussian process has little difference from using the
|
||||
Ornstein-Uhlenbeck process.
|
||||
"""
|
||||
|
||||
def __init__(self, sigma=0.3, theta=0.15, dt=1e-2, x0=None):
|
||||
self.alpha = theta * dt
|
||||
@ -11,6 +25,10 @@ class OUNoise(object):
|
||||
self.reset()
|
||||
|
||||
def __call__(self, size, mu=.1):
|
||||
"""
|
||||
Generate new noise. Return a ``numpy.ndarray`` which size is equal to
|
||||
``size``.
|
||||
"""
|
||||
if self.x is None or self.x.shape != size:
|
||||
self.x = 0
|
||||
r = self.beta * np.random.normal(size=size)
|
||||
@ -18,4 +36,7 @@ class OUNoise(object):
|
||||
return self.x
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset to the initial state.
|
||||
"""
|
||||
self.x = None
|
||||
|
||||
@ -46,10 +46,10 @@ class A2CPolicy(PGPolicy):
|
||||
nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), max_norm=self._grad_norm)
|
||||
self.optim.step()
|
||||
actor_losses.append(a_loss.detach().cpu().numpy())
|
||||
vf_losses.append(vf_loss.detach().cpu().numpy())
|
||||
ent_losses.append(ent_loss.detach().cpu().numpy())
|
||||
losses.append(loss.detach().cpu().numpy())
|
||||
actor_losses.append(a_loss.item())
|
||||
vf_losses.append(vf_loss.item())
|
||||
ent_losses.append(ent_loss.item())
|
||||
losses.append(loss.item())
|
||||
return {
|
||||
'loss': losses,
|
||||
'loss/actor': actor_losses,
|
||||
|
||||
@ -107,6 +107,6 @@ class DDPGPolicy(BasePolicy):
|
||||
self.actor_optim.step()
|
||||
self.sync_weight()
|
||||
return {
|
||||
'loss/actor': actor_loss.detach().cpu().numpy(),
|
||||
'loss/critic': critic_loss.detach().cpu().numpy(),
|
||||
'loss/actor': actor_loss.item(),
|
||||
'loss/critic': critic_loss.item(),
|
||||
}
|
||||
|
||||
@ -17,7 +17,7 @@ class DQNPolicy(BasePolicy):
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
self.eps = 0
|
||||
assert 0 < discount_factor <= 1, 'discount_factor should in (0, 1]'
|
||||
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
||||
self._gamma = discount_factor
|
||||
assert estimation_step > 0, 'estimation_step should greater than 0'
|
||||
self._n_step = estimation_step
|
||||
@ -66,7 +66,7 @@ class DQNPolicy(BasePolicy):
|
||||
target_q = target_q.max(axis=1)
|
||||
target_q[gammas != self._n_step] = 0
|
||||
returns += (self._gamma ** gammas) * target_q
|
||||
batch.update(returns=returns)
|
||||
batch.returns = returns
|
||||
return batch
|
||||
|
||||
def __call__(self, batch, state=None,
|
||||
@ -96,4 +96,4 @@ class DQNPolicy(BasePolicy):
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
self._cnt += 1
|
||||
return {'loss': loss.detach().cpu().numpy()}
|
||||
return {'loss': loss.item()}
|
||||
|
||||
@ -20,9 +20,8 @@ class PGPolicy(BasePolicy):
|
||||
self._gamma = discount_factor
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
returns = self._vanilla_returns(batch)
|
||||
# returns = self._vectorized_returns(batch)
|
||||
batch.update(returns=returns)
|
||||
batch.returns = self._vanilla_returns(batch)
|
||||
# batch.returns = self._vectorized_returns(batch)
|
||||
return batch
|
||||
|
||||
def __call__(self, batch, state=None):
|
||||
@ -45,7 +44,7 @@ class PGPolicy(BasePolicy):
|
||||
loss = -(dist.log_prob(a) * r).sum()
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
losses.append(loss.detach().cpu().numpy())
|
||||
losses.append(loss.item())
|
||||
return {'loss': losses}
|
||||
|
||||
def _vanilla_returns(self, batch):
|
||||
|
||||
@ -76,19 +76,19 @@ class PPOPolicy(PGPolicy):
|
||||
surr2 = ratio.clamp(
|
||||
1. - self._eps_clip, 1. + self._eps_clip) * adv
|
||||
clip_loss = -torch.min(surr1, surr2).mean()
|
||||
clip_losses.append(clip_loss.detach().cpu().numpy())
|
||||
clip_losses.append(clip_loss.item())
|
||||
vf_loss = F.smooth_l1_loss(self.critic(b.obs), target_v)
|
||||
vf_losses.append(vf_loss.detach().cpu().numpy())
|
||||
vf_losses.append(vf_loss.item())
|
||||
|
||||
e_loss = dist.entropy().mean()
|
||||
ent_losses.append(e_loss.detach().cpu().numpy())
|
||||
ent_losses.append(e_loss.item())
|
||||
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss
|
||||
losses.append(loss.detach().cpu().numpy())
|
||||
losses.append(loss.item())
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(list(
|
||||
self.actor.parameters()) + list(self.critic.parameters()),
|
||||
self._max_grad_norm)
|
||||
self._max_grad_norm)
|
||||
self.optim.step()
|
||||
self.sync_weight()
|
||||
return {
|
||||
|
||||
@ -99,7 +99,7 @@ class SACPolicy(DDPGPolicy):
|
||||
self.actor_optim.step()
|
||||
self.sync_weight()
|
||||
return {
|
||||
'loss/actor': actor_loss.detach().cpu().numpy(),
|
||||
'loss/critic1': critic1_loss.detach().cpu().numpy(),
|
||||
'loss/critic2': critic2_loss.detach().cpu().numpy(),
|
||||
'loss/actor': actor_loss.item(),
|
||||
'loss/critic1': critic1_loss.item(),
|
||||
'loss/critic2': critic2_loss.item(),
|
||||
}
|
||||
|
||||
@ -81,12 +81,12 @@ class TD3Policy(DDPGPolicy):
|
||||
batch.obs, self(batch, eps=0).act).mean()
|
||||
self.actor_optim.zero_grad()
|
||||
actor_loss.backward()
|
||||
self._last = actor_loss.detach().cpu().numpy()
|
||||
self._last = actor_loss.item()
|
||||
self.actor_optim.step()
|
||||
self.sync_weight()
|
||||
self._cnt += 1
|
||||
return {
|
||||
'loss/actor': self._last,
|
||||
'loss/critic1': critic1_loss.detach().cpu().numpy(),
|
||||
'loss/critic2': critic2_loss.detach().cpu().numpy(),
|
||||
'loss/critic1': critic1_loss.item(),
|
||||
'loss/critic2': critic2_loss.item(),
|
||||
}
|
||||
|
||||
@ -3,6 +3,9 @@ import numpy as np
|
||||
|
||||
|
||||
def test_episode(policy, collector, test_fn, epoch, n_episode):
|
||||
"""
|
||||
A simple wrapper of testing policy in collector.
|
||||
"""
|
||||
collector.reset_env()
|
||||
collector.reset_buffer()
|
||||
policy.eval()
|
||||
@ -17,6 +20,24 @@ def test_episode(policy, collector, test_fn, epoch, n_episode):
|
||||
|
||||
|
||||
def gather_info(start_time, train_c, test_c, best_reward):
|
||||
"""
|
||||
A simple wrapper of gathering information from collectors.
|
||||
|
||||
:return: A dictionary with following keys:
|
||||
|
||||
* ``train_step``: the total collected step of training collector;
|
||||
* ``train_episode``: the total collected episode of training collector;
|
||||
* ``train_time/collector``: the time for collecting frames in the\
|
||||
training collector;
|
||||
* ``train_time/model``: the time for training models;
|
||||
* ``train_speed``: the speed of training (frames per second);
|
||||
* ``test_step``: the total collected step of test collector;
|
||||
* ``test_episode``: the total collected episode of test collector;
|
||||
* ``test_time``: the time for testing;
|
||||
* ``test_speed``: the speed of testing (frames per second);
|
||||
* ``best_reward``: the best reward over the test results;
|
||||
* ``duration``: the total elapsed time.
|
||||
"""
|
||||
duration = time.time() - start_time
|
||||
model_time = duration - train_c.collect_time - test_c.collect_time
|
||||
train_speed = train_c.collect_step / (duration - test_c.collect_time)
|
||||
|
||||
@ -3,14 +3,35 @@ import numpy as np
|
||||
|
||||
|
||||
class MovAvg(object):
|
||||
"""
|
||||
Class for moving average. Usage:
|
||||
::
|
||||
|
||||
>>> stat = MovAvg(size=66)
|
||||
>>> stat.add(torch.tensor(5))
|
||||
5.0
|
||||
>>> stat.add(float('inf')) # which will not add to stat
|
||||
5.0
|
||||
>>> stat.add([6, 7, 8])
|
||||
6.5
|
||||
>>> stat.get()
|
||||
6.5
|
||||
>>> print(f'{stat.mean():.2f}±{stat.std():.2f}')
|
||||
6.50±1.12
|
||||
"""
|
||||
def __init__(self, size=100):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.cache = []
|
||||
|
||||
def add(self, x):
|
||||
"""
|
||||
Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
|
||||
only one element, a python scalar, or a list of python scalar. It will
|
||||
exclude the infinity.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.detach().cpu().numpy()
|
||||
x = x.item()
|
||||
if isinstance(x, list):
|
||||
for _ in x:
|
||||
if _ != np.inf:
|
||||
@ -22,14 +43,23 @@ class MovAvg(object):
|
||||
return self.get()
|
||||
|
||||
def get(self):
|
||||
"""
|
||||
Get the average.
|
||||
"""
|
||||
if len(self.cache) == 0:
|
||||
return 0
|
||||
return np.mean(self.cache)
|
||||
|
||||
def mean(self):
|
||||
"""
|
||||
Get the average. Same as :meth:`get`.
|
||||
"""
|
||||
return self.get()
|
||||
|
||||
def std(self):
|
||||
"""
|
||||
Get the standard deviation.
|
||||
"""
|
||||
if len(self.cache) == 0:
|
||||
return 0
|
||||
return np.std(self.cache)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user