Docs/overhaul (#999)
Closes #916 This PR presents an overhaul of how the docs are built and presented 1. Notebooks are no longer just links in some drive. They are checked in without their outputs, executed in CI, and thereby serve as integration tests as well as tutorials. They have been adjusted to work with the current master branch 2. Execution of notebooks is cached, so it's very fast 3. The api docs are generated automatically with a custom script. Previously this was only done for the highlevel module 4. The build is happening with jupyter-book (which still uses sphinx in the backend). It is using the default jupyter book theme, which I think looks very nice and adds useful navigation to the right side of the screen 5. Customized api docs rendering for better appearance 6. The toc of the docs is built automatically with jupyter-book. The api docs generation script has been adjusted accordingly 7. The viewcode and linkcode extensions add source code and links to it to the docs 8. A bunch of docstrings have been adjusted to better reflect the configured rules 9. Several typing issues improved to make mypy happy It was quite a piece of work, I hope you like the result :)
3
.gitignore
vendored
@ -153,6 +153,9 @@ videos/
|
||||
# might be needed for IDE plugins that can't read ruff config
|
||||
.flake8
|
||||
|
||||
docs/notebooks/_build/
|
||||
docs/conf.py
|
||||
|
||||
# temporary scripts (for ad-hoc testing), temp folder
|
||||
/temp
|
||||
/temp*.py
|
||||
@ -28,8 +28,8 @@ repos:
|
||||
pass_filenames: false
|
||||
- id: poetry-lock-check
|
||||
name: poetry lock check
|
||||
entry: poetry lock
|
||||
args: [--check]
|
||||
entry: poetry check
|
||||
args: [--lock]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
- id: mypy
|
||||
|
||||
@ -10,15 +10,14 @@ build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.11"
|
||||
jobs:
|
||||
pre_build:
|
||||
- pip install .
|
||||
|
||||
# Build documentation in the docs/ directory with Sphinx
|
||||
sphinx:
|
||||
configuration: docs/conf.py
|
||||
# We recommend specifying your dependencies to enable reproducible builds:
|
||||
# https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
|
||||
python:
|
||||
install:
|
||||
- requirements: docs/requirements.txt
|
||||
commands:
|
||||
- mkdir -p $READTHEDOCS_OUTPUT/html
|
||||
- curl -sSL https://install.python-poetry.org | python -
|
||||
# - ~/.local/bin/poetry config virtualenvs.create false
|
||||
- ~/.local/bin/poetry install --with dev
|
||||
## Same as poe tasks, but unfortunately poe doesn't work with poetry not creating virtualenvs
|
||||
- ~/.local/bin/poetry run python docs/autogen_rst.py
|
||||
- ~/.local/bin/poetry run which jupyter-book
|
||||
- ~/.local/bin/poetry run python docs/create_toc.py
|
||||
- ~/.local/bin/poetry run jupyter-book config sphinx docs/
|
||||
- ~/.local/bin/poetry run sphinx-build -W -b html docs $READTHEDOCS_OUTPUT/html
|
||||
|
||||
6
docs/.gitignore
vendored
@ -1,2 +1,4 @@
|
||||
# auto-generated content
|
||||
/api/tianshou.highlevel
|
||||
/03_api/*
|
||||
jupyter_execute
|
||||
_toc.yml
|
||||
.jupyter_cache
|
||||
|
||||
@ -308,7 +308,7 @@ Tianshou supports user-defined training code. Here is the code snippet:
|
||||
# train policy with a sampled batch data from buffer
|
||||
losses = policy.update(64, train_collector.buffer)
|
||||
|
||||
For further usage, you can refer to the :doc:`/tutorials/cheatsheet`.
|
||||
For further usage, you can refer to the :doc:`/01_tutorials/07_cheatsheet`.
|
||||
|
||||
.. rubric:: References
|
||||
|
||||
@ -339,7 +339,7 @@ Thus, we need a time-related interface for calculating the 2-step return. :meth:
|
||||
|
||||
This code does not consider the done flag, so it may not work very well. It shows two ways to get :math:`s_{t + 2}` from the replay buffer easily in :meth:`~tianshou.policy.BasePolicy.process_fn`.
|
||||
|
||||
For other method, you can check out :doc:`/api/tianshou.policy`. We give the usage of policy class a high-level explanation in :ref:`pseudocode`.
|
||||
For other method, you can check out :doc:`/03_api/policy/index`. We give the usage of policy class a high-level explanation in :ref:`pseudocode`.
|
||||
|
||||
|
||||
Collector
|
||||
@ -382,7 +382,7 @@ Trainer
|
||||
|
||||
Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`.
|
||||
|
||||
Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage.
|
||||
Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/03_api/trainer/index` for the usage.
|
||||
|
||||
We also provide the corresponding iterator-based trainer classes :class:`~tianshou.trainer.OnpolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, :class:`~tianshou.trainer.OfflineTrainer` to facilitate users writing more flexible training logic:
|
||||
::
|
||||
@ -126,7 +126,7 @@ The figure in the right gives an intuitive comparison among synchronous/asynchro
|
||||
.. note::
|
||||
|
||||
The async simulation collector would cause some exceptions when used as
|
||||
``test_collector`` in :doc:`/api/tianshou.trainer` (related to
|
||||
``test_collector`` in :doc:`/03_api/trainer/index` (related to
|
||||
`Issue 700 <https://github.com/thu-ml/tianshou/issues/700>`_). Please use
|
||||
sync version for ``test_collector`` instead.
|
||||
|
||||
@ -478,4 +478,4 @@ By constructing a new state ``state_ = (state, agent_id, mask)``, essentially we
|
||||
act = policy(state_)
|
||||
next_state_, reward = env.step(act)
|
||||
|
||||
Following this idea, we write a tiny example of playing `Tic Tac Toe <https://en.wikipedia.org/wiki/Tic-tac-toe>`_ against a random player by using a Q-learning algorithm. The tutorial is at :doc:`/tutorials/tictactoe`.
|
||||
Following this idea, we write a tiny example of playing `Tic Tac Toe <https://en.wikipedia.org/wiki/Tic-tac-toe>`_ against a random player by using a Q-learning algorithm. The tutorial is at :doc:`/01_tutorials/04_tictactoe`.
|
||||
2
docs/01_tutorials/index.rst
Normal file
@ -0,0 +1,2 @@
|
||||
Tutorials
|
||||
=========
|
||||
4
docs/02_notebooks/0_intro.md
Normal file
@ -0,0 +1,4 @@
|
||||
# Notebook Tutorials
|
||||
|
||||
Here is a collection of executable tutorials for Tianshou. You can run them
|
||||
directly in colab, or download them and run them locally.
|
||||
236
docs/02_notebooks/L0_overview.ipynb
Normal file
@ -0,0 +1,236 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"id": "r7aE6Rq3cAEE",
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"# Overview\n",
|
||||
"In this tutorial, we use guide you step by step to show you how the most basic modules in Tianshou work and how they collaborate with each other to conduct a classic DRL experiment."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "1_mLTSEIcY2c"
|
||||
},
|
||||
"source": [
|
||||
"## Run the code\n",
|
||||
"Before we get started, we must first install Tianshou's library and Gym environment by running the commands below. Here I choose a specific version of Tianshou(0.4.8) which is the latest as of the time writing this tutorial. APIs in different versions may vary a little bit but most are the same. Feel free to use other versions in your own project."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "IcFNmCjYeIIU"
|
||||
},
|
||||
"source": [
|
||||
"Below is a short script that use a certain DRL algorithm (PPO) to solve the classic CartPole-v1\n",
|
||||
"problem in Gym. Simply run it and **don't worry** if you can't understand the code very well. That is\n",
|
||||
"exactly what this tutorial is for.\n",
|
||||
"\n",
|
||||
"If the script ends normally, you will see the evaluation result printed out before the first\n",
|
||||
"epoch is done."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"is_executing": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": [
|
||||
"hide-cell",
|
||||
"remove-output"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gymnasium as gym\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from tianshou.data import Collector, VectorReplayBuffer\n",
|
||||
"from tianshou.env import DummyVectorEnv\n",
|
||||
"from tianshou.policy import PPOPolicy\n",
|
||||
"from tianshou.trainer import OnpolicyTrainer\n",
|
||||
"from tianshou.utils.net.common import ActorCritic, Net\n",
|
||||
"from tianshou.utils.net.discrete import Actor, Critic\n",
|
||||
"\n",
|
||||
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"is_executing": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# environments\n",
|
||||
"env = gym.make(\"CartPole-v1\")\n",
|
||||
"train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(20)])\n",
|
||||
"test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(10)])\n",
|
||||
"\n",
|
||||
"# model & optimizer\n",
|
||||
"net = Net(env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n",
|
||||
"actor = Actor(net, env.action_space.n, device=device).to(device)\n",
|
||||
"critic = Critic(net, device=device).to(device)\n",
|
||||
"actor_critic = ActorCritic(actor, critic)\n",
|
||||
"optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)\n",
|
||||
"\n",
|
||||
"# PPO policy\n",
|
||||
"dist = torch.distributions.Categorical\n",
|
||||
"policy = PPOPolicy(\n",
|
||||
" actor=actor,\n",
|
||||
" critic=critic,\n",
|
||||
" optim=optim,\n",
|
||||
" dist_fn=dist,\n",
|
||||
" action_space=env.action_space,\n",
|
||||
" action_scaling=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# collector\n",
|
||||
"train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)))\n",
|
||||
"test_collector = Collector(policy, test_envs)\n",
|
||||
"\n",
|
||||
"# trainer\n",
|
||||
"result = OnpolicyTrainer(\n",
|
||||
" policy=policy,\n",
|
||||
" batch_size=256,\n",
|
||||
" train_collector=train_collector,\n",
|
||||
" test_collector=test_collector,\n",
|
||||
" max_epoch=10,\n",
|
||||
" step_per_epoch=50000,\n",
|
||||
" repeat_per_collect=10,\n",
|
||||
" episode_per_test=10,\n",
|
||||
" step_per_collect=2000,\n",
|
||||
" stop_fn=lambda mean_reward: mean_reward >= 195,\n",
|
||||
")\n",
|
||||
"print(result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "G9YEQptYvCgx",
|
||||
"is_executing": true,
|
||||
"outputId": "2a9b5b22-be50-4bb7-ae93-af7e65e7442a"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Let's watch its performance!\n",
|
||||
"policy.eval()\n",
|
||||
"result = test_collector.collect(n_episode=1, render=False)\n",
|
||||
"print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "xFYlcPo8fpPU"
|
||||
},
|
||||
"source": [
|
||||
"## Tutorial Introduction\n",
|
||||
"\n",
|
||||
"A common DRL experiment as is shown above may require many components to work together. The agent, the\n",
|
||||
"environment (possibly parallelized ones), the replay buffer and the trainer all work together to complete a\n",
|
||||
"training task.\n",
|
||||
"\n",
|
||||
"<div align=center>\n",
|
||||
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/pipeline.png\">\n",
|
||||
"\n",
|
||||
"</div>\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "kV_uOyimj-bk"
|
||||
},
|
||||
"source": [
|
||||
"In Tianshou, all of these main components are factored out as different building blocks, which you\n",
|
||||
"can use to create your own algorithm and finish your own experiment.\n",
|
||||
"\n",
|
||||
"Building blocks may include:\n",
|
||||
"- Batch\n",
|
||||
"- Replay Buffer\n",
|
||||
"- Vectorized Environment Wrapper\n",
|
||||
"- Policy (the agent and the training algorithm)\n",
|
||||
"- Data Collector\n",
|
||||
"- Trainer\n",
|
||||
"- Logger\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Check this [webpage](https://tianshou.readthedocs.io/en/master/tutorials/dqn.html) to find jupyter-notebook-style tutorials that will guide you through all these\n",
|
||||
"modules one by one. You can also read the [documentation](https://tianshou.readthedocs.io/en/master/) of Tianshou for more detailed explanation and\n",
|
||||
"advanced usages."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "S0mNKwH9i6Ek"
|
||||
},
|
||||
"source": [
|
||||
"## Further reading"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "M3NPSUnAov4L"
|
||||
},
|
||||
"source": [
|
||||
"### What if I am not familiar with the PPO algorithm itself?\n",
|
||||
"As for the DRL algorithms themselves, we will refer you to the [Spinning up documentation](https://spinningup.openai.com/en/latest/algorithms/ppo.html), where they provide\n",
|
||||
"plenty of resources and guides if you want to study the DRL algorithms. In Tianshou's tutorials, we will\n",
|
||||
"focus on the usages of different modules, but not the algorithms themselves."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
399
docs/02_notebooks/L1_Batch.ipynb
Normal file
@ -0,0 +1,399 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "69y6AHvq1S3f"
|
||||
},
|
||||
"source": [
|
||||
"# Batch\n",
|
||||
"In this tutorial, we will introduce the **Batch** to you, which is the most basic data structure in Tianshou. You can simply considered Batch as a numpy version of python dictionary."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"editable": true,
|
||||
"id": "NkfiIe_y2FI-",
|
||||
"outputId": "5008275f-8f77-489a-af64-b35af4448589",
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": [
|
||||
"remove-output",
|
||||
"hide-cell"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"from tianshou.data import Batch\n",
|
||||
"import torch\n",
|
||||
"import pickle"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = Batch(a=4, b=[5, 5], c=\"2312312\", d=(\"a\", -2, -3))\n",
|
||||
"print(data)\n",
|
||||
"print(data.b)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "S6e6OuXe3UT-"
|
||||
},
|
||||
"source": [
|
||||
"A batch is simply a dictionary which stores all passed in data as key-value pairs, and automatically turns the value into a numpy array if possible.\n",
|
||||
"\n",
|
||||
"## Why we need Batch in Tianshou?\n",
|
||||
"The motivation behind the implementation of Batch module is simple. In DRL, you need to handle a lot of dictionary-format data. For instance most algorithms would require you to store state, action, and reward data for every step when interacting with the environment. All these data can be organized as a dictionary and a Batch module helps Tianshou unify the interface of a diverse set of algorithms. Plus, Batch supports advanced indexing, concatenation and splitting, formatting print just like any other numpy array, which may be very helpful for developers.\n",
|
||||
"<div align=center>\n",
|
||||
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/concepts_arch.png\", title=\"Data flow is converted into a Batch in Tianshou\">\n",
|
||||
"\n",
|
||||
"<a> Data flow is converted into a Batch in Tianshou </a>\n",
|
||||
"</div>\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "_Xenx64M9HhV"
|
||||
},
|
||||
"source": [
|
||||
"## Basic Usages"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "4YGX_f1Z9Uil"
|
||||
},
|
||||
"source": [
|
||||
"### Initialization\n",
|
||||
"Batch can be converted directly from a python dictionary, and all data structure will be converted to numpy array if possible."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Jl3-4BRbp3MM",
|
||||
"outputId": "a8b225f6-2893-4716-c694-3c2ff558b7f0"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# converted from a python library\n",
|
||||
"print(\"========================================\")\n",
|
||||
"batch1 = Batch({\"a\": [4, 4], \"b\": (5, 5)})\n",
|
||||
"print(batch1)\n",
|
||||
"\n",
|
||||
"# initialization of batch2 is equivalent to batch1\n",
|
||||
"print(\"========================================\")\n",
|
||||
"batch2 = Batch(a=[4, 4], b=(5, 5))\n",
|
||||
"print(batch2)\n",
|
||||
"\n",
|
||||
"# the dictionary can be nested, and it will be turned into a nested Batch\n",
|
||||
"print(\"========================================\")\n",
|
||||
"data = {\n",
|
||||
" \"action\": np.array([1.0, 2.0, 3.0]),\n",
|
||||
" \"reward\": 3.66,\n",
|
||||
" \"obs\": {\n",
|
||||
" \"rgb_obs\": np.zeros((3, 3)),\n",
|
||||
" \"flatten_obs\": np.ones(5),\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"batch3 = Batch(data, extra=\"extra_string\")\n",
|
||||
"print(batch3)\n",
|
||||
"# batch3.obs is also a Batch\n",
|
||||
"print(type(batch3.obs))\n",
|
||||
"print(batch3.obs.rgb_obs)\n",
|
||||
"\n",
|
||||
"# a list of dictionary/Batch will automatically be concatenated/stacked, providing convenience if you\n",
|
||||
"# want to use parallelized environments to collect data.\n",
|
||||
"print(\"========================================\")\n",
|
||||
"batch4 = Batch([data] * 3)\n",
|
||||
"print(batch4)\n",
|
||||
"print(batch4.obs.rgb_obs.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "JCf6bqY3uf5L"
|
||||
},
|
||||
"source": [
|
||||
"### Getting access to data\n",
|
||||
"You can conveniently search or change the key-value pair in the Batch just as if it is a python dictionary."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "2TNIY90-vU9b",
|
||||
"outputId": "de52ffe9-03c2-45f2-d95a-4071132daa4a"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"batch1 = Batch({\"a\": [4, 4], \"b\": (5, 5)})\n",
|
||||
"print(batch1)\n",
|
||||
"# add or delete key-value pair in batch1\n",
|
||||
"print(\"========================================\")\n",
|
||||
"batch1.c = Batch(c1=np.arange(3), c2=False)\n",
|
||||
"del batch1.a\n",
|
||||
"print(batch1)\n",
|
||||
"\n",
|
||||
"# access value by key\n",
|
||||
"print(\"========================================\")\n",
|
||||
"assert batch1[\"c\"] is batch1.c\n",
|
||||
"print(\"c\" in batch1)\n",
|
||||
"\n",
|
||||
"# traverse the Batch\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for key, value in batch1.items():\n",
|
||||
" print(str(key) + \": \" + str(value))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "bVywStbV9jD2"
|
||||
},
|
||||
"source": [
|
||||
"### Indexing and Slicing\n",
|
||||
"If all values in Batch share the same shape in certain dimensions, Batch can support advanced indexing and slicing just like a normal numpy array."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "gKza3OJnzc_D",
|
||||
"outputId": "4f240bfe-4a69-4c1b-b40e-983c5c4d0cbc"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Let us suppose we've got 4 environments, each returns a step of data\n",
|
||||
"step_datas = [\n",
|
||||
" {\n",
|
||||
" \"act\": np.random.randint(10),\n",
|
||||
" \"rew\": 0.0,\n",
|
||||
" \"obs\": np.ones((3, 3)),\n",
|
||||
" \"info\": {\"done\": np.random.choice(2), \"failed\": False},\n",
|
||||
" }\n",
|
||||
" for _ in range(4)\n",
|
||||
"]\n",
|
||||
"batch = Batch(step_datas)\n",
|
||||
"print(batch)\n",
|
||||
"print(batch.shape)\n",
|
||||
"\n",
|
||||
"# advanced indexing is supported, if we only want to select data in a given set of environments\n",
|
||||
"print(\"========================================\")\n",
|
||||
"print(batch[0])\n",
|
||||
"print(batch[[0, 3]])\n",
|
||||
"\n",
|
||||
"# slicing is also supported\n",
|
||||
"print(\"========================================\")\n",
|
||||
"print(batch[-2:])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Aggregation and Splitting\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "1vUwQ-Hw9jtu"
|
||||
},
|
||||
"source": [
|
||||
"Again, just like a numpy array. Play the example code below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "f5UkReyn3_kb",
|
||||
"outputId": "e7bb3324-7f20-4810-a328-479117efca55"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# concat batches with compatible keys\n",
|
||||
"# try incompatible keys yourself if you feel curious\n",
|
||||
"print(\"========================================\")\n",
|
||||
"b1 = Batch(a=[{\"b\": np.float64(1.0), \"d\": Batch(e=np.array(3.0))}])\n",
|
||||
"b2 = Batch(a=[{\"b\": np.float64(4.0), \"d\": {\"e\": np.array(6.0)}}])\n",
|
||||
"b12_cat_out = Batch.cat([b1, b2])\n",
|
||||
"print(b1)\n",
|
||||
"print(b2)\n",
|
||||
"print(b12_cat_out)\n",
|
||||
"\n",
|
||||
"# stack batches with compatible keys\n",
|
||||
"# try incompatible keys yourself if you feel curious\n",
|
||||
"print(\"========================================\")\n",
|
||||
"b3 = Batch(a=np.zeros((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[1], [2]]))\n",
|
||||
"b4 = Batch(a=np.ones((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[0], [3]]))\n",
|
||||
"b34_stack = Batch.stack((b3, b4), axis=1)\n",
|
||||
"print(b3)\n",
|
||||
"print(b4)\n",
|
||||
"print(b34_stack)\n",
|
||||
"\n",
|
||||
"# split the batch into small batches of size 1, breaking the order of the data\n",
|
||||
"print(\"========================================\")\n",
|
||||
"print(type(b34_stack.split(1)))\n",
|
||||
"print(list(b34_stack.split(1, shuffle=True)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Smc_W1Cx6zRS"
|
||||
},
|
||||
"source": [
|
||||
"### Data type converting\n",
|
||||
"Besides numpy array, Batch actually also supports Torch Tensor. The usages are exactly the same. Cool, isn't it?"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Y6im_Mtb7Ody",
|
||||
"outputId": "898e82c4-b940-4c35-a0f9-dedc4a9bc500"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"batch1 = Batch(a=np.arange(2), b=torch.zeros((2, 2)))\n",
|
||||
"batch2 = Batch(a=np.arange(2), b=torch.ones((2, 2)))\n",
|
||||
"batch_cat = Batch.cat([batch1, batch2, batch1])\n",
|
||||
"print(batch_cat)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "1wfTUVKb6xki"
|
||||
},
|
||||
"source": [
|
||||
"You can convert the data type easily, if you no longer want to use hybrid data type anymore."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "F7WknVs98DHD",
|
||||
"outputId": "cfd0712a-1df3-4208-e6cc-9149840bdc40"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"batch_cat.to_numpy()\n",
|
||||
"print(batch_cat)\n",
|
||||
"batch_cat.to_torch()\n",
|
||||
"print(batch_cat)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "NTFVle1-9Biz"
|
||||
},
|
||||
"source": [
|
||||
"Batch is even serializable, just in case you may need to save it to disk or restore it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Lnf17OXv9YRb",
|
||||
"outputId": "753753f2-3f66-4d4b-b4ff-d57f9c40d1da"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))\n",
|
||||
"batch_pk = pickle.loads(pickle.dumps(batch))\n",
|
||||
"print(batch_pk)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "-vPMiPZ-9kJN"
|
||||
},
|
||||
"source": [
|
||||
"## Further Reading"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8Oc1p8ud9kcu"
|
||||
},
|
||||
"source": [
|
||||
"Would like to learn more advanced usages of Batch? Feel curious about how data is organized inside the Batch? Check the [documentation](https://tianshou.readthedocs.io/en/master/03_api/tianshou.data.html) and other [tutorials](https://tianshou.readthedocs.io/en/master/tutorials/batch.html#) for more details."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
457
docs/02_notebooks/L2_Buffer.ipynb
Normal file
@ -0,0 +1,457 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "xoPiGVD8LNma"
|
||||
},
|
||||
"source": [
|
||||
"# Buffer\n",
|
||||
"Replay Buffer is a very common module in DRL implementations. In Tianshou, you can consider Buffer module as as a specialized form of Batch, which helps you track all data trajectories and provide utilities such as sampling method besides the basic storage.\n",
|
||||
"\n",
|
||||
"There are many kinds of Buffer modules in Tianshou, two most basic ones are ReplayBuffer and VectorReplayBuffer. The later one is specially designed for parallelized environments (will introduce in tutorial L3)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "OdesCAxANehZ"
|
||||
},
|
||||
"source": [
|
||||
"## Usages"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "fUbLl9T_SrTR"
|
||||
},
|
||||
"source": [
|
||||
"### Basic usages as a batch\n",
|
||||
"Usually a buffer stores all the data in a batch with circular-queue style."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": [
|
||||
"hide-cell",
|
||||
"remove-output"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tianshou.data import Batch, ReplayBuffer\n",
|
||||
"from numpy import False_\n",
|
||||
"\n",
|
||||
"import pickle"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "mocZ6IqZTH62",
|
||||
"outputId": "66cc4181-c51b-4a47-aacf-666b92b7fc52"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# a buffer is initialised with its maxsize set to 10 (older data will be discarded if more data flow in).\n",
|
||||
"print(\"========================================\")\n",
|
||||
"buf = ReplayBuffer(size=10)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"\n",
|
||||
"# add 3 steps of data into ReplayBuffer sequentially\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(3):\n",
|
||||
" buf.add(Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, done=0, obs_next=i + 1, info={}))\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"\n",
|
||||
"# add another 10 steps of data into ReplayBuffer sequentially\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(3, 13):\n",
|
||||
" buf.add(Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, done=0, obs_next=i + 1, info={}))\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "H8B85Y5yUfTy"
|
||||
},
|
||||
"source": [
|
||||
"Just like Batch, ReplayBuffer supports concatenation, splitting, advanced slicing and indexing, etc."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "cOX-ADOPNeEK",
|
||||
"outputId": "f1a8ec01-b878-419b-f180-bdce3dee73e6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(buf[-1])\n",
|
||||
"print(buf[-3:])\n",
|
||||
"# Try more methods you find useful in Batch yourself."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "vqldap-2WQBh"
|
||||
},
|
||||
"source": [
|
||||
"ReplayBuffer can also be saved into local disk, still keeping track of the trajectories. This is extremely helpful in offline DRL settings."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Ppx0L3niNT5K"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"_buf = pickle.loads(pickle.dumps(buf))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": [
|
||||
"remove-cell"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# ToDo: update link to gymnasium"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Eqezp0OyXn6J"
|
||||
},
|
||||
"source": [
|
||||
"### Understanding reserved keys for buffer\n",
|
||||
"As I have explained, ReplayBuffer is specially designed to utilize the implementations of DRL algorithms. So, for convenience, we reserve certain nine reserved keys in Batch.\n",
|
||||
"\n",
|
||||
"* `obs`\n",
|
||||
"* `act`\n",
|
||||
"* `rew`\n",
|
||||
"* `terminated`\n",
|
||||
"* `truncated`\n",
|
||||
"* `done`\n",
|
||||
"* `obs_next`\n",
|
||||
"* `info`\n",
|
||||
"* `policy`\n",
|
||||
"\n",
|
||||
"The meaning of these nine reserved keys are consistent with the meaning in [OPENAI Gym](https://gym.openai.com/). We would recommend you simply use these nine keys when adding batched data into ReplayBuffer, because\n",
|
||||
"some of them are tracked in ReplayBuffer (e.g. \"done\" value is tracked to help us determine a trajectory's start index and end index, together with its total reward and episode length.)\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"buf.add(Batch(......, extro_info=0)) # This is okay but not recommended.\n",
|
||||
"buf.add(Batch(......, info={\"extro_info\":0})) # Recommended.\n",
|
||||
"```\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ueAbTspsc6jo"
|
||||
},
|
||||
"source": [
|
||||
"### Data sampling\n",
|
||||
"We keep a replay buffer in DRL for one purpose:\"sample data from it for training\". `ReplayBuffer.sample()` and `ReplayBuffer.split(..., shuffle=True)` can both fulfill this need."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "P5xnYOhrchDl",
|
||||
"outputId": "bcd2c970-efa6-43bb-8709-720d38f77bbd"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"buf.sample(batch_size=5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "IWyaOSKOcgK4"
|
||||
},
|
||||
"source": [
|
||||
"## Trajectory tracking\n",
|
||||
"Compared to Batch, a unique feature of ReplayBuffer is that it can help you track the environment trajectories.\n",
|
||||
"\n",
|
||||
"First, let us simulate a situation, where we add three trajectories into the buffer. The last trajectory is still not finished yet."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"editable": true,
|
||||
"id": "H0qRb6HLfhLB",
|
||||
"outputId": "9bdb7d4e-b6ec-489f-a221-0bddf706d85b",
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"buf = ReplayBuffer(size=10)\n",
|
||||
"# Add the first trajectory (length is 3) into ReplayBuffer\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(3):\n",
|
||||
" result = buf.add(\n",
|
||||
" Batch(\n",
|
||||
" obs=i,\n",
|
||||
" act=i,\n",
|
||||
" rew=i,\n",
|
||||
" terminated=1 if i == 2 else 0,\n",
|
||||
" truncated=0,\n",
|
||||
" done=True if i == 2 else False,\n",
|
||||
" obs_next=i + 1,\n",
|
||||
" info={},\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" print(result)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"# Add the second trajectory (length is 5) into ReplayBuffer\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(3, 8):\n",
|
||||
" result = buf.add(\n",
|
||||
" Batch(\n",
|
||||
" obs=i,\n",
|
||||
" act=i,\n",
|
||||
" rew=i,\n",
|
||||
" terminated=1 if i == 7 else 0,\n",
|
||||
" truncated=0,\n",
|
||||
" done=True if i == 7 else False,\n",
|
||||
" obs_next=i + 1,\n",
|
||||
" info={},\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" print(result)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"# Add the third trajectory (length is 5, still not finished) into ReplayBuffer\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(8, 13):\n",
|
||||
" result = buf.add(\n",
|
||||
" Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, done=False, obs_next=i + 1, info={})\n",
|
||||
" )\n",
|
||||
" print(result)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "dO7PWdb_hkXA"
|
||||
},
|
||||
"source": [
|
||||
"### episode length and rewards tracking\n",
|
||||
"Notice that `ReplayBuffer.add()` returns a tuple of 4 numbers every time it returns, meaning `(current_index, episode_reward, episode_length, episode_start_index)`. `episode_reward` and `episode_length` are valid only when a trajectory is finished. This might save developers some trouble.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "xbVc90z8itH0"
|
||||
},
|
||||
"source": [
|
||||
"### Episode index management\n",
|
||||
"In the ReplayBuffer above, we can get access to any data step by indexing.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "4mKwo54MjupY",
|
||||
"outputId": "9ae14a7e-908b-44eb-afec-89b45bac5961"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(buf)\n",
|
||||
"data = buf[6]\n",
|
||||
"print(data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "p5Co_Fmzj8Sw"
|
||||
},
|
||||
"source": [
|
||||
"Now we know that step \"6\" is not the start of an episode (it should be step 4, 4-7 is the second trajectory we add into the ReplayBuffer), we wonder what is the earliest index of that episode.\n",
|
||||
"\n",
|
||||
"This may seem easy but actually it is not. We cannot simply look at the \"done\" flag now, because we can see that since the third-added trajectory is not finished yet, step \"4\" is surrounded by flag \"False\". There are many things to consider. Things could get more nasty if you are using more advanced ReplayBuffer like VectorReplayBuffer, because now the data is not stored in a simple circular-queue.\n",
|
||||
"\n",
|
||||
"Luckily, all ReplayBuffer instances help you identify step indexes through a unified API."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "DcJ0LEX6mxHg",
|
||||
"outputId": "7830f5fb-96d9-4298-d09b-24e64b2f633c"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Search for the previous index of index \"6\"\n",
|
||||
"now_index = 6\n",
|
||||
"while True:\n",
|
||||
" prev_index = buf.prev(now_index)\n",
|
||||
" print(prev_index)\n",
|
||||
" if prev_index == now_index:\n",
|
||||
" break\n",
|
||||
" else:\n",
|
||||
" now_index = prev_index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "4Wlb57V4lQyQ"
|
||||
},
|
||||
"source": [
|
||||
"Using `ReplayBuffer.prev()`, we know that the earliest step of that episode is step \"3\". Similarly, `ReplayBuffer.next()` helps us identify the last index of an episode regardless of which kind of ReplayBuffer we are using."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "zl5TRMo7oOy5",
|
||||
"outputId": "4a11612c-3ee0-4e74-b028-c8759e71fbdb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# next step of indexes [4,5,6,7,8,9] are:\n",
|
||||
"print(buf.next([4, 5, 6, 7, 8, 9]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "YJ9CcWZXoOXw"
|
||||
},
|
||||
"source": [
|
||||
"We can also search for the indexes which are labeled \"done: False\", but are the last step in a trajectory."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Xkawk97NpItg",
|
||||
"outputId": "df10b359-c2c7-42ca-e50d-9caee6bccadd"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(buf.unfinished_index())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8_lMr0j3pOmn"
|
||||
},
|
||||
"source": [
|
||||
"Aforementioned APIs will be helpful when we calculate quantities like GAE and n-step-returns in DRL algorithms ([Example usage in Tianshou](https://github.com/thu-ml/tianshou/blob/6fc68578127387522424460790cbcb32a2bd43c4/tianshou/policy/base.py#L384)). The unified APIs ensure a modular design and a flexible interface."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "FEyE0c7tNfwa"
|
||||
},
|
||||
"source": [
|
||||
"## Further Reading\n",
|
||||
"### Other Buffer Module\n",
|
||||
"\n",
|
||||
"* PrioritizedReplayBuffer, which helps you implement [prioritized experience replay](https://arxiv.org/abs/1511.05952)\n",
|
||||
"* CachedReplayBuffer, one main buffer with several cached buffers (higher sample efficiency in some scenarios)\n",
|
||||
"* ReplayBufferManager, A base class that can be inherited (may help you manage multiple buffers).\n",
|
||||
"\n",
|
||||
"Check the documentation and the source code for more details.\n",
|
||||
"\n",
|
||||
"### Support for steps stacking to use RNN in DRL.\n",
|
||||
"There is an option called `stack_num` (default to 1) when initializing the ReplayBuffer, which may help you use RNN in your algorithm. Check the documentation for details."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
263
docs/02_notebooks/L3_Vectorized__Environment.ipynb
Normal file
@ -0,0 +1,263 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "W5V7z3fVX7_b"
|
||||
},
|
||||
"source": [
|
||||
"# Vectorized Environment\n",
|
||||
"In reinforcement learning, the agent interacts with environments to improve itself. In this tutorial we will concentrate on the environment part. Although there are many kinds of environments or their libraries in DRL research, Tianshou chooses to keep a consistent API with [OPENAI Gym](https://gym.openai.com/).\n",
|
||||
"\n",
|
||||
"<div align=center>\n",
|
||||
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/rl-loop.jpg\", title=\"The agents interacting with the environment\">\n",
|
||||
"\n",
|
||||
"<a> The agents interacting with the environment </a>\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
"In Gym, an environment receives an action and returns next observation and reward. This process is slow and sometimes can be the throughput bottleneck in a DRL experiment.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "A0NGWZ8adBwt"
|
||||
},
|
||||
"source": [
|
||||
"Tianshou provides vectorized environment wrapper for a Gym environment. This wrapper allows you to make use of multiple cpu cores in your server to accelerate the data sampling."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": [
|
||||
"remove-cell"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# ToDo: Clarify if it has to be done, or truncated. Also in the function description of ´SubprocVectorEnv.step()´, output is not clear"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"editable": true,
|
||||
"id": "67wKtkiNi3lb",
|
||||
"outputId": "1e04353b-7a91-4c32-e2ae-f3889d58aa5e",
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": [
|
||||
"remove-output",
|
||||
"hide-cell"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tianshou.env import SubprocVectorEnv\n",
|
||||
"import numpy as np\n",
|
||||
"import gymnasium as gym\n",
|
||||
"import time\n",
|
||||
"from tianshou.env import DummyVectorEnv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"num_cpus = [1, 2, 5]\n",
|
||||
"for num_cpu in num_cpus:\n",
|
||||
" env = SubprocVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(num_cpu)])\n",
|
||||
" env.reset()\n",
|
||||
" sampled_steps = 0\n",
|
||||
" time_start = time.time()\n",
|
||||
" while sampled_steps < 1000:\n",
|
||||
" act = np.random.choice(2, size=num_cpu)\n",
|
||||
" obs, rew, terminated, truncated, info = env.step(act)\n",
|
||||
" if np.sum(terminated):\n",
|
||||
" env.reset(np.where(terminated)[0])\n",
|
||||
" sampled_steps += num_cpu\n",
|
||||
" time_used = time.time() - time_start\n",
|
||||
" print(\"{}s used to sample 1000 steps if using {} cpus.\".format(time_used, num_cpu))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "S1b6vxp9nEUS"
|
||||
},
|
||||
"source": [
|
||||
"You may notice that the speed doesn't increase linearly when we add subprocess numbers. There are multiple reasons behind this. One reason is that synchronize exception causes straggler effect. One way to solve this would be to use asynchronous mode. We leave this for further reading if you feel interested.\n",
|
||||
"\n",
|
||||
"Note that SubprocVectorEnv should only be used when the environment execution is slow. In practice, DummyVectorEnv (or raw Gym environment) is actually more efficient for a simple environment like CartPole because now you avoid both straggler effect and the overhead of communication between subprocesses."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Z6yPxdqFp18j"
|
||||
},
|
||||
"source": [
|
||||
"## Usages\n",
|
||||
"### Initialization\n",
|
||||
"Just pass in a list of functions which return the initialized environment upon called."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ssLcrL_pq24-"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# In Gym\n",
|
||||
"env = gym.make(\"CartPole-v1\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# In Tianshou\n",
|
||||
"def helper_function():\n",
|
||||
" env = gym.make(\"CartPole-v1\")\n",
|
||||
" # other operations such as env.seed(np.random.choice(10))\n",
|
||||
" return env\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"envs = DummyVectorEnv([helper_function for _ in range(5)])\n",
|
||||
"print(envs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "X7p8csjdrwIN"
|
||||
},
|
||||
"source": [
|
||||
"### EnvPool supporting\n",
|
||||
"Besides integrated environment wrappers, Tianshou also fully supports [EnvPool](https://github.com/sail-sg/envpool/). Explore its Github page yourself."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "kvIfqh0vqAR5"
|
||||
},
|
||||
"source": [
|
||||
"### Environment execution and resetting\n",
|
||||
"The only difference between Vectorized environments and standard Gym environments is that passed in actions and returned rewards/observations are also vectorized."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": [
|
||||
"remove-cell"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# ToDo: num_cpu is defined for a for loop, not as a parameter..."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "BH1ZnPG6tkdD"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# In Gym, env.reset() returns a single observation.\n",
|
||||
"print(\"In Gym, env.reset() returns a single observation.\")\n",
|
||||
"print(env.reset())\n",
|
||||
"\n",
|
||||
"# In Tianshou, envs.reset() returns stacked observations.\n",
|
||||
"print(\"========================================\")\n",
|
||||
"print(\"In Tianshou, envs.reset() returns stacked observations.\")\n",
|
||||
"print(envs.reset())\n",
|
||||
"\n",
|
||||
"obs, rew, done, truncated, info = envs.step(np.random.choice(2, size=num_cpu))\n",
|
||||
"print(info)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "qXroB7KluvP9"
|
||||
},
|
||||
"source": [
|
||||
"If we only want to execute several environments. The `id` argument can be used."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ufvFViKTu8d_"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(envs.step(np.random.choice(2, size=3), id=[0, 3, 1]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "fekHR1a6X_HB"
|
||||
},
|
||||
"source": [
|
||||
"## Further Reading\n",
|
||||
"### Other environment wrappers in Tianshou\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"* ShmemVectorEnv: use share memory instead of pipe based on SubprocVectorEnv;\n",
|
||||
"* RayVectorEnv: use Ray for concurrent activities and is currently the only choice for parallel simulation in a cluster with multiple machines.\n",
|
||||
"\n",
|
||||
"Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.env.html) for details.\n",
|
||||
"\n",
|
||||
"### Difference between synchronous and asynchronous mode (How to choose?)\n",
|
||||
"Explanation can be found at the [Parallel Sampling](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling) tutorial."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
860
docs/02_notebooks/L4_Policy.ipynb
Normal file
@ -0,0 +1,860 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "PNM9wqstBSY_"
|
||||
},
|
||||
"source": [
|
||||
"# Policy\n",
|
||||
"In reinforcement learning, the agent interacts with environments to improve itself. In this tutorial we will concentrate on the agent part. In Tianshou, both the agent and the core DRL algorithm are implemented in the Policy module. Tianshou provides more than 20 Policy modules, each representing one DRL algorithm. See supported algorithms [here](https://github.com/thu-ml/tianshou).\n",
|
||||
"\n",
|
||||
"<div align=center>\n",
|
||||
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/rl-loop.jpg\", title=\"The agents interacting with the environment\">\n",
|
||||
"\n",
|
||||
"<a> The agents interacting with the environment </a>\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
"All Policy modules inherit from a BasePolicy Class and share the same interface."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ZqdHYdoJJS51"
|
||||
},
|
||||
"source": [
|
||||
"## Creating your own Policy\n",
|
||||
"We will use the simple REINFORCE algorithm Policy to show the implementation of a Policy Module. The Policy we implement here will be a highly scaled-down version of [PGPolicy](https://github.com/thu-ml/tianshou/blob/master/tianshou/policy/modelfree/pg.py) in Tianshou."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "PWFBgZ4TJkfz"
|
||||
},
|
||||
"source": [
|
||||
"### Initialization\n",
|
||||
"Firstly we create the `REINFORCEPolicy` by inheriting from `BasePolicy` in Tianshou."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"id": "cDlSjASbJmy-",
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": [
|
||||
"hide-cell",
|
||||
"remove-output"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Dict, List\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import gymnasium as gym\n",
|
||||
"\n",
|
||||
"from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as\n",
|
||||
"from tianshou.policy import BasePolicy\n",
|
||||
"from tianshou.utils.net.common import Net\n",
|
||||
"from tianshou.utils.net.discrete import Actor"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class REINFORCEPolicy(BasePolicy):\n",
|
||||
" \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self):\n",
|
||||
" super().__init__(action_space=action_space)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "qc1RnIBbLCDN"
|
||||
},
|
||||
"source": [
|
||||
"As we have mentioned, the Policy Module mainly does two things:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"1. `policy.forward()` receives observation and other information (stored in a Batch) from the environment and returns a new Batch containing the action.\n",
|
||||
"2. `policy.update()` receives training data sampled from the replay buffer and updates itself, and then returns logging details.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"<div align=center>\n",
|
||||
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/pipeline.png\">\n",
|
||||
"\n",
|
||||
"<a> policy.forward() and policy.update() </a>\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
"We also need to take care of the following things:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"1. Since Tianshou is a **Deep** RL libraries, there should be a policy network in our Policy Module, also a Torch optimizer.\n",
|
||||
"2. In Tianshou's BasePolicy, `Policy.update()` first calls `Policy.process_fn()` to preprocess training data and computes quantities like episodic returns (gradient free), then it will call `Policy.learn()` to perform the back-propagation.\n",
|
||||
"\n",
|
||||
"Then we get the implementation below.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "6j32PSKUQ23w"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class REINFORCEPolicy(BasePolicy):\n",
|
||||
" \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self, model: torch.nn.Module, optim: torch.optim.Optimizer, action_space: gym.Space\n",
|
||||
" ):\n",
|
||||
" super().__init__(action_space=action_space)\n",
|
||||
" self.actor = model\n",
|
||||
" self.optim = optim\n",
|
||||
"\n",
|
||||
" def forward(self, batch: Batch) -> Batch:\n",
|
||||
" \"\"\"Compute action over the given batch data.\"\"\"\n",
|
||||
" act = None\n",
|
||||
" return Batch(act=act)\n",
|
||||
"\n",
|
||||
" def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:\n",
|
||||
" \"\"\"Compute the discounted returns for each transition.\"\"\"\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n",
|
||||
" \"\"\"Perform the back-propagation.\"\"\"\n",
|
||||
" return"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "tjtqjt8WRY5e"
|
||||
},
|
||||
"source": [
|
||||
"### Policy.forward()\n",
|
||||
"According to the equation of REINFORCE algorithm in Spinning Up's [documentation](https://spinningup.openai.com/en/latest/algorithms/vpg.html), we need to map the observation to an action distribution in action space using neural network (`self.actor`).\n",
|
||||
"\n",
|
||||
"<div align=center>\n",
|
||||
"<img src=\"https://spinningup.openai.com/en/latest/_images/math/3d29a18c0f98b1cdb656ecdf261ee37ffe8bb74b.svg\">\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
"Let us suppose the action space is discrete, and the distribution is a simple categorical distribution.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "uE4YDE-_RwgN"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def forward(self, batch: Batch) -> Batch:\n",
|
||||
" \"\"\"Compute action over the given batch data.\"\"\"\n",
|
||||
" self.dist_fn = torch.distributions.Categorical\n",
|
||||
" logits = self.actor(batch.obs)\n",
|
||||
" dist = self.dist_fn(logits)\n",
|
||||
" act = dist.sample()\n",
|
||||
" return Batch(act=act, dist=dist)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "CultfOeuTx2V"
|
||||
},
|
||||
"source": [
|
||||
"### Policy.process_fn()\n",
|
||||
"Now that we have defined our actor, if given training data we can set up a loss function and optimize our neural network. However, before that, we must first calculate episodic returns for every step in our training data to construct the REINFORCE loss function.\n",
|
||||
"\n",
|
||||
"Calculating episodic return is not hard, given `ReplayBuffer.next()` allows us to access every reward to go in an episode. A more convenient way would be to simply use the built-in method `BasePolicy.compute_episodic_return()` inherited from BasePolicy.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "wPAmOD7zV7n2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:\n",
|
||||
" \"\"\"Compute the discounted returns for each transition.\"\"\"\n",
|
||||
" returns, _ = self.compute_episodic_return(batch, buffer, indices, gamma=0.99, gae_lambda=1.0)\n",
|
||||
" batch.returns = returns\n",
|
||||
" return batch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "XA8OF4GnWWr5"
|
||||
},
|
||||
"source": [
|
||||
"`BasePolicy.compute_episodic_return()` could also be used to compute [GAE](https://arxiv.org/abs/1506.02438). Another similar method is `BasePolicy.compute_nstep_return()`. Check the [source code](https://github.com/thu-ml/tianshou/blob/6fc68578127387522424460790cbcb32a2bd43c4/tianshou/policy/base.py#L304) for more details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "7UsdzNaOXPpC"
|
||||
},
|
||||
"source": [
|
||||
"### Policy.learn()\n",
|
||||
"Data batch returned by `Policy.process_fn()` will flow into `Policy.learn()`. Final we can construct our loss function and perform the back-propagation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "aCO-dLXWXtz9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n",
|
||||
" \"\"\"Perform the back-propagation.\"\"\"\n",
|
||||
" logging_losses = []\n",
|
||||
" for _ in range(repeat):\n",
|
||||
" for minibatch in batch.split(batch_size, merge_last=True):\n",
|
||||
" self.optim.zero_grad()\n",
|
||||
" result = self(minibatch)\n",
|
||||
" dist = result.dist\n",
|
||||
" act = to_torch_as(minibatch.act, result.act)\n",
|
||||
" ret = to_torch(minibatch.returns, torch.float, result.act.device)\n",
|
||||
" log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n",
|
||||
" loss = -(log_prob * ret).mean()\n",
|
||||
" loss.backward()\n",
|
||||
" self.optim.step()\n",
|
||||
" logging_losses.append(loss.item())\n",
|
||||
" return {\"loss\": logging_losses}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "1BtuV2W0YJTi"
|
||||
},
|
||||
"source": [
|
||||
"## Implementation\n",
|
||||
"Finally we can assemble the implemented methods and form a REINFORCE Policy."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Ab0KNQHTOlGo"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class REINFORCEPolicy(BasePolicy):\n",
|
||||
" \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" *,\n",
|
||||
" model: torch.nn.Module,\n",
|
||||
" optim: torch.optim.Optimizer,\n",
|
||||
" action_space: gym.Space,\n",
|
||||
" ):\n",
|
||||
" super().__init__(action_space=action_space)\n",
|
||||
" self.actor = model\n",
|
||||
" self.optim = optim\n",
|
||||
" # action distribution\n",
|
||||
" self.dist_fn = torch.distributions.Categorical\n",
|
||||
"\n",
|
||||
" def forward(self, batch: Batch) -> Batch:\n",
|
||||
" \"\"\"Compute action over the given batch data.\"\"\"\n",
|
||||
" logits, _ = self.actor(batch.obs)\n",
|
||||
" dist = self.dist_fn(logits)\n",
|
||||
" act = dist.sample()\n",
|
||||
" return Batch(act=act, dist=dist)\n",
|
||||
"\n",
|
||||
" def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:\n",
|
||||
" \"\"\"Compute the discounted returns for each transition.\"\"\"\n",
|
||||
" returns, _ = self.compute_episodic_return(\n",
|
||||
" batch, buffer, indices, gamma=0.99, gae_lambda=1.0\n",
|
||||
" )\n",
|
||||
" batch.returns = returns\n",
|
||||
" return batch\n",
|
||||
"\n",
|
||||
" def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n",
|
||||
" \"\"\"Perform the back-propagation.\"\"\"\n",
|
||||
" logging_losses = []\n",
|
||||
" for _ in range(repeat):\n",
|
||||
" for minibatch in batch.split(batch_size, merge_last=True):\n",
|
||||
" self.optim.zero_grad()\n",
|
||||
" result = self(minibatch)\n",
|
||||
" dist = result.dist\n",
|
||||
" act = to_torch_as(minibatch.act, result.act)\n",
|
||||
" ret = to_torch(minibatch.returns, torch.float, result.act.device)\n",
|
||||
" log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n",
|
||||
" loss = -(log_prob * ret).mean()\n",
|
||||
" loss.backward()\n",
|
||||
" self.optim.step()\n",
|
||||
" logging_losses.append(loss.item())\n",
|
||||
" return {\"loss\": logging_losses}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "xlPAbh0lKti8"
|
||||
},
|
||||
"source": [
|
||||
"## Use the policy\n",
|
||||
"Note that `BasePolicy` itself inherits from `torch.nn.Module`. As a result, you can consider all Policy modules as a Torch Module. They share similar APIs.\n",
|
||||
"\n",
|
||||
"Firstly we will initialize a new REINFORCE policy."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "JkLFA9Z1KjuX"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"state_shape = 4\n",
|
||||
"action_shape = 2\n",
|
||||
"# Usually taken from an env by using env.action_space\n",
|
||||
"action_space = gym.spaces.Box(low=-1, high=1, shape=(2,))\n",
|
||||
"net = Net(state_shape, hidden_sizes=[16, 16], device=\"cpu\")\n",
|
||||
"actor = Actor(net, action_shape, device=\"cpu\").to(\"cpu\")\n",
|
||||
"optim = torch.optim.Adam(actor.parameters(), lr=0.0003)\n",
|
||||
"\n",
|
||||
"policy = REINFORCEPolicy(model=actor, optim=optim, action_space=action_space)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "LAo_0t2fekUD"
|
||||
},
|
||||
"source": [
|
||||
"REINFORCE policy shares same APIs with the Torch Module."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "UiuTc8RhJiEi",
|
||||
"outputId": "9b5bc54c-6303-45f3-ba81-2216a44931e8"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(policy)\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for para in policy.parameters():\n",
|
||||
" print(para.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "-RCrsttYgAG-"
|
||||
},
|
||||
"source": [
|
||||
"### Making decision\n",
|
||||
"Given a batch of observations, the policy can return a batch of actions and other data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "0jkBb6AAgUla",
|
||||
"outputId": "37948844-cdd8-4567-9481-89453c80a157"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs_batch = Batch(obs=np.ones(shape=(256, 4)))\n",
|
||||
"action = policy(obs_batch) # forward() method is called\n",
|
||||
"print(action)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "swikhnuDfKep"
|
||||
},
|
||||
"source": [
|
||||
"### Save and Load models\n",
|
||||
"Naturally, Tianshou Policy can be saved and loaded like a normal Torch Network."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "tYOoWM_OJRnA"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.save(policy.state_dict(), \"policy.pth\")\n",
|
||||
"assert policy.load_state_dict(torch.load(\"policy.pth\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "gp8PzOYsg5z-"
|
||||
},
|
||||
"source": [
|
||||
"### Algorithm Updating\n",
|
||||
"We have to collect some data and save them in the ReplayBuffer before updating our agent(policy). Typically we use collector to collect data, but we leave this part till later when we have learned the Collector in Tianshou. For now we generate some **fake** data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "XrrPxOUAYShR"
|
||||
},
|
||||
"source": [
|
||||
"#### Generating fake data\n",
|
||||
"Firstly, we need to \"pretend\" that we are using the \"Policy\" to collect data. We plan to collect 10 data so that we can update our algorithm."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "a14CmzSfYh5C",
|
||||
"outputId": "aaf45a1f-5e21-4bc8-cbe3-8ce798258af0"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# a buffer is initialised with its maxsize set to 20.\n",
|
||||
"print(\"========================================\")\n",
|
||||
"buf = ReplayBuffer(size=12)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"env = gym.make(\"CartPole-v1\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8S94cV7yZITR"
|
||||
},
|
||||
"source": [
|
||||
"Now we are pretending to collect the first episode. The first episode ends at step 3 (perhaps because we are performing too badly)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "a_mtvbmBZbfs"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs, info = env.reset()\n",
|
||||
"for i in range(3):\n",
|
||||
" act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n",
|
||||
" obs_next, rew, _, truncated, info = env.step(act)\n",
|
||||
" # pretend ending at step 3\n",
|
||||
" terminated = True if i == 2 else False\n",
|
||||
" info[\"id\"] = i\n",
|
||||
" buf.add(\n",
|
||||
" Batch(\n",
|
||||
" obs=obs,\n",
|
||||
" act=act,\n",
|
||||
" rew=rew,\n",
|
||||
" terminated=terminated,\n",
|
||||
" truncated=truncated,\n",
|
||||
" obs_next=obs_next,\n",
|
||||
" info=info,\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" obs = obs_next"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(buf)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "pkxq4gu9bGkt"
|
||||
},
|
||||
"source": [
|
||||
"Now we are pretending to collect the second episode. At step 7 the second episode still doesn't end, but we are unwilling to wait, so we stop collecting to update the algorithm."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "pAoKe02ybG68"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs, info = env.reset()\n",
|
||||
"for i in range(3, 10):\n",
|
||||
" act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n",
|
||||
" obs_next, rew, _, truncated, info = env.step(act)\n",
|
||||
" # pretend this episode never end\n",
|
||||
" terminated = False\n",
|
||||
" info[\"id\"] = i\n",
|
||||
" buf.add(\n",
|
||||
" Batch(\n",
|
||||
" obs=obs,\n",
|
||||
" act=act,\n",
|
||||
" rew=rew,\n",
|
||||
" terminated=terminated,\n",
|
||||
" truncated=truncated,\n",
|
||||
" obs_next=obs_next,\n",
|
||||
" info=info,\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" obs = obs_next"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "MKM6aWMucv-M"
|
||||
},
|
||||
"source": [
|
||||
"Our replay buffer looks like this now."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "CSJEEWOqXdTU",
|
||||
"outputId": "2b3bb75c-f219-4e82-ca78-0ea6173a91f9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "55VWhWpkdfEb"
|
||||
},
|
||||
"source": [
|
||||
"#### Updates\n",
|
||||
"Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "i_O1lJDWdeoc",
|
||||
"outputId": "b154741a-d6dc-46cb-898f-6e84fa14e5a7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 0 means sample all data from the buffer\n",
|
||||
"# batch_size=10 defines the training batch size\n",
|
||||
"# repeat=6 means repeat the training for 6 times\n",
|
||||
"policy.update(0, buf, batch_size=10, repeat=6)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "enqlFQLSJrQl"
|
||||
},
|
||||
"source": [
|
||||
"Not that difficult, right?"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "QJ5krjrcbuiA"
|
||||
},
|
||||
"source": [
|
||||
"## Further Reading\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "pmWi3HuXWcV8"
|
||||
},
|
||||
"source": [
|
||||
"### Pre-defined Networks\n",
|
||||
"Tianshou provides numerous pre-defined networks usually used in DRL so that you don't have to bother yourself. Check this [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.utils.html#pre-defined-networks) for details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "UPVl5LBEWJ0t"
|
||||
},
|
||||
"source": [
|
||||
"### How to compute GAE on your own?\n",
|
||||
"(Note that for this reading you need to understand the calculation of [GAE](https://arxiv.org/abs/1506.02438) advantage first)\n",
|
||||
"\n",
|
||||
"In terms of code implementation, perhaps the most difficult and annoying part is computing GAE advantage. Just now, we use the `self.compute_episodic_return()` method inherited from `BasePolicy` to save us from all those troubles. However, it is still important that we know the details behind this.\n",
|
||||
"\n",
|
||||
"To compute GAE advantage, the usage of `self.compute_episodic_return()` may go like:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "D34GlVvPNz08",
|
||||
"outputId": "43a4e5df-59b5-4e4a-c61c-e69090810215"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"batch, indices = buf.sample(0) # 0 means sampling all the data from the buffer\n",
|
||||
"returns, advantage = BasePolicy.compute_episodic_return(\n",
|
||||
" batch, buf, indices, v_s_=np.zeros(10), v_s=np.zeros(10), gamma=1.0, gae_lambda=1.0\n",
|
||||
")\n",
|
||||
"print(returns)\n",
|
||||
"print(advantage)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "h_5Dt6XwQLXV"
|
||||
},
|
||||
"source": [
|
||||
"In the code above, we sample all the 10 data in the buffer and try to compute the GAE advantage. As we know, we need to estimate the value function of every observation to compute GAE advantage. so the passed in `v_s` is the value of batch.obs, `v_s_` is the value of batch.obs_next this is usually computed by:\n",
|
||||
"\n",
|
||||
"`v_s = critic(batch.obs)`,\n",
|
||||
"\n",
|
||||
"`v_s_ = critic(batch.obs_next)`,\n",
|
||||
"\n",
|
||||
"where both `v_s` and `v_s_` are 10 dimensional arrays and `critic` is usually a neural network.\n",
|
||||
"\n",
|
||||
"After we've got all those values, GAE can be computed following the equation below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ooHNIICGUO19"
|
||||
},
|
||||
"source": [
|
||||
"\\begin{aligned}\n",
|
||||
"\\hat{A}_{t}^{\\mathrm{GAE}(\\gamma, \\lambda)}: =& \\sum_{l=0}^{\\infty}(\\gamma \\lambda)^{l} \\delta_{t+l}^{V}\n",
|
||||
"\\end{aligned}\n",
|
||||
"\n",
|
||||
"while\n",
|
||||
"\n",
|
||||
"\\begin{equation}\n",
|
||||
"\\delta_{t}^{V} \\quad=-V\\left(s_{t}\\right)+r_{t}+\\gamma V\\left(s_{t+1}\\right)\n",
|
||||
"\\end{equation}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "eV6XZaouU7EV"
|
||||
},
|
||||
"source": [
|
||||
"But, if you do follow this equation I referred from the paper. You probably will get a slightly lower performance than you expected. There are at least 3 \"bugs\" in this equation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "FCxD9gNNVYbd"
|
||||
},
|
||||
"source": [
|
||||
"**First** is that Gym always returns you a `obs_next` even if this is already the last step. The value of this timestep is exactly 0 and you should not let the neural network estimate it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "rNZNUNgQVvRJ",
|
||||
"outputId": "44354595-c25a-4da8-b4d8-cffa31ac4b7d"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Assume v_s_ is got by calling critic(batch.obs_next)\n",
|
||||
"v_s_ = np.ones(10)\n",
|
||||
"v_s_ *= ~batch.done\n",
|
||||
"print(v_s_)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "2EtMi18QWXTN"
|
||||
},
|
||||
"source": [
|
||||
"After the fix above, we will perhaps get a more accurate estimate.\n",
|
||||
"\n",
|
||||
"**Secondly**, you must know when to stop bootstrapping. Usually we stop bootstrapping when we meet a `done` flag. However, in the buffer above, the last (10th) step is not marked by done=True, because the collecting has not finished. We must know all those unfinished steps so that we know when to stop bootstrapping.\n",
|
||||
"\n",
|
||||
"Luckily, this can be done under the assistance of buffer because buffers in Tianshou not only store data, but also help you manage data trajectories."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "saluvX4JU6bC",
|
||||
"outputId": "2994d178-2f33-40a0-a6e4-067916b0b5c5"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"unfinished_indexes = buf.unfinished_index()\n",
|
||||
"print(unfinished_indexes)\n",
|
||||
"done_indexes = np.where(batch.done)[0]\n",
|
||||
"print(done_indexes)\n",
|
||||
"stop_bootstrap_ids = np.concatenate([unfinished_indexes, done_indexes])\n",
|
||||
"print(stop_bootstrap_ids)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "qp6vVE4dYWv1"
|
||||
},
|
||||
"source": [
|
||||
"**Thirdly**, there are some special indexes which are marked by done flag. However, its value for obs_next should not be zero. This is because these steps are usually those at the last step of an episode, but this episode stops not because the agent can no longer get any rewards (value=0), but because the episode is too long so we have to truncate it. These kind of steps are always marked with `info['TimeLimit.truncated']=True` in Gym."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "tWkqXRJfZTvV"
|
||||
},
|
||||
"source": [
|
||||
"As a result, we need to rewrite the equation above\n",
|
||||
"\n",
|
||||
"`v_s_ *= ~batch.done`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "kms-QtxKZe-M"
|
||||
},
|
||||
"source": [
|
||||
"to\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"mask = batch.info['TimeLimit.truncated'] | (~batch.done)\n",
|
||||
"v_s_ *= mask\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "u_aPPoKraBu6"
|
||||
},
|
||||
"source": [
|
||||
"### Summary\n",
|
||||
"If you already felt bored by now, simply remember that Tianshou can help handle all these little details so that you can focus on the algorithm itself. Just call `BasePolicy.compute_episodic_return()`.\n",
|
||||
"\n",
|
||||
"If you still feel interested, we would recommend you check Appendix C in this [paper](https://arxiv.org/abs/2107.14171v2) and implementation of `BasePolicy.value_mask()` and `BasePolicy.compute_episodic_return()` for details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "2cPnUXRBWKD9"
|
||||
},
|
||||
"source": [
|
||||
"<center>\n",
|
||||
"<img src=../_static/images/timelimit.svg></img>\n",
|
||||
"</center>\n",
|
||||
"<center>\n",
|
||||
"<img src=../_static/images/policy_table.svg></img>\n",
|
||||
"</center>"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
287
docs/02_notebooks/L5_Collector.ipynb
Normal file
@ -0,0 +1,287 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "M98bqxdMsTXK"
|
||||
},
|
||||
"source": [
|
||||
"# Collector\n",
|
||||
"From its literal meaning, we can easily know that the Collector in Tianshou is used to collect training data. More specifically, the Collector controls the interaction between Policy (agent) and the environment. It also helps save the interaction data into the ReplayBuffer and returns episode statistics.\n",
|
||||
"\n",
|
||||
"<center>\n",
|
||||
"<img src=../_static/images/structure.svg></img>\n",
|
||||
"</center>\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "OX5cayLv4Ziu"
|
||||
},
|
||||
"source": [
|
||||
"## Usages\n",
|
||||
"Collector can be used both for training (data collecting) and evaluation in Tianshou."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Z6XKbj28u8Ze"
|
||||
},
|
||||
"source": [
|
||||
"### Policy evaluation\n",
|
||||
"We need to evaluate our trained policy from time to time in DRL experiments. Collector can help us with this.\n",
|
||||
"\n",
|
||||
"First we have to initialize a Collector with an (vectorized) environment and a given policy (agent)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"id": "w8t9ubO7u69J",
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": [
|
||||
"hide-cell",
|
||||
"remove-output"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gymnasium as gym\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from tianshou.data import Collector\n",
|
||||
"from tianshou.env import DummyVectorEnv\n",
|
||||
"from tianshou.policy import PGPolicy\n",
|
||||
"from tianshou.utils.net.common import Net\n",
|
||||
"from tianshou.utils.net.discrete import Actor\n",
|
||||
"from tianshou.data import VectorReplayBuffer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = gym.make(\"CartPole-v1\")\n",
|
||||
"test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n",
|
||||
"\n",
|
||||
"# model\n",
|
||||
"net = Net(\n",
|
||||
" env.observation_space.shape,\n",
|
||||
" hidden_sizes=[\n",
|
||||
" 16,\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
"actor = Actor(net, env.action_space.shape)\n",
|
||||
"optim = torch.optim.Adam(actor.parameters(), lr=0.0003)\n",
|
||||
"\n",
|
||||
"policy = PGPolicy(\n",
|
||||
" actor=actor,\n",
|
||||
" optim=optim,\n",
|
||||
" dist_fn=torch.distributions.Categorical,\n",
|
||||
" action_space=env.action_space,\n",
|
||||
" action_scaling=False,\n",
|
||||
")\n",
|
||||
"test_collector = Collector(policy, test_envs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "wmt8vuwpzQdR"
|
||||
},
|
||||
"source": [
|
||||
"Now we would like to collect 9 episodes of data to test how our initialized Policy performs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "9SuT6MClyjyH",
|
||||
"outputId": "1e48f13b-c1fe-4fc2-ca1b-669485efdcae"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"collect_result = test_collector.collect(n_episode=9)\n",
|
||||
"print(collect_result)\n",
|
||||
"print(\"Rewards of 9 episodes are {}\".format(collect_result[\"rews\"]))\n",
|
||||
"print(\"Average episode reward is {}.\".format(collect_result[\"rew\"]))\n",
|
||||
"print(\"Average episode length is {}.\".format(collect_result[\"len\"]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "zX9AQY0M0R3C"
|
||||
},
|
||||
"source": [
|
||||
"Now we wonder what is the performance of a random policy."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "UEcs8P8P0RLt",
|
||||
"outputId": "85f02f9d-b79b-48b2-99c6-36a1602f0884"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Reset the collector\n",
|
||||
"test_collector.reset()\n",
|
||||
"collect_result = test_collector.collect(n_episode=9, random=True)\n",
|
||||
"print(collect_result)\n",
|
||||
"print(\"Rewards of 9 episodes are {}\".format(collect_result[\"rews\"]))\n",
|
||||
"print(\"Average episode reward is {}.\".format(collect_result[\"rew\"]))\n",
|
||||
"print(\"Average episode length is {}.\".format(collect_result[\"len\"]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "sKQRTiG10ljU"
|
||||
},
|
||||
"source": [
|
||||
"Seems that an initialized policy performs even worse than a random policy without any training."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8RKmHIoG1A1k"
|
||||
},
|
||||
"source": [
|
||||
"### Data Collecting\n",
|
||||
"Data collecting is mostly used during training, when we need to store the collected data in a ReplayBuffer."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"id": "CB9XB9bF1YPC",
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_env_num = 4\n",
|
||||
"buffer_size = 100\n",
|
||||
"train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n",
|
||||
"replaybuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
|
||||
"\n",
|
||||
"train_collector = Collector(policy, train_envs, replaybuffer)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "rWKDazA42IUQ"
|
||||
},
|
||||
"source": [
|
||||
"Now we can collect 50 steps of data, which will be automatically saved in the replay buffer. You can still choose to collect a certain number of episodes rather than steps. Try it yourself."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "-fUtQOnM2Yi1",
|
||||
"outputId": "dceee987-433e-4b75-ed9e-823c20a9e1c2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(len(replaybuffer))\n",
|
||||
"collect_result = train_collector.collect(n_step=50)\n",
|
||||
"print(len(replaybuffer))\n",
|
||||
"print(collect_result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "EWO4A7plefwM",
|
||||
"outputId": "9a6f36d1-2b84-49b0-a03d-a8ebe8acadbf"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for i in range(13):\n",
|
||||
" print(i, replaybuffer.next(i))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "HW8PpWH9fLCo",
|
||||
"outputId": "7ca70c50-23b9-4405-9e42-2e5771cd9c78"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"replaybuffer.sample(10)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8NP7lOBU3-VS"
|
||||
},
|
||||
"source": [
|
||||
"## Further Reading\n",
|
||||
"The above collector actually collects 52 data at a time because 52 % 4 = 0. There is one asynchronous collector which allows you collect exactly 50 steps. Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.data.html#asynccollector) for details."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
252
docs/02_notebooks/L6_Trainer.ipynb
Normal file
@ -0,0 +1,252 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "S3-tJZy35Ck_"
|
||||
},
|
||||
"source": [
|
||||
"# Trainer\n",
|
||||
"Trainer is the highest-level encapsulation in Tianshou. It controls the training loop and the evaluation method. It also controls the interaction between the Collector and the Policy, with the ReplayBuffer serving as the media.\n",
|
||||
"\n",
|
||||
"<center>\n",
|
||||
"<img src=../_static/images/structure.svg></img>\n",
|
||||
"</center>\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ifsEQMzZ6mmz"
|
||||
},
|
||||
"source": [
|
||||
"## Usages\n",
|
||||
"In Tianshou v0.5.1, there are three types of Trainer. They are designed to be used in on-policy training, off-policy training and offline training respectively. We will use on-policy trainer as an example and leave the other two for further reading."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "XfsuU2AAE52C"
|
||||
},
|
||||
"source": [
|
||||
"### Pseudocode\n",
|
||||
"<center>\n",
|
||||
"<img src=../_static/images/pseudocode_off_policy.svg></img>\n",
|
||||
"</center>\n",
|
||||
"\n",
|
||||
"For the on-policy trainer, the main difference is that we clear the buffer after Line 10."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Hcp_o0CCFz12"
|
||||
},
|
||||
"source": [
|
||||
"### Training without trainer\n",
|
||||
"As we have learned the usages of the Collector and the Policy, it's possible that we write our own training logic.\n",
|
||||
"\n",
|
||||
"First, let us create the instances of Environment, ReplayBuffer, Policy and Collector."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"id": "do-xZ-8B7nVH",
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": [
|
||||
"hide-cell",
|
||||
"remove-output"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gymnasium as gym\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from tianshou.data import Collector, VectorReplayBuffer\n",
|
||||
"from tianshou.env import DummyVectorEnv\n",
|
||||
"from tianshou.policy import PGPolicy\n",
|
||||
"from tianshou.utils.net.common import Net\n",
|
||||
"from tianshou.utils.net.discrete import Actor\n",
|
||||
"from tianshou.trainer import OnpolicyTrainer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_env_num = 4\n",
|
||||
"buffer_size = (\n",
|
||||
" 2000 # Since REINFORCE is an on-policy algorithm, we don't need a very large buffer size\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Create the environments, used for training and evaluation\n",
|
||||
"env = gym.make(\"CartPole-v1\")\n",
|
||||
"test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n",
|
||||
"train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n",
|
||||
"\n",
|
||||
"# Create the Policy instance\n",
|
||||
"net = Net(\n",
|
||||
" env.observation_space.shape,\n",
|
||||
" hidden_sizes=[\n",
|
||||
" 16,\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
"actor = Actor(net, env.action_space.shape)\n",
|
||||
"optim = torch.optim.Adam(actor.parameters(), lr=0.001)\n",
|
||||
"policy = PGPolicy(\n",
|
||||
" actor=actor,\n",
|
||||
" optim=optim,\n",
|
||||
" dist_fn=torch.distributions.Categorical,\n",
|
||||
" action_space=env.action_space,\n",
|
||||
" action_scaling=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Create the replay buffer and the collector\n",
|
||||
"replaybuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
|
||||
"test_collector = Collector(policy, test_envs)\n",
|
||||
"train_collector = Collector(policy, train_envs, replaybuffer)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "wiEGiBgQIiFM"
|
||||
},
|
||||
"source": [
|
||||
"Now, we can try training our policy network. The logic is simple. We collect some data into the buffer and then we use the data to train our policy."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "JMUNPN5SI_kd",
|
||||
"outputId": "7d68323c-0322-4b82-dafb-7c7f63e7a26d"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_collector.reset()\n",
|
||||
"train_envs.reset()\n",
|
||||
"test_collector.reset()\n",
|
||||
"test_envs.reset()\n",
|
||||
"replaybuffer.reset()\n",
|
||||
"for i in range(10):\n",
|
||||
" evaluation_result = test_collector.collect(n_episode=10)\n",
|
||||
" print(\"Evaluation reward is {}\".format(evaluation_result[\"rew\"]))\n",
|
||||
" train_collector.collect(n_step=2000)\n",
|
||||
" # 0 means taking all data stored in train_collector.buffer\n",
|
||||
" policy.update(0, train_collector.buffer, batch_size=512, repeat=1)\n",
|
||||
" train_collector.reset_buffer(keep_statistics=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "QXBHIBckMs_2"
|
||||
},
|
||||
"source": [
|
||||
"The evaluation reward doesn't seem to improve. That is simply because we haven't trained it for enough time. Plus, the network size is too small and REINFORCE algorithm is actually not very stable. Don't worry, we will solve this problem in the end. Still we get some idea on how to start a training loop."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "p-7U_cwgF5Ej"
|
||||
},
|
||||
"source": [
|
||||
"### Training with trainer\n",
|
||||
"The trainer does almost the same thing. The only difference is that it has considered many details and is more modular."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "vcvw9J8RNtFE",
|
||||
"outputId": "b483fa8b-2a57-4051-a3d0-6d8162d948c5"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_collector.reset()\n",
|
||||
"train_envs.reset()\n",
|
||||
"test_collector.reset()\n",
|
||||
"test_envs.reset()\n",
|
||||
"replaybuffer.reset()\n",
|
||||
"\n",
|
||||
"result = OnpolicyTrainer(\n",
|
||||
" policy=policy,\n",
|
||||
" train_collector=train_collector,\n",
|
||||
" test_collector=test_collector,\n",
|
||||
" max_epoch=10,\n",
|
||||
" step_per_epoch=1,\n",
|
||||
" repeat_per_collect=1,\n",
|
||||
" episode_per_test=10,\n",
|
||||
" step_per_collect=2000,\n",
|
||||
" batch_size=512,\n",
|
||||
")\n",
|
||||
"print(result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "_j3aUJZQ7nml"
|
||||
},
|
||||
"source": [
|
||||
"## Further Reading\n",
|
||||
"### Logger usages\n",
|
||||
"Tianshou provides experiment loggers that are both tensorboard- and wandb-compatible. It also has a BaseLogger Class which allows you to self-define your own logger. Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.utils.html#tianshou.utils.BaseLogger) for details.\n",
|
||||
"\n",
|
||||
"### Learn more about the APIs of Trainers\n",
|
||||
"[documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.trainer.html)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"collapsed_sections": [
|
||||
"S3-tJZy35Ck_",
|
||||
"XfsuU2AAE52C",
|
||||
"p-7U_cwgF5Ej",
|
||||
"_j3aUJZQ7nml"
|
||||
],
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
333
docs/02_notebooks/L7_Experiment.ipynb
Normal file
@ -0,0 +1,333 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "_UaXOSRjDUF9"
|
||||
},
|
||||
"source": [
|
||||
"# Experiment\n",
|
||||
"Finally, we can assemble building blocks that we have came across in previous tutorials to conduct our first DRL experiment. In this experiment, we will use [PPO](https://arxiv.org/abs/1707.06347) algorithm to solve the classic CartPole task in Gym."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "2QRbCJvDHNAd"
|
||||
},
|
||||
"source": [
|
||||
"## Experiment\n",
|
||||
"To conduct this experiment, we need the following building blocks.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"* Two vectorized environments, one for training and one for evaluation\n",
|
||||
"* A PPO agent\n",
|
||||
"* A replay buffer to store transition data\n",
|
||||
"* Two collectors to manage the data collecting process, one for training and one for evaluation\n",
|
||||
"* A trainer to manage the training loop\n",
|
||||
"\n",
|
||||
"<div align=center>\n",
|
||||
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/pipeline.png\">\n",
|
||||
"\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
"Let us do this step by step."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "-Hh4E6i0Hj0I"
|
||||
},
|
||||
"source": [
|
||||
"## Preparation\n",
|
||||
"Firstly, install Tianshou if you haven't installed it before."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "7E4EhiBeHxD5"
|
||||
},
|
||||
"source": [
|
||||
"Import libraries we might need later."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"id": "ao9gWJDiHgG-",
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": [
|
||||
"hide-cell",
|
||||
"remove-output"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gymnasium as gym\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from tianshou.data import Collector, VectorReplayBuffer\n",
|
||||
"from tianshou.env import DummyVectorEnv\n",
|
||||
"from tianshou.policy import PPOPolicy\n",
|
||||
"from tianshou.trainer import OnpolicyTrainer\n",
|
||||
"from tianshou.utils.net.common import ActorCritic, Net\n",
|
||||
"from tianshou.utils.net.discrete import Actor, Critic\n",
|
||||
"\n",
|
||||
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "QnRg5y7THRYw"
|
||||
},
|
||||
"source": [
|
||||
"## Environment"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "YZERKCGtH8W1"
|
||||
},
|
||||
"source": [
|
||||
"We create two vectorized environments both for training and testing. Since the execution time of CartPole is extremely short, there is no need to use multi-process wrappers and we simply use DummyVectorEnv."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Mpuj5PFnDKVS"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = gym.make(\"CartPole-v1\")\n",
|
||||
"train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(20)])\n",
|
||||
"test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(10)])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "BJtt_Ya8DTAh"
|
||||
},
|
||||
"source": [
|
||||
"## Policy\n",
|
||||
"Next we need to initialize our PPO policy. PPO is an actor-critic-style on-policy algorithm, so we have to define the actor and the critic in PPO first.\n",
|
||||
"\n",
|
||||
"The actor is a neural network that shares the same network head with the critic. Both networks' input is the environment observation. The output of the actor is the action and the output of the critic is a single value, representing the value of the current policy.\n",
|
||||
"\n",
|
||||
"Luckily, Tianshou already provides basic network modules that we can use in this experiment."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "_Vy8uPWXP4m_"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# net is the shared head of the actor and the critic\n",
|
||||
"net = Net(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n",
|
||||
"actor = Actor(preprocess_net=net, action_shape=env.action_space.n, device=device).to(device)\n",
|
||||
"critic = Critic(preprocess_net=net, device=device).to(device)\n",
|
||||
"actor_critic = ActorCritic(actor=actor, critic=critic)\n",
|
||||
"\n",
|
||||
"# optimizer of the actor and the critic\n",
|
||||
"optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Lh2-hwE5Dn9I"
|
||||
},
|
||||
"source": [
|
||||
"Once we have defined the actor, the critic and the optimizer. We can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "OiJ2GkT0Qnbr"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dist = torch.distributions.Categorical\n",
|
||||
"policy = PPOPolicy(\n",
|
||||
" actor=actor,\n",
|
||||
" critic=critic,\n",
|
||||
" optim=optim,\n",
|
||||
" dist_fn=dist,\n",
|
||||
" action_space=env.action_space,\n",
|
||||
" deterministic_eval=True,\n",
|
||||
" action_scaling=False,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "okxfj6IEQ-r8"
|
||||
},
|
||||
"source": [
|
||||
"`deterministic_eval=True` means that we want to sample actions during training but we would like to always use the best action in evaluation. No randomness included."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "n5XAAbuBZarO"
|
||||
},
|
||||
"source": [
|
||||
"## Collector\n",
|
||||
"We can set up the collectors now. Train collector is used to collect and store training data, so an additional replay buffer has to be passed in."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ezwz0qerZhQM"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_collector = Collector(\n",
|
||||
" policy=policy, env=train_envs, buffer=VectorReplayBuffer(20000, len(train_envs))\n",
|
||||
")\n",
|
||||
"test_collector = Collector(policy=policy, env=test_envs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ZaoPxOd2hm0b"
|
||||
},
|
||||
"source": [
|
||||
"We use `VectorReplayBuffer` here because it's more efficient to collaborate with vectorized environments, you can simply consider `VectorReplayBuffer` as a a list of ordinary replay buffers."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "qBoE9pLUiC-8"
|
||||
},
|
||||
"source": [
|
||||
"## Trainer\n",
|
||||
"Finally, we can use the trainer to help us set up the training loop."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "i45EDnpxQ8gu",
|
||||
"outputId": "b1666b88-0bfa-4340-868e-58611872d988"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result = OnpolicyTrainer(\n",
|
||||
" policy=policy,\n",
|
||||
" train_collector=train_collector,\n",
|
||||
" test_collector=test_collector,\n",
|
||||
" max_epoch=10,\n",
|
||||
" step_per_epoch=50000,\n",
|
||||
" repeat_per_collect=10,\n",
|
||||
" episode_per_test=10,\n",
|
||||
" batch_size=256,\n",
|
||||
" step_per_collect=2000,\n",
|
||||
" stop_fn=lambda mean_reward: mean_reward >= 195,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ckgINHE2iTFR"
|
||||
},
|
||||
"source": [
|
||||
"## Results\n",
|
||||
"Print the training result."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "tJCPgmiyiaaX",
|
||||
"outputId": "40123ae3-3365-4782-9563-46c43812f10f"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "A-MJ9avMibxN"
|
||||
},
|
||||
"source": [
|
||||
"We can also test our trained agent."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "mnMANFcciiAQ",
|
||||
"outputId": "6febcc1e-7265-4a75-c9dd-34e29a3e5d21"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Let's watch its performance!\n",
|
||||
"policy.eval()\n",
|
||||
"result = test_collector.collect(n_episode=1, render=False)\n",
|
||||
"print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": [],
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@ -61,19 +61,19 @@ Test by GitHub Actions
|
||||
|
||||
1. Click the ``Actions`` button in your own repo:
|
||||
|
||||
.. image:: _static/images/action1.jpg
|
||||
.. image:: ../_static/images/action1.jpg
|
||||
:align: center
|
||||
|
||||
2. Click the green button:
|
||||
|
||||
.. image:: _static/images/action2.jpg
|
||||
.. image:: ../_static/images/action2.jpg
|
||||
:align: center
|
||||
|
||||
3. You will see ``Actions Enabled.`` on the top of html page.
|
||||
|
||||
4. When you push a new commit to your own repo (e.g. ``git push``), it will automatically run the test in this page:
|
||||
|
||||
.. image:: _static/images/action3.png
|
||||
.. image:: ../_static/images/action3.png
|
||||
:align: center
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
Contributor
|
||||
===========
|
||||
Contributors
|
||||
============
|
||||
|
||||
We always welcome contributions to help make Tianshou better. Below are an incomplete list of our contributors (find more on `this page <https://github.com/thu-ml/tianshou/graphs/contributors>`_).
|
||||
|
||||
142
docs/_config.yml
Normal file
@ -0,0 +1,142 @@
|
||||
# Book settings
|
||||
# Learn more at https://jupyterbook.org/customize/config.html
|
||||
|
||||
#######################################################################################
|
||||
# A default configuration that will be loaded for all jupyter books
|
||||
# Users are expected to override these values in their own `_config.yml` file.
|
||||
# This is also the "master list" of all allowed keys and values.
|
||||
|
||||
#######################################################################################
|
||||
# Book settings
|
||||
title : Tianshou Documentation # The title of the book. Will be placed in the left navbar.
|
||||
author : Tianshou contributors # The author of the book
|
||||
copyright : "2020, Tianshou contributors." # Copyright year to be placed in the footer
|
||||
logo : _static/images/tianshou-logo.png # A path to the book logo
|
||||
# Patterns to skip when building the book. Can be glob-style (e.g. "*skip.ipynb")
|
||||
exclude_patterns : ['**.ipynb_checkpoints', '.DS_Store', 'Thumbs.db', '_build', 'jupyter_execute', '.jupyter_cache', '.pytest_cache', 'docs/autogen_rst.py', 'docs/create_toc.py']
|
||||
# Auto-exclude files not in the toc
|
||||
only_build_toc_files : false
|
||||
|
||||
#######################################################################################
|
||||
# Execution settings
|
||||
execute:
|
||||
execute_notebooks : cache # Whether to execute notebooks at build time. Must be one of ("auto", "force", "cache", "off")
|
||||
cache : "" # A path to the jupyter cache that will be used to store execution artifacts. Defaults to `_build/.jupyter_cache/`
|
||||
exclude_patterns : [] # A list of patterns to *skip* in execution (e.g. a notebook that takes a really long time)
|
||||
timeout : -1 # The maximum time (in seconds) each notebook cell is allowed to run.
|
||||
run_in_temp : false # If `True`, then a temporary directory will be created and used as the command working directory (cwd),
|
||||
# otherwise the notebook's parent directory will be the cwd.
|
||||
allow_errors : false # If `False`, when a code cell raises an error the execution is stopped, otherwise all cells are always run.
|
||||
stderr_output : show # One of 'show', 'remove', 'remove-warn', 'warn', 'error', 'severe'
|
||||
|
||||
#######################################################################################
|
||||
# Parse and render settings
|
||||
parse:
|
||||
myst_enable_extensions: # default extensions to enable in the myst parser. See https://myst-parser.readthedocs.io/en/latest/using/syntax-optional.html
|
||||
- amsmath
|
||||
- colon_fence
|
||||
# - deflist
|
||||
- dollarmath
|
||||
# - html_admonition
|
||||
# - html_image
|
||||
- linkify
|
||||
# - replacements
|
||||
# - smartquotes
|
||||
- substitution
|
||||
- tasklist
|
||||
myst_url_schemes: [ mailto, http, https ] # URI schemes that will be recognised as external URLs in Markdown links
|
||||
myst_dmath_double_inline: true # Allow display math ($$) within an inline context
|
||||
|
||||
#######################################################################################
|
||||
# HTML-specific settings
|
||||
html:
|
||||
favicon : "" # A path to a favicon image
|
||||
use_edit_page_button : false # Whether to add an "edit this page" button to pages. If `true`, repository information in repository: must be filled in
|
||||
use_repository_button : false # Whether to add a link to your repository button
|
||||
use_issues_button : false # Whether to add an "open an issue" button
|
||||
use_multitoc_numbering : true # Continuous numbering across parts/chapters
|
||||
extra_footer : "" # Will be displayed underneath the footer.
|
||||
google_analytics_id : "" # A GA id that can be used to track book views.
|
||||
home_page_in_navbar : true # Whether to include your home page in the left Navigation Bar
|
||||
baseurl : "" # The base URL where your book will be hosted. Used for creating image previews and social links. e.g.: https://mypage.com/mybook/
|
||||
analytics:
|
||||
|
||||
comments:
|
||||
hypothesis : false
|
||||
utterances : false
|
||||
announcement : "" # A banner announcement at the top of the site.
|
||||
|
||||
#######################################################################################
|
||||
# LaTeX-specific settings
|
||||
latex:
|
||||
latex_engine : pdflatex # one of 'pdflatex', 'xelatex' (recommended for unicode), 'luatex', 'platex', 'uplatex'
|
||||
use_jupyterbook_latex : true # use sphinx-jupyterbook-latex for pdf builds as default
|
||||
targetname : book.tex
|
||||
# Add a bibtex file so that we can create citations
|
||||
bibtex_bibfiles:
|
||||
- refs.bib
|
||||
|
||||
#######################################################################################
|
||||
# Launch button settings
|
||||
launch_buttons:
|
||||
notebook_interface : classic # The interface interactive links will activate ["classic", "jupyterlab"]
|
||||
binderhub_url : "" # The URL of the BinderHub (e.g., https://mybinder.org)
|
||||
jupyterhub_url : "" # The URL of the JupyterHub (e.g., https://datahub.berkeley.edu)
|
||||
thebe : false # Add a thebe button to pages (requires the repository to run on Binder)
|
||||
colab_url : "https://colab.research.google.com"
|
||||
|
||||
repository:
|
||||
url : https://github.com/carlocagnetta/tianshou.git # The URL to your book's repository
|
||||
path_to_book : ./docs/ # A path to your book's folder, relative to the repository root.
|
||||
branch : master # Which branch of the repository should be used when creating links
|
||||
|
||||
#######################################################################################
|
||||
# Advanced and power-user settings
|
||||
sphinx:
|
||||
extra_extensions :
|
||||
- sphinx.ext.autodoc
|
||||
- sphinx.ext.viewcode
|
||||
- sphinx_toolbox.more_autodoc.sourcelink
|
||||
- sphinxcontrib.spelling
|
||||
local_extensions : # A list of local extensions to load by sphinx specified by "name: path" items
|
||||
recursive_update : false # A boolean indicating whether to overwrite the Sphinx config (true) or recursively update (false)
|
||||
config : # key-value pairs to directly over-ride the Sphinx configuration
|
||||
autodoc_typehints_format: "short"
|
||||
autodoc_show_sourcelink: True
|
||||
add_module_names: False
|
||||
github_username: thu-ml
|
||||
github_repository: tianshou
|
||||
python_use_unqualified_type_names: True
|
||||
nb_mime_priority_overrides: [
|
||||
[ 'html', 'application/vnd.jupyter.widget-view+json', 10 ],
|
||||
[ 'html', 'application/javascript', 20 ],
|
||||
[ 'html', 'text/html', 30 ],
|
||||
[ 'html', 'text/latex', 40 ],
|
||||
[ 'html', 'image/svg+xml', 50 ],
|
||||
[ 'html', 'image/png', 60 ],
|
||||
[ 'html', 'image/jpeg', 70 ],
|
||||
[ 'html', 'text/markdown', 80 ],
|
||||
[ 'html', 'text/plain', 90 ],
|
||||
[ 'spelling', 'application/vnd.jupyter.widget-view+json', 10 ],
|
||||
[ 'spelling', 'application/javascript', 20 ],
|
||||
[ 'spelling', 'text/html', 30 ],
|
||||
[ 'spelling', 'text/latex', 40 ],
|
||||
[ 'spelling', 'image/svg+xml', 50 ],
|
||||
[ 'spelling', 'image/png', 60 ],
|
||||
[ 'spelling', 'image/jpeg', 70 ],
|
||||
[ 'spelling', 'text/markdown', 80 ],
|
||||
[ 'spelling', 'text/plain', 90 ],
|
||||
]
|
||||
mathjax_path: https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
|
||||
mathjax3_config:
|
||||
loader: { load: [ '[tex]/configmacros' ] }
|
||||
tex:
|
||||
packages: { '[+]': [ 'configmacros' ] }
|
||||
macros:
|
||||
vect: ["{\\mathbf{\\boldsymbol{#1}} }", 1]
|
||||
E: "{\\mathbb{E}}"
|
||||
P: "{\\mathbb{P}}"
|
||||
R: "{\\mathbb{R}}"
|
||||
abs: ["{\\left| #1 \\right|}", 1]
|
||||
simpl: ["{\\Delta^{#1} }", 1]
|
||||
amax: "{\\text{argmax}}"
|
||||
BIN
docs/_static/images/action1.jpg
vendored
|
Before Width: | Height: | Size: 62 KiB After Width: | Height: | Size: 56 KiB |
BIN
docs/_static/images/action2.jpg
vendored
|
Before Width: | Height: | Size: 42 KiB After Width: | Height: | Size: 40 KiB |
BIN
docs/_static/images/action3.png
vendored
|
Before Width: | Height: | Size: 30 KiB After Width: | Height: | Size: 11 KiB |
BIN
docs/_static/images/aggregation.png
vendored
|
Before Width: | Height: | Size: 208 KiB After Width: | Height: | Size: 65 KiB |
BIN
docs/_static/images/async.png
vendored
|
Before Width: | Height: | Size: 55 KiB After Width: | Height: | Size: 20 KiB |
BIN
docs/_static/images/batch_reserve.png
vendored
|
Before Width: | Height: | Size: 74 KiB After Width: | Height: | Size: 23 KiB |
BIN
docs/_static/images/batch_tree.png
vendored
|
Before Width: | Height: | Size: 77 KiB After Width: | Height: | Size: 24 KiB |
BIN
docs/_static/images/concepts_arch.png
vendored
|
Before Width: | Height: | Size: 18 KiB After Width: | Height: | Size: 7.1 KiB |
BIN
docs/_static/images/concepts_arch2.png
vendored
|
Before Width: | Height: | Size: 29 KiB After Width: | Height: | Size: 10 KiB |
BIN
docs/_static/images/marl.png
vendored
|
Before Width: | Height: | Size: 73 KiB After Width: | Height: | Size: 25 KiB |
BIN
docs/_static/images/pipeline.png
vendored
|
Before Width: | Height: | Size: 107 KiB After Width: | Height: | Size: 37 KiB |
BIN
docs/_static/images/policy_table.svg
vendored
Normal file
|
After Width: | Height: | Size: 34 KiB |
BIN
docs/_static/images/pseudocode_off_policy.svg
vendored
Normal file
|
After Width: | Height: | Size: 95 KiB |
3
docs/_static/images/structure.svg
vendored
Normal file
|
After Width: | Height: | Size: 56 KiB |
BIN
docs/_static/images/tianshou-logo.png
vendored
|
Before Width: | Height: | Size: 44 KiB After Width: | Height: | Size: 12 KiB |
BIN
docs/_static/images/tic-tac-toe.png
vendored
|
Before Width: | Height: | Size: 28 KiB After Width: | Height: | Size: 12 KiB |
3
docs/_static/images/timelimit.svg
vendored
Normal file
|
After Width: | Height: | Size: 47 KiB |
@ -1,142 +0,0 @@
|
||||
tianshou.data
|
||||
=============
|
||||
|
||||
|
||||
Batch
|
||||
-----
|
||||
|
||||
.. autoclass:: tianshou.data.Batch
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
Buffer
|
||||
------
|
||||
|
||||
ReplayBuffer
|
||||
~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.data.ReplayBuffer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
PrioritizedReplayBuffer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.data.PrioritizedReplayBuffer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
HERReplayBuffer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.data.HERReplayBuffer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
ReplayBufferManager
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.data.ReplayBufferManager
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
PrioritizedReplayBufferManager
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.data.PrioritizedReplayBufferManager
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
HERReplayBufferManager
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.data.HERReplayBufferManager
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
VectorReplayBuffer
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.data.VectorReplayBuffer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
PrioritizedVectorReplayBuffer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.data.PrioritizedVectorReplayBuffer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
HERVectorReplayBuffer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.data.HERVectorReplayBuffer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
CachedReplayBuffer
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.data.CachedReplayBuffer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Collector
|
||||
---------
|
||||
|
||||
Collector
|
||||
~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.data.Collector
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
AsyncCollector
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.data.AsyncCollector
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
Utils
|
||||
-----
|
||||
|
||||
to_numpy
|
||||
~~~~~~~~
|
||||
|
||||
.. autofunction:: tianshou.data.to_numpy
|
||||
|
||||
to_torch
|
||||
~~~~~~~~
|
||||
|
||||
.. autofunction:: tianshou.data.to_torch
|
||||
|
||||
to_torch_as
|
||||
~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: tianshou.data.to_torch_as
|
||||
|
||||
SegmentTree
|
||||
~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.data.SegmentTree
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -1,122 +0,0 @@
|
||||
tianshou.env
|
||||
============
|
||||
|
||||
|
||||
VectorEnv
|
||||
---------
|
||||
|
||||
BaseVectorEnv
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.env.BaseVectorEnv
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
DummyVectorEnv
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.env.DummyVectorEnv
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
SubprocVectorEnv
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.env.SubprocVectorEnv
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
ShmemVectorEnv
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.env.ShmemVectorEnv
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
RayVectorEnv
|
||||
~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.env.RayVectorEnv
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
Wrapper
|
||||
-------
|
||||
|
||||
ContinuousToDiscrete
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.env.ContinuousToDiscrete
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
VectorEnvWrapper
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.env.VectorEnvWrapper
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
VectorEnvNormObs
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.env.VectorEnvNormObs
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
Worker
|
||||
------
|
||||
|
||||
EnvWorker
|
||||
~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.env.worker.EnvWorker
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
DummyEnvWorker
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.env.worker.DummyEnvWorker
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
SubprocEnvWorker
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.env.worker.SubprocEnvWorker
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
RayEnvWorker
|
||||
~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.env.worker.RayEnvWorker
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
Utils
|
||||
-----
|
||||
|
||||
PettingZooEnv
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.env.PettingZooEnv
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -1,7 +0,0 @@
|
||||
tianshou.exploration
|
||||
====================
|
||||
|
||||
.. automodule:: tianshou.exploration
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -1,176 +0,0 @@
|
||||
tianshou.policy
|
||||
===============
|
||||
|
||||
Base
|
||||
----
|
||||
|
||||
.. autoclass:: tianshou.policy.BasePolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.RandomPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Model-free
|
||||
----------
|
||||
|
||||
DQN Family
|
||||
~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.policy.DQNPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.BranchingDQNPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.C51Policy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.RainbowPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.QRDQNPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.IQNPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.FQFPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
On-policy
|
||||
~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.policy.PGPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.NPGPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.A2CPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.TRPOPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.PPOPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Off-policy
|
||||
~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.policy.DDPGPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.TD3Policy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.SACPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.REDQPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.DiscreteSACPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Imitation
|
||||
---------
|
||||
|
||||
.. autoclass:: tianshou.policy.ImitationPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.BCQPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.CQLPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.TD3BCPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.DiscreteBCQPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.DiscreteCQLPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.DiscreteCRRPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.GAILPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Model-based
|
||||
-----------
|
||||
|
||||
.. autoclass:: tianshou.policy.PSRLPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.ICMPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Multi-agent
|
||||
-----------
|
||||
|
||||
.. autoclass:: tianshou.policy.MultiAgentPolicyManager
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -1,36 +0,0 @@
|
||||
tianshou.trainer
|
||||
================
|
||||
|
||||
|
||||
On-policy
|
||||
---------
|
||||
|
||||
.. autoclass:: tianshou.trainer.OnpolicyTrainer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
Off-policy
|
||||
----------
|
||||
|
||||
.. autoclass:: tianshou.trainer.OffpolicyTrainer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
Offline
|
||||
-------
|
||||
|
||||
.. autoclass:: tianshou.trainer.OfflineTrainer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
utils
|
||||
-----
|
||||
|
||||
.. autofunction:: tianshou.trainer.test_episode
|
||||
|
||||
.. autofunction:: tianshou.trainer.gather_info
|
||||
@ -1,35 +0,0 @@
|
||||
tianshou.utils
|
||||
==============
|
||||
|
||||
.. automodule:: tianshou.utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
Pre-defined Networks
|
||||
--------------------
|
||||
|
||||
Common
|
||||
~~~~~~
|
||||
|
||||
.. automodule:: tianshou.utils.net.common
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Discrete
|
||||
~~~~~~~~
|
||||
|
||||
.. automodule:: tianshou.utils.net.discrete
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Continuous
|
||||
~~~~~~~~~~
|
||||
|
||||
.. automodule:: tianshou.utils.net.continuous
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -10,7 +10,7 @@ def module_template(module_qualname: str):
|
||||
module_name = module_qualname.split(".")[-1]
|
||||
title = module_name.replace("_", r"\_")
|
||||
return f"""{title}
|
||||
{"="*len(title)}
|
||||
{"=" * len(title)}
|
||||
|
||||
.. automodule:: {module_qualname}
|
||||
:members:
|
||||
@ -18,37 +18,16 @@ def module_template(module_qualname: str):
|
||||
"""
|
||||
|
||||
|
||||
def package_template(package_qualname: str):
|
||||
package_name = package_qualname.split(".")[-1]
|
||||
title = package_name.replace("_", r"\_")
|
||||
return f"""{title}
|
||||
{"="*len(title)}
|
||||
def index_template(package_name: str, doc_references: list[str] | None = None, text_prefix=""):
|
||||
doc_references = doc_references or ""
|
||||
if doc_references:
|
||||
doc_references = "\n" + "\n".join(f"* :doc:`{ref}`" for ref in doc_references) + "\n"
|
||||
|
||||
.. automodule:: {package_qualname}
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
.. toctree::
|
||||
:glob:
|
||||
|
||||
{package_name}/*
|
||||
"""
|
||||
|
||||
|
||||
def indexTemplate(package_name):
|
||||
title = package_name
|
||||
return f"""{title}
|
||||
{"="*len(title)}
|
||||
|
||||
.. automodule:: {package_name}
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
.. toctree::
|
||||
:glob:
|
||||
|
||||
*
|
||||
"""
|
||||
dirname = package_name.split(".")[-1]
|
||||
title = dirname.replace("_", r"\_")
|
||||
if title == "tianshou":
|
||||
title = "Tianshou API Reference"
|
||||
return f"{title}\n{'=' * len(title)}" + text_prefix + doc_references
|
||||
|
||||
|
||||
def write_to_file(content: str, path: str):
|
||||
@ -58,8 +37,17 @@ def write_to_file(content: str, path: str):
|
||||
os.chmod(path, 0o666)
|
||||
|
||||
|
||||
_SUBTITLE = (
|
||||
"\n Here is the autogenerated documentation of the Tianshou API. \n \n "
|
||||
"The Table of Contents to the left has the same structure as the "
|
||||
"repository's package code. The links at each page point to the submodules and subpackages. \n\n "
|
||||
"Enjoy scrolling through! \n"
|
||||
)
|
||||
|
||||
|
||||
def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""):
|
||||
"""Creates/updates documentation in form of rst files for modules and packages.
|
||||
|
||||
Does not delete any existing rst files. Thus, rst files for packages or modules that have been removed or renamed
|
||||
should be deleted by hand.
|
||||
|
||||
@ -79,7 +67,28 @@ def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""
|
||||
shutil.rmtree(rst_root)
|
||||
|
||||
base_package_name = package_prefix + os.path.basename(src_root)
|
||||
write_to_file(indexTemplate(base_package_name), os.path.join(rst_root, "index.rst"))
|
||||
|
||||
# TODO: reduce duplication with same logic for subpackages below
|
||||
files_in_dir = os.listdir(src_root)
|
||||
module_names = [f[:-3] for f in files_in_dir if f.endswith(".py") and not f.startswith("_")]
|
||||
subdir_refs = [
|
||||
os.path.join(f, "index")
|
||||
for f in files_in_dir
|
||||
if os.path.isdir(os.path.join(src_root, f)) and not f.startswith("_")
|
||||
]
|
||||
package_index_rst_path = os.path.join(
|
||||
rst_root,
|
||||
"index.rst",
|
||||
)
|
||||
log.info(f"Writing {package_index_rst_path}")
|
||||
write_to_file(
|
||||
index_template(
|
||||
base_package_name,
|
||||
doc_references=module_names + subdir_refs,
|
||||
text_prefix=_SUBTITLE,
|
||||
),
|
||||
package_index_rst_path,
|
||||
)
|
||||
|
||||
for root, dirnames, filenames in os.walk(src_root):
|
||||
if os.path.basename(root).startswith("_"):
|
||||
@ -91,11 +100,33 @@ def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""
|
||||
).replace(os.path.sep, ".")
|
||||
|
||||
for dirname in dirnames:
|
||||
if not dirname.startswith("_"):
|
||||
package_qualname = f"{base_package_qualname}.{dirname}"
|
||||
package_rst_path = os.path.join(rst_root, base_package_relpath, f"{dirname}.rst")
|
||||
log.info(f"Writing package documentation to {package_rst_path}")
|
||||
write_to_file(package_template(package_qualname), package_rst_path)
|
||||
if dirname.startswith("_"):
|
||||
log.debug(f"Skipping {dirname}")
|
||||
continue
|
||||
files_in_dir = os.listdir(os.path.join(root, dirname))
|
||||
module_names = [
|
||||
f[:-3] for f in files_in_dir if f.endswith(".py") and not f.startswith("_")
|
||||
]
|
||||
subdir_refs = [
|
||||
os.path.join(f, "index")
|
||||
for f in files_in_dir
|
||||
if os.path.isdir(os.path.join(root, dirname, f)) and not f.startswith("_")
|
||||
]
|
||||
if not module_names:
|
||||
log.debug(f"Skipping {dirname} as it does not contain any .py files")
|
||||
continue
|
||||
package_qualname = f"{base_package_qualname}.{dirname}"
|
||||
package_index_rst_path = os.path.join(
|
||||
rst_root,
|
||||
base_package_relpath,
|
||||
dirname,
|
||||
"index.rst",
|
||||
)
|
||||
log.info(f"Writing {package_index_rst_path}")
|
||||
write_to_file(
|
||||
index_template(package_qualname, doc_references=module_names + subdir_refs),
|
||||
package_index_rst_path,
|
||||
)
|
||||
|
||||
for filename in filenames:
|
||||
base_name, ext = os.path.splitext(filename)
|
||||
@ -114,8 +145,7 @@ if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
docs_root = Path(__file__).parent
|
||||
make_rst(
|
||||
docs_root / ".." / "tianshou" / "highlevel",
|
||||
docs_root / "api" / "tianshou.highlevel",
|
||||
docs_root / ".." / "tianshou",
|
||||
docs_root / "03_api",
|
||||
clean=True,
|
||||
package_prefix="tianshou.",
|
||||
)
|
||||
|
||||
100
docs/conf.py
@ -1,100 +0,0 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# This file only contains a selection of the most common options. For a full
|
||||
# list see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
# -- Path setup --------------------------------------------------------------
|
||||
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
#
|
||||
# import os
|
||||
# import sys
|
||||
# sys.path.insert(0, os.path.abspath('.'))
|
||||
|
||||
import sphinx_rtd_theme
|
||||
|
||||
import tianshou
|
||||
|
||||
# Get the version string
|
||||
version = tianshou.__version__
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "Tianshou"
|
||||
copyright = "2020, Tianshou contributors."
|
||||
author = "Tianshou contributors"
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = version
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.doctest",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinx.ext.coverage",
|
||||
# 'sphinx.ext.imgmath',
|
||||
"sphinx.ext.mathjax",
|
||||
"sphinx.ext.ifconfig",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx.ext.githubpages",
|
||||
"sphinxcontrib.bibtex",
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ["_templates"]
|
||||
source_suffix = [".rst"]
|
||||
master_doc = "index"
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
autodoc_default_options = {"special-members": "__len__, __call__, __getitem__, __setitem__"}
|
||||
autodoc_member_order = "bysource"
|
||||
bibtex_bibfiles = ["refs.bib"]
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ["_static"]
|
||||
|
||||
html_logo = "_static/images/tianshou-logo.png"
|
||||
|
||||
|
||||
def setup(app):
|
||||
app.add_js_file("https://cdn.jsdelivr.net/npm/vega@5.20.2")
|
||||
app.add_js_file("https://cdn.jsdelivr.net/npm/vega-lite@5.1.0")
|
||||
app.add_js_file("https://cdn.jsdelivr.net/npm/vega-embed@6.17.0")
|
||||
|
||||
app.add_js_file("js/copybutton.js")
|
||||
app.add_js_file("js/benchmark.js")
|
||||
app.add_css_file("css/style.css")
|
||||
|
||||
|
||||
# -- Extension configuration -------------------------------------------------
|
||||
|
||||
# -- Options for intersphinx extension ---------------------------------------
|
||||
|
||||
# Example configuration for intersphinx: refer to the Python standard library.
|
||||
# intersphinx_mapping = {'https://docs.python.org/3/': None}
|
||||
|
||||
# -- Options for todo extension ----------------------------------------------
|
||||
|
||||
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
||||
# todo_include_todos = False
|
||||
8
docs/create_toc.py
Normal file
@ -0,0 +1,8 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# This script provides a platform-independent way of making the jupyter-book call (used in pyproject.toml)
|
||||
toc_file = Path(__file__).parent / "_toc.yml"
|
||||
cmd = f"jupyter-book toc from-project docs -e .rst -e .md -e .ipynb >{toc_file}"
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
@ -52,7 +52,7 @@ Here is Tianshou's other features:
|
||||
* Support any type of environment state/action (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env`
|
||||
* Support :ref:`customize_training`
|
||||
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
|
||||
* Support :doc:`/tutorials/tictactoe`
|
||||
* Support :doc:`/01_tutorials/04_tictactoe`
|
||||
* Support both `TensorBoard <https://www.tensorflow.org/tensorboard>`_ and `W&B <https://wandb.ai/>`_ log tools
|
||||
* Support multi-GPU training :ref:`multi_gpu`
|
||||
* Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking
|
||||
@ -63,7 +63,8 @@ Here is Tianshou's other features:
|
||||
Installation
|
||||
------------
|
||||
|
||||
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_ and `conda-forge <https://github.com/conda-forge/tianshou-feedstock>`_. It requires Python >= 3.11.
|
||||
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_ and `conda-forge <https://github.com/conda-forge/tianshou-feedstock>`_. New releases
|
||||
(and the current state of the master branch) will require Python >= 3.11.
|
||||
|
||||
You can simply install Tianshou from PyPI with the following command:
|
||||
|
||||
@ -93,42 +94,6 @@ If no error occurs, you have successfully installed Tianshou.
|
||||
|
||||
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_.
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Tutorials
|
||||
|
||||
tutorials/get_started
|
||||
tutorials/dqn
|
||||
tutorials/concepts
|
||||
tutorials/batch
|
||||
tutorials/tictactoe
|
||||
tutorials/logger
|
||||
tutorials/benchmark
|
||||
tutorials/cheatsheet
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: API Docs
|
||||
|
||||
api/tianshou.data
|
||||
api/tianshou.env
|
||||
api/tianshou.policy
|
||||
api/tianshou.trainer
|
||||
api/tianshou.exploration
|
||||
api/tianshou.utils
|
||||
api/tianshou.highlevel/index
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Community
|
||||
|
||||
contributing
|
||||
contributor
|
||||
|
||||
|
||||
Indices and tables
|
||||
------------------
|
||||
|
||||
|
||||
10
docs/nbstripout.py
Normal file
@ -0,0 +1,10 @@
|
||||
"""Implements a platform-independent way of calling nbstripout (used in pyproject.toml)."""
|
||||
import glob
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
if __name__ == "__main__":
|
||||
docs_dir = Path(__file__).parent
|
||||
for path in glob.glob(str(docs_dir / "02_notebooks" / "*.ipynb")):
|
||||
cmd = f"nbstripout {path}"
|
||||
os.system(cmd)
|
||||
@ -1,10 +0,0 @@
|
||||
numba
|
||||
numpy>=1.20
|
||||
sphinx<7
|
||||
sphinxcontrib-bibtex
|
||||
sphinx_rtd_theme>=0.5.1
|
||||
tensorboard
|
||||
torch
|
||||
tqdm
|
||||
protobuf
|
||||
pettingzoo
|
||||
@ -188,4 +188,54 @@ MLP
|
||||
backpropagation
|
||||
dataclass
|
||||
superset
|
||||
subtype
|
||||
subdirectory
|
||||
picklable
|
||||
ShmemVectorEnv
|
||||
Github
|
||||
wandb
|
||||
jupyter
|
||||
img
|
||||
src
|
||||
parallelized
|
||||
infty
|
||||
venv
|
||||
venvs
|
||||
subproc
|
||||
bcq
|
||||
highlevel
|
||||
icm
|
||||
modelbased
|
||||
td
|
||||
psrl
|
||||
ddpg
|
||||
npg
|
||||
tf
|
||||
trpo
|
||||
crr
|
||||
pettingzoo
|
||||
multidiscrete
|
||||
vecbuf
|
||||
prio
|
||||
colab
|
||||
segtree
|
||||
multiagent
|
||||
mapolicy
|
||||
sensai
|
||||
sensAI
|
||||
docstrings
|
||||
superclass
|
||||
iterable
|
||||
functools
|
||||
str
|
||||
sklearn
|
||||
attr
|
||||
bc
|
||||
redq
|
||||
modelfree
|
||||
bdq
|
||||
util
|
||||
logp
|
||||
autogenerated
|
||||
subpackage
|
||||
subpackages
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
Get Started with Jupyter Notebook
|
||||
=================================
|
||||
|
||||
In this tutorial, we will use Google Colaboratory to show you the most basic usages of common building blocks in Tianshou. You will be guided step by step to see how different modules in Tianshou collaborate with each other to conduct a classic DRL experiment (PPO algorithm for CartPole-v0 environment).
|
||||
|
||||
- L0: `Overview <https://colab.research.google.com/drive/1yavOkfSTbyBD24-dyQzdETFN9YA7ioor?usp=sharing>`_
|
||||
- L1: `Batch <https://colab.research.google.com/drive/1uklagjDxYjJERS9gJvgbPnV1BtMuXvOR?usp=sharing>`_
|
||||
- L2: `Replay Buffer <https://colab.research.google.com/drive/1sfw-dDy02Gado-WuYlHAQsyWhZ33D1bd?usp=sharing>`_
|
||||
- L3: `Vectorized Environment <https://colab.research.google.com/drive/1ABk2BgjzvC4DZu1rDxGzd2Uqjo3FRLEy?usp=sharing>`_
|
||||
- L4: `Policy <https://colab.research.google.com/drive/1MhzYXtUEfnRrlAVSB3SR83r0HA5wds2i?usp=sharing>`_
|
||||
- L5: `Collector <https://colab.research.google.com/drive/1CvOTPiNXdSST04I75Wuyvy_hZ949zKHZ?usp=sharing>`_
|
||||
- L6: `Trainer <https://colab.research.google.com/drive/1qMsEiZZ8mh60ycbfoX-nYy6qMCnLkmZE?usp=sharing>`_
|
||||
- L7: `Experiment <https://colab.research.google.com/drive/1CieGncgbGCt2grx8Mzwb7YTmFB0AGJ0f?usp=sharing>`_
|
||||
4066
poetry.lock
generated
@ -39,18 +39,18 @@ tensorboard = "^2.5.0"
|
||||
torch = "^2.0.0, !=2.0.1, !=2.1.0"
|
||||
tqdm = "*"
|
||||
virtualenv = [
|
||||
# special sauce b/c of a flaky bug in poetry on windows
|
||||
# see https://github.com/python-poetry/poetry/issues/7611#issuecomment-1466478926
|
||||
{ version = "^20.4.3,!=20.4.5,!=20.4.6" },
|
||||
{ version = "<20.16.4", markers = "sys_platform == 'win32'" },
|
||||
# special sauce b/c of a flaky bug in poetry on windows
|
||||
# see https://github.com/python-poetry/poetry/issues/7611#issuecomment-1466478926
|
||||
{ version = "^20.4.3,!=20.4.5,!=20.4.6" },
|
||||
{ version = "<20.16.4", markers = "sys_platform == 'win32'" },
|
||||
]
|
||||
|
||||
# TODO: add versions
|
||||
atari_py = {version = "*", optional = true}
|
||||
envpool = {version = "^0.8.2", optional = true}
|
||||
mujoco_py = {version = "*", optional = true}
|
||||
opencv_python = {version = "*", optional = true}
|
||||
pybullet = {version = "*", optional = true}
|
||||
atari_py = { version = "*", optional = true }
|
||||
envpool = { version = "^0.8.2", optional = true }
|
||||
mujoco_py = { version = "*", optional = true }
|
||||
opencv_python = { version = "*", optional = true }
|
||||
pybullet = { version = "*", optional = true }
|
||||
|
||||
[tool.poetry.extras]
|
||||
atari = ["atari_py", "opencv-python"]
|
||||
@ -62,11 +62,14 @@ envpool = ["envpool"]
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^23.7.0"
|
||||
black = { version = "^23.7.0", extras = ["jupyter"] }
|
||||
docstring-parser = "^0.15"
|
||||
jinja2 = "*"
|
||||
jsonargparse = "^4.24.1"
|
||||
jupyter = "^1.0.0"
|
||||
jupyter-book = "^0.15.1"
|
||||
mypy = "^1.4.1"
|
||||
nbstripout = "^0.6.1"
|
||||
# networkx is used in a test
|
||||
networkx = "*"
|
||||
poethepoet = "^0.20.0"
|
||||
@ -77,11 +80,16 @@ pytest = "*"
|
||||
pytest-cov = "*"
|
||||
# Ray currently causes issues when installed on windows server 2022 in CI
|
||||
# If users want to use ray, they should install it manually.
|
||||
ray = {version = "^2", markers = "sys_platform != 'win32'"}
|
||||
ray = { version = "^2", markers = "sys_platform != 'win32'" }
|
||||
ruff = "^0.0.285"
|
||||
scipy = "*"
|
||||
sphinx = "<7"
|
||||
sphinx_rtd_theme = "*"
|
||||
sphinx-book-theme = "^1.0.1"
|
||||
sphinx-comments = "^0.0.3"
|
||||
sphinx-copybutton = "^0.5.2"
|
||||
sphinx-jupyterbook-latex = "^0.5.2"
|
||||
sphinx-togglebutton = "^0.3.2"
|
||||
sphinx-toolbox = "^3.5.0"
|
||||
sphinxcontrib-bibtex = "*"
|
||||
sphinxcontrib-spelling = "^8.0.0"
|
||||
wandb = "^0.12.0"
|
||||
@ -104,7 +112,7 @@ warn_redundant_casts = true
|
||||
warn_unreachable = true
|
||||
warn_unused_configs = true
|
||||
warn_unused_ignores = true
|
||||
exclude = "^build/|^docs/|^tianshou/utils/(string|logging).py|^temp*.py"
|
||||
exclude = "^build/|^docs/"
|
||||
|
||||
[tool.doc8]
|
||||
max-line-length = 1000
|
||||
@ -119,31 +127,31 @@ select = [
|
||||
]
|
||||
ignore = [
|
||||
"SIM118", # Needed b/c iter(batch) != iter(batch.keys()). See https://github.com/thu-ml/tianshou/issues/922
|
||||
"E501", # line too long. black does a good enough job
|
||||
"E741", # variable names like "l". this isn't a huge problem
|
||||
"B008", # do not perform function calls in argument defaults. we do this sometimes
|
||||
"B011", # assert false. we don't use python -O
|
||||
"B028", # we don't need explicit stacklevel for warnings
|
||||
"D100", "D101", "D102", "D104", "D105", "D107", "D203", "D213", "D401", "D402", "D106", "D205", # docstring stuff
|
||||
"G004", # logging (no f-strings)
|
||||
"RUF012", # disallows mutable class variables unless annotated
|
||||
"DTZ005", # we don't need that
|
||||
"RET505", # sacrifices visual discernability of control flow paths for brevity (regarding return statements)
|
||||
"E501", # line too long. black does a good enough job
|
||||
"E741", # variable names like "l". this isn't a huge problem
|
||||
"B008", # do not perform function calls in argument defaults. we do this sometimes
|
||||
"B011", # assert false. we don't use python -O
|
||||
"B028", # we don't need explicit stacklevel for warnings
|
||||
"D100", "D101", "D102", "D104", "D105", "D107", "D203", "D213", "D401", "D402", # docstring stuff
|
||||
"DTZ005", # we don't need that
|
||||
# remaining rules from https://github.com/psf/black/blob/main/.flake8 (except W503)
|
||||
# this is a simplified version of config, making vscode plugin happy
|
||||
"E402", "E501", "E701", "E731", "C408", "E203"
|
||||
"E402", "E501", "E701", "E731", "C408", "E203",
|
||||
# Logging statement uses f-string warning
|
||||
"G004",
|
||||
# Unnecessary `elif` after `return` statement
|
||||
"RET505"
|
||||
]
|
||||
unfixable = [
|
||||
"F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all
|
||||
"F601", # automatic fix might obscure issue
|
||||
"F602", # automatic fix might obscure issue
|
||||
"B018", # automatic fix might obscure issue
|
||||
"F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all
|
||||
"F601", # automatic fix might obscure issue
|
||||
"F602", # automatic fix might obscure issue
|
||||
"B018", # automatic fix might obscure issue
|
||||
]
|
||||
extend-fixable = [
|
||||
"F401", # unused import
|
||||
"B905" , # bugbear
|
||||
"F401", # unused import
|
||||
"B905", # bugbear
|
||||
]
|
||||
ignore-init-module-imports = true # without this, "unused" imports in __init__ will be auto-removed, breaking imports
|
||||
|
||||
target-version = "py311"
|
||||
|
||||
@ -155,12 +163,13 @@ max-complexity = 20
|
||||
"docs/**" = ["D103"]
|
||||
"examples/**" = ["D103"]
|
||||
|
||||
|
||||
[tool.poetry_bumpversion.file."tianshou/__init__.py"]
|
||||
|
||||
[tool.poetry-sort]
|
||||
move-optionals-to-bottom = true
|
||||
|
||||
[tool.poe.env]
|
||||
PYDEVD_DISABLE_FILE_VALIDATION="1"
|
||||
# keep relevant parts in sync with pre-commit
|
||||
[tool.poe.tasks] # https://github.com/nat-n/poethepoet
|
||||
test = "pytest test --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v --color=yes"
|
||||
@ -171,11 +180,15 @@ _black_format = "black ."
|
||||
_ruff_format = "ruff --fix ."
|
||||
lint = ["_black_check", "_ruff_check"]
|
||||
_poetry_install_sort_plugin = "poetry self add poetry-plugin-sort"
|
||||
_poery_sort = "poetry sort"
|
||||
format = ["_black_format", "_ruff_format", "_poetry_install_sort_plugin", "_poery_sort"]
|
||||
_poetry_sort = "poetry sort"
|
||||
clean-nbs = "python docs/nbstripout.py"
|
||||
format = ["_black_format", "_ruff_format", "_poetry_install_sort_plugin", "_poetry_sort"]
|
||||
_autogen_rst = "python docs/autogen_rst.py"
|
||||
_spellcheck = "sphinx-build -W -b spelling docs docs/_build"
|
||||
_doc_build = "sphinx-build -W -b html docs docs/_build"
|
||||
_sphinx_build = "sphinx-build -W -b html docs docs/_build"
|
||||
_jb_generate_toc = "python docs/create_toc.py"
|
||||
_jb_generate_config = "jupyter-book config sphinx docs/"
|
||||
doc-clean = "rm -rf docs/_build"
|
||||
doc-build = ["_autogen_rst", "_spellcheck", "_doc_build"]
|
||||
doc-generate-files = ["_autogen_rst", "_jb_generate_toc", "_jb_generate_config"]
|
||||
doc-spellcheck = "sphinx-build -W -b spelling docs docs/_build"
|
||||
doc-build = ["doc-generate-files", "doc-spellcheck", "_sphinx_build"]
|
||||
type-check = "mypy tianshou"
|
||||
|
||||
@ -95,10 +95,10 @@ def create_value(
|
||||
size: int,
|
||||
stack: bool = True,
|
||||
) -> Union["Batch", np.ndarray, torch.Tensor]:
|
||||
"""Create empty place-holders accroding to inst's shape.
|
||||
"""Create empty place-holders according to inst's shape.
|
||||
|
||||
:param stack: whether to stack or to concatenate. E.g. if inst has shape of
|
||||
(3, 5), size = 10, stack=True returns an np.ndarry with shape of (10, 3, 5),
|
||||
(3, 5), size = 10, stack=True returns an np.array with shape of (10, 3, 5),
|
||||
otherwise (10, 5)
|
||||
"""
|
||||
has_shape = isinstance(inst, np.ndarray | torch.Tensor)
|
||||
|
||||
@ -18,7 +18,7 @@ class ReplayBuffer:
|
||||
stores all the data in a batch with circular-queue style.
|
||||
|
||||
For the example usage of ReplayBuffer, please check out Section Buffer in
|
||||
:doc:`/tutorials/concepts`.
|
||||
:doc:`/01_tutorials/01_concepts`.
|
||||
|
||||
:param size: the maximum size of replay buffer.
|
||||
:param stack_num: the frame-stack sampling argument, should be greater than or
|
||||
|
||||
@ -68,7 +68,7 @@ class Collector:
|
||||
super().__init__()
|
||||
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
|
||||
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
|
||||
self.env = DummyVectorEnv([lambda: env]) # type: ignore
|
||||
self.env = DummyVectorEnv([lambda: env])
|
||||
else:
|
||||
self.env = env # type: ignore
|
||||
self.env_num = len(self.env)
|
||||
|
||||
2
tianshou/env/worker/subproc.py
vendored
@ -203,7 +203,7 @@ class SubprocEnvWorker(EnvWorker):
|
||||
obs = result[0]
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
return (obs, *result[1:]) # type: ignore
|
||||
return (obs, *result[1:])
|
||||
obs = result
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
|
||||
@ -81,8 +81,7 @@ class Environments(ToStringMixin, ABC):
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
) -> "Environments":
|
||||
"""Creates a suitable subtype instance from a factory function that creates a single instance and
|
||||
the type of environment (continuous/discrete).
|
||||
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
|
||||
|
||||
:param factory_fn: the factory for a single environment instance
|
||||
:param env_type: the type of environments created by `factory_fn`
|
||||
@ -115,8 +114,7 @@ class Environments(ToStringMixin, ABC):
|
||||
}
|
||||
|
||||
def set_persistence(self, *p: Persistence) -> None:
|
||||
"""Associates the given persistence handlers which may persist and restore
|
||||
environment-specific information.
|
||||
"""Associates the given persistence handlers which may persist and restore environment-specific information.
|
||||
|
||||
:param p: persistence handlers
|
||||
"""
|
||||
|
||||
@ -188,7 +188,9 @@ class Experiment(ToStringMixin):
|
||||
experiment_name: str | None = None,
|
||||
logger_run_id: str | None = None,
|
||||
) -> ExperimentResult:
|
||||
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
||||
"""Run the experiment and return the results.
|
||||
|
||||
:param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
||||
directory) where all results associated with the experiment will be saved.
|
||||
The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case
|
||||
a nested directory structure will be created.
|
||||
@ -327,6 +329,7 @@ class ExperimentBuilder:
|
||||
|
||||
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
||||
"""Allows to customize the logger factory to use.
|
||||
|
||||
If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used.
|
||||
|
||||
:param logger_factory: the factory to use
|
||||
@ -346,6 +349,7 @@ class ExperimentBuilder:
|
||||
|
||||
def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self:
|
||||
"""Allows to customize the gradient-based optimizer to use.
|
||||
|
||||
By default, :class:`OptimizerFactoryAdam` will be used with default parameters.
|
||||
|
||||
:param optim_factory: the optimizer factory
|
||||
@ -390,6 +394,7 @@ class ExperimentBuilder:
|
||||
|
||||
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
|
||||
"""Allows to define a callback that decides whether training shall stop early.
|
||||
|
||||
The callback receives the undiscounted returns of the testing result.
|
||||
|
||||
:param callback: the callback
|
||||
@ -435,6 +440,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
||||
|
||||
def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
|
||||
"""Allows to customize the actor component via the specification of a factory.
|
||||
|
||||
If this function is not called, a default actor factory (with default parameters) will be used.
|
||||
|
||||
:param actor_factory: the factory to use for the creation of the actor network
|
||||
@ -450,7 +456,9 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
||||
continuous_unbounded: bool = False,
|
||||
continuous_conditioned_sigma: bool = False,
|
||||
) -> Self:
|
||||
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
||||
"""Adds a default actor factory with the given parameters.
|
||||
|
||||
:param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
||||
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
|
||||
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
|
||||
shall be computed from the input; if False, sigma is an independent parameter.
|
||||
@ -479,9 +487,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
||||
|
||||
|
||||
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
||||
"""Specialization of the actor mixin where, in the continuous case, the actor component outputs
|
||||
Gaussian distribution parameters.
|
||||
"""
|
||||
"""Specialization of the actor mixin where, in the continuous case, the actor component outputs Gaussian distribution parameters."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(ContinuousActorType.GAUSSIAN)
|
||||
@ -494,6 +500,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
||||
continuous_conditioned_sigma: bool = False,
|
||||
) -> Self:
|
||||
"""Defines use of the default actor factory, allowing its parameters it to be customized.
|
||||
|
||||
The default actor factory uses an MLP-style architecture.
|
||||
|
||||
:param hidden_sizes: dimensions of hidden layers used by the network
|
||||
@ -523,6 +530,7 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor
|
||||
hidden_activation: ModuleType = torch.nn.ReLU,
|
||||
) -> Self:
|
||||
"""Defines use of the default actor factory, allowing its parameters it to be customized.
|
||||
|
||||
The default actor factory uses an MLP-style architecture.
|
||||
|
||||
:param hidden_sizes: dimensions of hidden layers used by the network
|
||||
@ -700,6 +708,7 @@ class _BuilderMixinCriticEnsembleFactory:
|
||||
|
||||
def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self:
|
||||
"""Specifies that the given factory shall be used for the critic ensemble.
|
||||
|
||||
If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used.
|
||||
|
||||
:param factory: the critic ensemble factory
|
||||
|
||||
@ -19,7 +19,9 @@ class LoggerFactory(ToStringMixin, ABC):
|
||||
run_id: str | None,
|
||||
config_dict: dict,
|
||||
) -> TLogger:
|
||||
""":param log_dir: path to the directory in which log data is to be stored
|
||||
"""Creates the logger.
|
||||
|
||||
:param log_dir: path to the directory in which log data is to be stored
|
||||
:param experiment_name: the name of the job, which may contain `os.path.sep`
|
||||
:param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger
|
||||
:param config_dict: a dictionary with data that is to be logged
|
||||
|
||||
@ -78,7 +78,7 @@ class ActorFactory(ModuleFactory, ToStringMixin, ABC):
|
||||
# do last policy layer scaling, this will make initial actions have (close to)
|
||||
# 0 mean and std, and will help boost performances,
|
||||
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
|
||||
for m in actor.mu.modules(): # type: ignore
|
||||
for m in actor.mu.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
m.weight.data.copy_(0.01 * m.weight.data)
|
||||
|
||||
@ -168,7 +168,9 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
|
||||
conditioned_sigma: bool = False,
|
||||
activation: ModuleType = nn.ReLU,
|
||||
):
|
||||
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
||||
"""For actors with Gaussian policies.
|
||||
|
||||
:param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
||||
:param unbounded: whether to apply tanh activation on final logits
|
||||
:param conditioned_sigma: if True, the standard deviation of continuous actions (sigma) is computed from the
|
||||
input; if False, sigma is an independent parameter
|
||||
@ -229,9 +231,7 @@ class ActorFactoryDiscreteNet(ActorFactory):
|
||||
|
||||
|
||||
class ActorFactoryTransientStorageDecorator(ActorFactory):
|
||||
"""Wraps an actor factory, storing the most recently created actor instance such that it
|
||||
can be retrieved.
|
||||
"""
|
||||
"""Wraps an actor factory, storing the most recently created actor instance such that it can be retrieved."""
|
||||
|
||||
def __init__(self, actor_factory: ActorFactory, actor_future: ActorFuture):
|
||||
self.actor_factory = actor_factory
|
||||
|
||||
@ -20,7 +20,9 @@ class OptimizerFactory(ABC, ToStringMixin):
|
||||
|
||||
class OptimizerFactoryTorch(OptimizerFactory):
|
||||
def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any):
|
||||
""":param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
|
||||
"""Factory for torch optimizers.
|
||||
|
||||
:param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
|
||||
which will be passed the module parameters, the learning rate as `lr` and the
|
||||
kwargs provided.
|
||||
:param kwargs: keyword arguments to provide at optimizer construction
|
||||
|
||||
@ -12,13 +12,12 @@ class NoiseFactory(ToStringMixin, ABC):
|
||||
|
||||
|
||||
class NoiseFactoryMaxActionScaledGaussian(NoiseFactory):
|
||||
"""Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value.
|
||||
|
||||
This factory can only be applied to continuous action spaces.
|
||||
"""
|
||||
|
||||
def __init__(self, std_fraction: float):
|
||||
""":param std_fraction: fraction (between 0 and 1) of the maximum action value that shall
|
||||
"""Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value.
|
||||
|
||||
This factory can only be applied to continuous action spaces.
|
||||
|
||||
:param std_fraction: fraction (between 0 and 1) of the maximum action value that shall
|
||||
be used as the standard deviation
|
||||
"""
|
||||
self.std_fraction = std_fraction
|
||||
|
||||
@ -43,7 +43,9 @@ class ParamTransformerData:
|
||||
|
||||
|
||||
class ParamTransformer(ABC):
|
||||
"""Transforms one or more parameters from the representation used by the high-level API
|
||||
"""Base class for parameter transformations from high to low-level API.
|
||||
|
||||
Transforms one or more parameters from the representation used by the high-level API
|
||||
to the representation required by the (low-level) policy implementation.
|
||||
It operates directly on a dictionary of keyword arguments, which is initially
|
||||
generated from the parameter dataclass (subclass of `Params`).
|
||||
@ -83,7 +85,9 @@ class ParamTransformerChangeValue(ParamTransformer):
|
||||
|
||||
|
||||
class ParamTransformerLRScheduler(ParamTransformer):
|
||||
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
|
||||
"""Transformer for learning rate scheduler params.
|
||||
|
||||
Transforms a key containing a learning rate scheduler factory (removed) into a key containing
|
||||
a learning rate scheduler (added) for the data member `optim`.
|
||||
"""
|
||||
|
||||
@ -100,12 +104,12 @@ class ParamTransformerLRScheduler(ParamTransformer):
|
||||
|
||||
|
||||
class ParamTransformerMultiLRScheduler(ParamTransformer):
|
||||
"""Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance
|
||||
if more than one factory is indeed given.
|
||||
"""
|
||||
|
||||
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str):
|
||||
""":param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
|
||||
"""Transforms several scheduler factories into a single scheduler.
|
||||
|
||||
The result may be a `MultipleLRSchedulers` instance if more than one factory is indeed given.
|
||||
|
||||
:param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
|
||||
:param key_scheduler: the key under which to store the resulting learning rate scheduler
|
||||
"""
|
||||
self.optim_key_list = optim_key_list
|
||||
|
||||
@ -57,9 +57,9 @@ class PersistenceGroup(Persistence):
|
||||
|
||||
|
||||
class PolicyPersistence:
|
||||
"""Handles persistence of the policy."""
|
||||
|
||||
class Mode(Enum):
|
||||
"""Mode of persistence."""
|
||||
|
||||
POLICY_STATE_DICT = "policy_state_dict"
|
||||
"""Persist only the policy's state dictionary. Note that for a policy to be restored from
|
||||
such a dictionary, it is necessary to first create a structurally equivalent object which can
|
||||
@ -81,7 +81,9 @@ class PolicyPersistence:
|
||||
enabled: bool = True,
|
||||
mode: Mode = Mode.POLICY,
|
||||
):
|
||||
""":param additional_persistence: a persistence instance which is to be invoked whenever
|
||||
"""Handles persistence of the policy.
|
||||
|
||||
:param additional_persistence: a persistence instance which is to be invoked whenever
|
||||
this object is used to persist/restore data
|
||||
:param enabled: whether persistence is enabled (restoration is always enabled)
|
||||
:param mode: the persistence mode
|
||||
|
||||
@ -51,7 +51,9 @@ class TrainerStopCallback(ToStringMixin, ABC):
|
||||
|
||||
@abstractmethod
|
||||
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
||||
""":param mean_rewards: the average undiscounted returns of the testing result
|
||||
"""Determines whether training should stop.
|
||||
|
||||
:param mean_rewards: the average undiscounted returns of the testing result
|
||||
:param context: the training context
|
||||
:return: True if the goal has been reached and training should stop, False otherwise
|
||||
"""
|
||||
|
||||
@ -159,8 +159,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
info: dict[str, Any] | None = None,
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
) -> np.ndarray | int:
|
||||
"""Get action as int (for discrete env's) or array (for continuous ones) from
|
||||
an env's observation and info.
|
||||
"""Get action as int (for discrete env's) or array (for continuous ones) from an env's observation and info.
|
||||
|
||||
:param obs: observation from the gym's env.
|
||||
:param info: information given by the gym's env.
|
||||
|
||||
@ -95,7 +95,7 @@ class ICMPolicy(BasePolicy):
|
||||
def set_eps(self, eps: float) -> None:
|
||||
"""Set the eps for epsilon-greedy exploration."""
|
||||
if hasattr(self.policy, "set_eps"):
|
||||
self.policy.set_eps(eps) # type: ignore
|
||||
self.policy.set_eps(eps)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -78,11 +78,11 @@ class BranchingDQNPolicy(DQNPolicy):
|
||||
# but it collides with an attr of the same name in base class
|
||||
@property
|
||||
def _action_per_branch(self) -> int:
|
||||
return self.model.action_per_branch # type: ignore
|
||||
return self.model.action_per_branch
|
||||
|
||||
@property
|
||||
def num_branches(self) -> int:
|
||||
return self.model.num_branches # type: ignore
|
||||
return self.model.num_branches
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
obs_next_batch = Batch(
|
||||
|
||||
@ -78,7 +78,7 @@ class DDPGPolicy(BasePolicy):
|
||||
action_bound_method=action_bound_method,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
if action_scaling and not np.isclose(actor.max_action, 1.0): # type: ignore
|
||||
if action_scaling and not np.isclose(actor.max_action, 1.0):
|
||||
warnings.warn(
|
||||
"action_scaling and action_bound_method are only intended to deal"
|
||||
"with unbounded model action space, but find actor model bound"
|
||||
|
||||
@ -77,7 +77,7 @@ class PGPolicy(BasePolicy):
|
||||
action_bound_method=action_bound_method,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
if action_scaling and not np.isclose(actor.max_action, 1.0): # type: ignore
|
||||
if action_scaling and not np.isclose(actor.max_action, 1.0):
|
||||
warnings.warn(
|
||||
"action_scaling and action_bound_method are only intended"
|
||||
"to deal with unbounded model action space, but find actor model"
|
||||
|
||||
@ -9,9 +9,9 @@ from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from logging import *
|
||||
from typing import Any
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
log = getLogger(__name__)
|
||||
log = getLogger(__name__) # type: ignore
|
||||
|
||||
LOG_DEFAULT_FORMAT = "%(levelname)-5s %(asctime)-15s %(name)s:%(funcName)s - %(message)s"
|
||||
|
||||
@ -20,18 +20,18 @@ LOG_DEFAULT_FORMAT = "%(levelname)-5s %(asctime)-15s %(name)s:%(funcName)s - %(m
|
||||
_logFormat = LOG_DEFAULT_FORMAT
|
||||
|
||||
|
||||
def remove_log_handlers():
|
||||
def remove_log_handlers() -> None:
|
||||
"""Removes all current log handlers."""
|
||||
logger = getLogger()
|
||||
while logger.hasHandlers():
|
||||
logger.removeHandler(logger.handlers[0])
|
||||
|
||||
|
||||
def remove_log_handler(handler):
|
||||
def remove_log_handler(handler: Handler) -> None:
|
||||
getLogger().removeHandler(handler)
|
||||
|
||||
|
||||
def is_log_handler_active(handler):
|
||||
def is_log_handler_active(handler: Handler) -> bool:
|
||||
"""Checks whether the given handler is active.
|
||||
|
||||
:param handler: a log handler
|
||||
@ -41,7 +41,7 @@ def is_log_handler_active(handler):
|
||||
|
||||
|
||||
# noinspection PyShadowingBuiltins
|
||||
def configure(format=LOG_DEFAULT_FORMAT, level=lg.DEBUG):
|
||||
def configure(format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG) -> None:
|
||||
"""Configures logging to stdout with the given format and log level,
|
||||
also configuring the default log levels of some overly verbose libraries as well as some pandas output options.
|
||||
|
||||
@ -56,8 +56,13 @@ def configure(format=LOG_DEFAULT_FORMAT, level=lg.DEBUG):
|
||||
getLogger("numba").setLevel(INFO)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# noinspection PyShadowingBuiltins
|
||||
def run_main(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEBUG):
|
||||
def run_main(
|
||||
main_fn: Callable[[], T], format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG
|
||||
) -> T | None:
|
||||
"""Configures logging with the given parameters, ensuring that any exceptions that occur during
|
||||
the execution of the given function are logged.
|
||||
Logs two additional messages, one before the execution of the function, and one upon its completion.
|
||||
@ -68,16 +73,19 @@ def run_main(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEB
|
||||
:return: the result of `main_fn`
|
||||
"""
|
||||
configure(format=format, level=level)
|
||||
log.info("Starting")
|
||||
log.info("Starting") # type: ignore
|
||||
try:
|
||||
result = main_fn()
|
||||
log.info("Done")
|
||||
log.info("Done") # type: ignore
|
||||
return result
|
||||
except Exception as e:
|
||||
log.error("Exception during script execution", exc_info=e)
|
||||
log.error("Exception during script execution", exc_info=e) # type: ignore
|
||||
return None
|
||||
|
||||
|
||||
def run_cli(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEBUG):
|
||||
def run_cli(
|
||||
main_fn: Callable[[], T], format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG
|
||||
) -> T | None:
|
||||
"""
|
||||
Configures logging with the given parameters and runs the given main function as a
|
||||
CLI using `jsonargparse` (which is configured to also parse attribute docstrings, such
|
||||
@ -107,14 +115,14 @@ _isAtExitReportFileLoggerRegistered = False
|
||||
_memoryLogStream: StringIO | None = None
|
||||
|
||||
|
||||
def _at_exit_report_file_logger():
|
||||
def _at_exit_report_file_logger() -> None:
|
||||
for path in _fileLoggerPaths:
|
||||
print(f"A log file was saved to {path}")
|
||||
|
||||
|
||||
def add_file_logger(path, register_atexit=True):
|
||||
def add_file_logger(path: str, register_atexit: bool = True) -> FileHandler:
|
||||
global _isAtExitReportFileLoggerRegistered
|
||||
log.info(f"Logging to {path} ...")
|
||||
log.info(f"Logging to {path} ...") # type: ignore
|
||||
handler = FileHandler(path)
|
||||
handler.setFormatter(Formatter(_logFormat))
|
||||
Logger.root.addHandler(handler)
|
||||
@ -138,21 +146,22 @@ def add_memory_logger() -> None:
|
||||
Logger.root.addHandler(handler)
|
||||
|
||||
|
||||
def get_memory_log():
|
||||
def get_memory_log() -> Any:
|
||||
""":return: the in-memory log (provided that `add_memory_logger` was called beforehand)"""
|
||||
assert _memoryLogStream is not None, "This should not have happened and might be a bug."
|
||||
return _memoryLogStream.getvalue()
|
||||
|
||||
|
||||
class FileLoggerContext:
|
||||
def __init__(self, path: str, enabled=True):
|
||||
def __init__(self, path: str, enabled: bool = True):
|
||||
self.enabled = enabled
|
||||
self.path = path
|
||||
self._log_handler = None
|
||||
self._log_handler: Handler | None = None
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> None:
|
||||
if self.enabled:
|
||||
self._log_handler = add_file_logger(self.path, register_atexit=False)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
if self._log_handler is not None:
|
||||
remove_log_handler(self._log_handler)
|
||||
|
||||
@ -370,13 +370,13 @@ class NoisyLinear(nn.Module):
|
||||
|
||||
# TODO: rename or change functionality? Usually sample is not an inplace operation...
|
||||
def sample(self) -> None:
|
||||
self.eps_p.copy_(self.f(self.eps_p)) # type: ignore
|
||||
self.eps_q.copy_(self.f(self.eps_q)) # type: ignore
|
||||
self.eps_p.copy_(self.f(self.eps_p))
|
||||
self.eps_q.copy_(self.f(self.eps_q))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.training:
|
||||
weight = self.mu_W + self.sigma_W * (self.eps_q.ger(self.eps_p)) # type: ignore
|
||||
bias = self.mu_bias + self.sigma_bias * self.eps_q.clone() # type: ignore
|
||||
weight = self.mu_W + self.sigma_W * (self.eps_q.ger(self.eps_p))
|
||||
bias = self.mu_bias + self.sigma_bias * self.eps_q.clone()
|
||||
else:
|
||||
weight = self.mu_W
|
||||
bias = self.mu_bias
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Copy of sensai.util.string from sensAI commit d7b4afcc89b4d2e922a816cb07dffde27f297354."""
|
||||
"""Copy of sensai.util.string from sensAI """
|
||||
# From commit commit d7b4afcc89b4d2e922a816cb07dffde27f297354
|
||||
|
||||
|
||||
import functools
|
||||
@ -10,22 +11,27 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Self,
|
||||
)
|
||||
|
||||
reCommaWhitespacePotentiallyBreaks = re.compile(r",\s+")
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
|
||||
class StringConverter(ABC):
|
||||
"""Abstraction for a string conversion mechanism."""
|
||||
|
||||
@abstractmethod
|
||||
def to_string(self, x) -> str:
|
||||
def to_string(self, x: Any) -> str:
|
||||
pass
|
||||
|
||||
|
||||
def dict_string(d: Mapping, brackets: str | None = None, converter: StringConverter = None):
|
||||
def dict_string(
|
||||
d: Mapping, brackets: str | None = None, converter: StringConverter | None = None
|
||||
) -> str:
|
||||
"""Converts a dictionary to a string of the form "<key>=<value>, <key>=<value>, ...", optionally enclosed
|
||||
by brackets.
|
||||
|
||||
@ -44,10 +50,10 @@ def dict_string(d: Mapping, brackets: str | None = None, converter: StringConver
|
||||
|
||||
def list_string(
|
||||
l: Iterable[Any],
|
||||
brackets="[]",
|
||||
brackets: str | None = "[]",
|
||||
quote: str | None = None,
|
||||
converter: StringConverter = None,
|
||||
):
|
||||
converter: StringConverter | None = None,
|
||||
) -> str:
|
||||
"""Converts a list or any other iterable to a string of the form "[<value>, <value>, ...]", optionally enclosed
|
||||
by different brackets or with the values quoted.
|
||||
|
||||
@ -59,7 +65,7 @@ def list_string(
|
||||
:return: the string representation
|
||||
"""
|
||||
|
||||
def item(x):
|
||||
def item(x: Any) -> str:
|
||||
x = to_string(x, converter=converter, context="list")
|
||||
if quote is not None:
|
||||
return quote + x + quote
|
||||
@ -74,11 +80,11 @@ def list_string(
|
||||
|
||||
|
||||
def to_string(
|
||||
x,
|
||||
converter: StringConverter = None,
|
||||
apply_converter_to_non_complex_objects=True,
|
||||
context=None,
|
||||
):
|
||||
x: Any,
|
||||
converter: StringConverter | None = None,
|
||||
apply_converter_to_non_complex_objects: bool = True,
|
||||
context: Any = None,
|
||||
) -> str:
|
||||
"""Converts the given object to a string, with proper handling of lists, tuples and dictionaries, optionally using a converter.
|
||||
The conversion also removes unwanted line breaks (as present, in particular, in sklearn's string representations).
|
||||
|
||||
@ -116,7 +122,7 @@ def to_string(
|
||||
raise
|
||||
|
||||
|
||||
def object_repr(obj, member_names_or_dict: list[str] | dict[str, Any]):
|
||||
def object_repr(obj: Any, member_names_or_dict: list[str] | dict[str, Any]) -> str:
|
||||
"""Creates a string representation for the given object based on the given members.
|
||||
|
||||
The string takes the form "ClassName[attr1=value1, attr2=value2, ...]"
|
||||
@ -128,9 +134,9 @@ def object_repr(obj, member_names_or_dict: list[str] | dict[str, Any]):
|
||||
return f"{obj.__class__.__name__}[{dict_string(members_dict)}]"
|
||||
|
||||
|
||||
def or_regex_group(allowed_names: Sequence[str]):
|
||||
def or_regex_group(allowed_names: Sequence[str]) -> str:
|
||||
""":param allowed_names: strings to include as literals in the regex
|
||||
:return: a regular expression string of the form (<name1>| ...|<nameN>), which any of the given names
|
||||
:return: a regular expression string of the form `(<name_1>| ...|<name_N>)`, which any of the given names
|
||||
"""
|
||||
allowed_names = [re.escape(name) for name in allowed_names]
|
||||
return r"(%s)" % "|".join(allowed_names)
|
||||
@ -179,20 +185,11 @@ class ToStringMixin:
|
||||
In such cases, override :meth:`_toStringIncludesForced` to add inclusions regardless of the semantics otherwise used along
|
||||
the class hierarchy.
|
||||
|
||||
.. document private functions
|
||||
.. automethod:: _tostring_class_name
|
||||
.. automethod:: _tostring_object_info
|
||||
.. automethod:: _tostring_excludes
|
||||
.. automethod:: _tostring_exclude_exceptions
|
||||
.. automethod:: _tostring_includes
|
||||
.. automethod:: _tostring_includes_forced
|
||||
.. automethod:: _tostring_additional_entries
|
||||
.. automethod:: _tostring_exclude_private
|
||||
"""
|
||||
|
||||
_TOSTRING_INCLUDE_ALL = "__all__"
|
||||
|
||||
def _tostring_class_name(self):
|
||||
def _tostring_class_name(self) -> str:
|
||||
""":return: the string use for <class name> in the string representation ``"<class name>[<object info]"``"""
|
||||
return type(self).__qualname__
|
||||
|
||||
@ -203,7 +200,7 @@ class ToStringMixin:
|
||||
exclude_exceptions: list[str] | None = None,
|
||||
include_forced: list[str] | None = None,
|
||||
additional_entries: dict[str, Any] | None = None,
|
||||
converter: StringConverter = None,
|
||||
converter: StringConverter | None = None,
|
||||
) -> str:
|
||||
"""Creates a string of the class attributes, with optional exclusions/inclusions/additions.
|
||||
Exclusions take precedence over inclusions.
|
||||
@ -217,7 +214,7 @@ class ToStringMixin:
|
||||
:return: a string containing entry/property names and values
|
||||
"""
|
||||
|
||||
def mklist(x):
|
||||
def mklist(x: Any) -> list[str]:
|
||||
if x is None:
|
||||
return []
|
||||
if isinstance(x, str):
|
||||
@ -229,7 +226,7 @@ class ToStringMixin:
|
||||
include_forced = mklist(include_forced)
|
||||
exclude_exceptions = mklist(exclude_exceptions)
|
||||
|
||||
def is_excluded(k):
|
||||
def is_excluded(k: Any) -> bool:
|
||||
if k in include_forced or k in exclude_exceptions:
|
||||
return False
|
||||
if k in exclude:
|
||||
@ -336,17 +333,17 @@ class ToStringMixin:
|
||||
"""
|
||||
return []
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self._tostring_class_name()}[{self._tostring_object_info()}]"
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
info = f"id={id(self)}"
|
||||
property_info = self._tostring_object_info()
|
||||
if len(property_info) > 0:
|
||||
info += ", " + property_info
|
||||
return f"{self._tostring_class_name()}[{info}]"
|
||||
|
||||
def pprint(self, file=sys.stdout):
|
||||
def pprint(self, file: Any = sys.stdout) -> None:
|
||||
"""Prints a prettily formatted string representation of the object (with line breaks and indentations)
|
||||
to ``stdout`` or the given file.
|
||||
|
||||
@ -371,7 +368,7 @@ class ToStringMixin:
|
||||
""":param handled_objects: objects which are initially assumed to have been handled already"""
|
||||
self._handled_to_string_mixin_ids = {id(o) for o in handled_objects}
|
||||
|
||||
def to_string(self, x) -> str:
|
||||
def to_string(self, x: Any) -> str:
|
||||
if isinstance(x, ToStringMixin):
|
||||
oid = id(x)
|
||||
if oid in self._handled_to_string_mixin_ids:
|
||||
@ -396,17 +393,17 @@ class ToStringMixin:
|
||||
# methods where we assume that they could transitively call _toStringProperties (others are assumed not to)
|
||||
TOSTRING_METHODS_TRANSITIVELY_CALLING_TOSTRINGPROPERTIES = {"_tostring_object_info"}
|
||||
|
||||
def __init__(self, x: "ToStringMixin", converter):
|
||||
def __init__(self, x: "ToStringMixin", converter: Any) -> None:
|
||||
self.x = x
|
||||
self.converter = converter
|
||||
|
||||
def _tostring_properties(self, *args, **kwargs):
|
||||
return self.x._tostring_properties(*args, **kwargs, converter=self.converter)
|
||||
def _tostring_properties(self, *args: Any, **kwargs: Any) -> str:
|
||||
return self.x._tostring_properties(*args, **kwargs, converter=self.converter) # type: ignore[misc]
|
||||
|
||||
def _tostring_class_name(self):
|
||||
def _tostring_class_name(self) -> str:
|
||||
return self.x._tostring_class_name()
|
||||
|
||||
def __getattr__(self, attr: str):
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if attr.startswith(
|
||||
"_tostring",
|
||||
): # ToStringMixin method which we may bind to use this proxy to ensure correct transitive call
|
||||
@ -420,11 +417,13 @@ class ToStringMixin:
|
||||
else:
|
||||
return getattr(self.x, attr)
|
||||
|
||||
def __str__(self: "ToStringMixin"):
|
||||
return ToStringMixin.__str__(self)
|
||||
def __str__(self) -> str:
|
||||
return ToStringMixin.__str__(self) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def pretty_string_repr(s: Any, initial_indentation_level=0, indentation_string=" "):
|
||||
def pretty_string_repr(
|
||||
s: Any, initial_indentation_level: int = 0, indentation_string: str = " "
|
||||
) -> str:
|
||||
"""Creates a pretty string representation (using indentations) from the given object/string representation (as generated, for example, via
|
||||
ToStringMixin). An indentation level is added for every opening bracket.
|
||||
|
||||
@ -439,16 +438,16 @@ def pretty_string_repr(s: Any, initial_indentation_level=0, indentation_string="
|
||||
result = indentation_string * indent
|
||||
i = 0
|
||||
|
||||
def nl():
|
||||
def nl() -> None:
|
||||
nonlocal result
|
||||
result += "\n" + (indentation_string * indent)
|
||||
|
||||
def take(cnt=1):
|
||||
def take(cnt: int = 1) -> None:
|
||||
nonlocal result, i
|
||||
result += s[i : i + cnt]
|
||||
i += cnt
|
||||
|
||||
def find_matching(j):
|
||||
def find_matching(j: int) -> int | None:
|
||||
start = j
|
||||
op = s[j]
|
||||
cl = {"[": "]", "(": ")", "'": "'"}[s[j]]
|
||||
@ -499,17 +498,18 @@ def pretty_string_repr(s: Any, initial_indentation_level=0, indentation_string="
|
||||
class TagBuilder:
|
||||
"""Assists in building strings made up of components that are joined via a glue string."""
|
||||
|
||||
def __init__(self, *initial_components: str, glue="_"):
|
||||
def __init__(self, *initial_components: str, glue: str = "_"):
|
||||
""":param initial_components: initial components to always include at the beginning
|
||||
:param glue: the glue string which joins components
|
||||
"""
|
||||
self.glue = glue
|
||||
self.components = list(initial_components)
|
||||
|
||||
def with_component(self, component: str):
|
||||
def with_component(self, component: str) -> Self:
|
||||
self.components.append(component)
|
||||
return self
|
||||
|
||||
def with_conditional(self, cond: bool, component: str):
|
||||
def with_conditional(self, cond: bool, component: str) -> Self:
|
||||
"""Conditionally adds the given component.
|
||||
|
||||
:param cond: the condition
|
||||
@ -520,7 +520,7 @@ class TagBuilder:
|
||||
self.components.append(component)
|
||||
return self
|
||||
|
||||
def with_alternative(self, cond: bool, true_component: str, false_component: str):
|
||||
def with_alternative(self, cond: bool, true_component: str, false_component: str) -> Self:
|
||||
"""Adds a component depending on a condition.
|
||||
|
||||
:param cond: the condition
|
||||
@ -531,6 +531,6 @@ class TagBuilder:
|
||||
self.components.append(true_component if cond else false_component)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
def build(self) -> str:
|
||||
""":return: the string (with all components joined)"""
|
||||
return self.glue.join(self.components)
|
||||
|
||||