Add notebooks from ./docs/tutorials/get_started.rst to ./notebooks
This commit is contained in:
parent
8d3d1f164b
commit
6b6ce0fdf1
359
notebooks/L0_overview.ipynb
Normal file
359
notebooks/L0_overview.ipynb
Normal file
@ -0,0 +1,359 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"language": "python",
|
||||
"display_name": "Python 3 (ipykernel)"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Overview\n",
|
||||
"In this toturial, we use guide you step by step to show you how the most basic modules in Tianshou work and how they collaborate with each other to conduct a classic DRL experiment."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "r7aE6Rq3cAEE"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Run the code\n",
|
||||
"Before we get started, we must first install Tianshou's library and Gym environment by running the commands below. Here I choose a specific version of Tianshou(0.4.8) which is the latest as of the time writing this toturial. APIs in differet versions may vary a little bit but most are the same. Feel free to use other versions in your own project."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "1_mLTSEIcY2c"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"id": "qvplhjduVDs6",
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-10-12T15:51:01.680688825Z",
|
||||
"start_time": "2023-10-12T15:48:15.090023052Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Collecting tianshou==0.4.8\r\n",
|
||||
" Downloading tianshou-0.4.8-py3-none-any.whl (150 kB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m150.4/150.4 kB\u001B[0m \u001B[31m3.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0ma \u001B[36m0:00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting gym>=0.15.4 (from tianshou==0.4.8)\r\n",
|
||||
" Downloading gym-0.26.2.tar.gz (721 kB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m721.7/721.7 kB\u001B[0m \u001B[31m11.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0ma \u001B[36m0:00:01\u001B[0m\r\n",
|
||||
"\u001B[?25h Installing build dependencies ... \u001B[?25ldone\r\n",
|
||||
"\u001B[?25h Getting requirements to build wheel ... \u001B[?25ldone\r\n",
|
||||
"\u001B[?25h Preparing metadata (pyproject.toml) ... \u001B[?25ldone\r\n",
|
||||
"\u001B[?25hRequirement already satisfied: tqdm in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (4.66.1)\r\n",
|
||||
"Requirement already satisfied: numpy>1.16.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (1.24.4)\r\n",
|
||||
"Requirement already satisfied: tensorboard>=2.5.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (2.14.1)\r\n",
|
||||
"Requirement already satisfied: torch>=1.4.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (2.1.0)\r\n",
|
||||
"Requirement already satisfied: numba>=0.51.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (0.57.1)\r\n",
|
||||
"Requirement already satisfied: h5py>=2.10.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (3.10.0)\r\n",
|
||||
"Requirement already satisfied: cloudpickle>=1.2.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym>=0.15.4->tianshou==0.4.8) (2.2.1)\r\n",
|
||||
"Collecting gym-notices>=0.0.4 (from gym>=0.15.4->tianshou==0.4.8)\r\n",
|
||||
" Downloading gym_notices-0.0.8-py3-none-any.whl (3.0 kB)\r\n",
|
||||
"Requirement already satisfied: llvmlite<0.41,>=0.40.0dev0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from numba>=0.51.0->tianshou==0.4.8) (0.40.1)\r\n",
|
||||
"Requirement already satisfied: absl-py>=0.4 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (2.0.0)\r\n",
|
||||
"Requirement already satisfied: grpcio>=1.48.2 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (1.59.0)\r\n",
|
||||
"Requirement already satisfied: google-auth<3,>=1.6.3 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (2.23.3)\r\n",
|
||||
"Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (1.0.0)\r\n",
|
||||
"Requirement already satisfied: markdown>=2.6.8 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (3.5)\r\n",
|
||||
"Requirement already satisfied: protobuf>=3.19.6 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (3.20.3)\r\n",
|
||||
"Requirement already satisfied: requests<3,>=2.21.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (2.31.0)\r\n",
|
||||
"Requirement already satisfied: setuptools>=41.0.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (68.2.2)\r\n",
|
||||
"Requirement already satisfied: six>1.9 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (1.16.0)\r\n",
|
||||
"Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (0.7.1)\r\n",
|
||||
"Requirement already satisfied: werkzeug>=1.0.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (3.0.0)\r\n",
|
||||
"Requirement already satisfied: filelock in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (3.12.4)\r\n",
|
||||
"Requirement already satisfied: typing-extensions in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (4.8.0)\r\n",
|
||||
"Requirement already satisfied: sympy in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (1.12)\r\n",
|
||||
"Requirement already satisfied: networkx in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (3.1)\r\n",
|
||||
"Requirement already satisfied: jinja2 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (3.1.2)\r\n",
|
||||
"Requirement already satisfied: fsspec in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (2023.9.2)\r\n",
|
||||
"Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m23.7/23.7 MB\u001B[0m \u001B[31m17.1 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m823.6/823.6 kB\u001B[0m \u001B[31m20.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m14.1/14.1 MB\u001B[0m \u001B[31m14.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Obtaining dependency information for nvidia-cudnn-cu12==8.9.2.26 from https://files.pythonhosted.org/packages/ff/74/a2e2be7fb83aaedec84f391f082cf765dfb635e7caa9b49065f73e4835d8/nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata\r\n",
|
||||
" Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\r\n",
|
||||
"Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m410.6/410.6 MB\u001B[0m \u001B[31m5.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m121.6/121.6 MB\u001B[0m \u001B[31m9.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m56.5/56.5 MB\u001B[0m \u001B[31m13.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m124.2/124.2 MB\u001B[0m \u001B[31m10.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m196.0/196.0 MB\u001B[0m \u001B[31m8.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-nccl-cu12==2.18.1 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl (209.8 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m209.8/209.8 MB\u001B[0m \u001B[31m8.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-nvtx-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m99.1/99.1 kB\u001B[0m \u001B[31m28.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting triton==2.1.0 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Obtaining dependency information for triton==2.1.0 from https://files.pythonhosted.org/packages/5c/c1/54fffb2eb13d293d9a429fead3646752ea190de0229bcf3d591ba2481263/triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata\r\n",
|
||||
" Downloading triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.3 kB)\r\n",
|
||||
"Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Obtaining dependency information for nvidia-nvjitlink-cu12 from https://files.pythonhosted.org/packages/0a/f8/5193b57555cbeecfdb6ade643df0d4218cc6385485492b6e2f64ceae53bb/nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl.metadata\r\n",
|
||||
" Downloading nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\r\n",
|
||||
"Requirement already satisfied: cachetools<6.0,>=2.0.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (5.3.1)\r\n",
|
||||
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (0.3.0)\r\n",
|
||||
"Requirement already satisfied: rsa<5,>=3.1.4 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (4.9)\r\n",
|
||||
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard>=2.5.0->tianshou==0.4.8) (1.3.1)\r\n",
|
||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (3.3.0)\r\n",
|
||||
"Requirement already satisfied: idna<4,>=2.5 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (3.4)\r\n",
|
||||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (2.0.6)\r\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (2023.7.22)\r\n",
|
||||
"Requirement already satisfied: MarkupSafe>=2.1.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from werkzeug>=1.0.1->tensorboard>=2.5.0->tianshou==0.4.8) (2.1.3)\r\n",
|
||||
"Requirement already satisfied: mpmath>=0.19 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from sympy->torch>=1.4.0->tianshou==0.4.8) (1.3.0)\r\n",
|
||||
"Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (0.5.0)\r\n",
|
||||
"Requirement already satisfied: oauthlib>=3.0.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard>=2.5.0->tianshou==0.4.8) (3.2.2)\r\n",
|
||||
"Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m731.7/731.7 MB\u001B[0m \u001B[31m3.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m:00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hDownloading triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (89.2 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m89.2/89.2 MB\u001B[0m \u001B[31m13.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hDownloading nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl (20.2 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m20.2/20.2 MB\u001B[0m \u001B[31m15.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hBuilding wheels for collected packages: gym\r\n",
|
||||
" Building wheel for gym (pyproject.toml) ... \u001B[?25ldone\r\n",
|
||||
"\u001B[?25h Created wheel for gym: filename=gym-0.26.2-py3-none-any.whl size=827621 sha256=612698033ee83c54db52d872001a111f5f0adf14dd996065edff561305ac2266\r\n",
|
||||
" Stored in directory: /home/ccagnetta/.cache/pip/wheels/1c/77/9e/9af5470201a0b0543937933ee99ba884cd237d2faefe8f4d37\r\n",
|
||||
"Successfully built gym\r\n",
|
||||
"Installing collected packages: gym-notices, triton, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, gym, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, tianshou\r\n",
|
||||
" Attempting uninstall: triton\r\n",
|
||||
" Found existing installation: triton 2.0.0\r\n",
|
||||
" Uninstalling triton-2.0.0:\r\n",
|
||||
" Successfully uninstalled triton-2.0.0\r\n",
|
||||
" Attempting uninstall: tianshou\r\n",
|
||||
" Found existing installation: tianshou 0.5.1\r\n",
|
||||
" Uninstalling tianshou-0.5.1:\r\n",
|
||||
" Successfully uninstalled tianshou-0.5.1\r\n",
|
||||
"Successfully installed gym-0.26.2 gym-notices-0.0.8 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.18.1 nvidia-nvjitlink-cu12-12.2.140 nvidia-nvtx-cu12-12.1.105 tianshou-0.4.8 triton-2.1.0\r\n",
|
||||
"Requirement already satisfied: gym in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (0.26.2)\r\n",
|
||||
"Requirement already satisfied: numpy>=1.18.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym) (1.24.4)\r\n",
|
||||
"Requirement already satisfied: cloudpickle>=1.2.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym) (2.2.1)\r\n",
|
||||
"Requirement already satisfied: gym-notices>=0.0.4 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym) (0.0.8)\r\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!pip install tianshou==0.4.8\n",
|
||||
"!pip install gym"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Below is a short script that use a certain DRL algorithm (PPO) to solve the classic CartPole-v0\n",
|
||||
"problem in Gym. Simply run it and **don't worry** if you can't understand the code very well. That is\n",
|
||||
"exactly what this tutorial is for.\n",
|
||||
"\n",
|
||||
"If the script ends normally, you will see the evaluation result printed out before the first\n",
|
||||
"epoch is done."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "IcFNmCjYeIIU"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import gym\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from tianshou.data import Collector, VectorReplayBuffer\n",
|
||||
"from tianshou.env import DummyVectorEnv\n",
|
||||
"from tianshou.policy import PPOPolicy\n",
|
||||
"from tianshou.trainer import onpolicy_trainer\n",
|
||||
"from tianshou.utils.net.common import ActorCritic, Net\n",
|
||||
"from tianshou.utils.net.discrete import Actor, Critic\n",
|
||||
"\n",
|
||||
"import warnings\n",
|
||||
"warnings.filterwarnings('ignore')\n",
|
||||
"\n",
|
||||
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
||||
"\n",
|
||||
"# environments\n",
|
||||
"env = gym.make('CartPole-v0')\n",
|
||||
"train_envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(20)])\n",
|
||||
"test_envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)])\n",
|
||||
"\n",
|
||||
"# model & optimizer\n",
|
||||
"net = Net(env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n",
|
||||
"actor = Actor(net, env.action_space.n, device=device).to(device)\n",
|
||||
"critic = Critic(net, device=device).to(device)\n",
|
||||
"actor_critic = ActorCritic(actor, critic)\n",
|
||||
"optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)\n",
|
||||
"\n",
|
||||
"# PPO policy\n",
|
||||
"dist = torch.distributions.Categorical\n",
|
||||
"policy = PPOPolicy(actor, critic, optim, dist, action_space=env.action_space, deterministic_eval=True)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# collector\n",
|
||||
"train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)))\n",
|
||||
"test_collector = Collector(policy, test_envs)\n",
|
||||
"\n",
|
||||
"# trainer\n",
|
||||
"result = onpolicy_trainer(\n",
|
||||
" policy,\n",
|
||||
" train_collector,\n",
|
||||
" test_collector,\n",
|
||||
" max_epoch=10,\n",
|
||||
" step_per_epoch=50000,\n",
|
||||
" repeat_per_collect=10,\n",
|
||||
" episode_per_test=10,\n",
|
||||
" batch_size=256,\n",
|
||||
" step_per_collect=2000,\n",
|
||||
" stop_fn=lambda mean_reward: mean_reward >= 195,\n",
|
||||
")\n",
|
||||
"print(result)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "pxY_ZbGmkr6_",
|
||||
"outputId": "b792fc24-f42c-426a-9d83-fe1a4f3f91f1"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"Epoch #1: 50001it [00:19, 2529.50it/s, env_step=50000, len=87, loss=80.895, loss/clip=-0.009, loss/ent=0.566, loss/vf=161.818, n/ep=15, n/st=2000, rew=87.27] \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Epoch #1: test_reward: 200.000000 ± 0.000000, best_reward: 200.000000 ± 0.000000 in #1\n",
|
||||
"{'duration': '20.26s', 'train_time/model': '13.75s', 'test_step': 2159, 'test_episode': 20, 'test_time': '0.48s', 'test_speed': '4496.33 step/s', 'best_reward': 200.0, 'best_result': '200.00 ± 0.00', 'train_step': 50000, 'train_episode': 944, 'train_time/collector': '6.03s', 'train_speed': '2527.97 step/s'}\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Let's watch its performance!\n",
|
||||
"policy.eval()\n",
|
||||
"result = test_collector.collect(n_episode=1, render=False)\n",
|
||||
"print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "G9YEQptYvCgx",
|
||||
"outputId": "2a9b5b22-be50-4bb7-ae93-af7e65e7442a"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Final reward: 200.0, length: 200.0\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Tutorial Introduction\n",
|
||||
"\n",
|
||||
"A common DRL experiment as is shown above may require many components to work together. The agent, the\n",
|
||||
"environment (possibly parallelized ones), the replay buffer and the trainer all work together to complete a\n",
|
||||
"training task.\n",
|
||||
"\n",
|
||||
"<div align=center>\n",
|
||||
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/pipeline.png\", width=500>\n",
|
||||
"\n",
|
||||
"</div>\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "xFYlcPo8fpPU"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"In Tianshou, all of these main components are factored out as different building blocks, which you\n",
|
||||
"can use to create your own algorithm and finish your own experiment.\n",
|
||||
"\n",
|
||||
"Buiding blocks may include:\n",
|
||||
"- Batch\n",
|
||||
"- Replay Buffer\n",
|
||||
"- Vectorized Environment Wrapper\n",
|
||||
"- Policy (the agent and the training algorithm)\n",
|
||||
"- Data Collector\n",
|
||||
"- Trainer\n",
|
||||
"- Logger\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Check this [webpage](https://tianshou.readthedocs.io/en/master/tutorials/dqn.html) to find jupter-notebook-style tutorials that will guide you through all these\n",
|
||||
"modules one by one. You can also read the [documentation](https://tianshou.readthedocs.io/en/master/) of Tianshou for more detailed explanation and\n",
|
||||
"advanced usages."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "kV_uOyimj-bk"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Further reading"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "S0mNKwH9i6Ek"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## What if I am not familar with the PPO algorithm itself?\n",
|
||||
"As for the DRL algorithms themselves, we will refer you to the [Spinning up documentation](https://spinningup.openai.com/en/latest/algorithms/ppo.html), where they provide\n",
|
||||
"plenty of resources and guides if you want to study the DRL algorithms. In Tianshou's toturials, we will\n",
|
||||
"focus on the usages of different modules, but not the algorithms themselves."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "M3NPSUnAov4L"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
716
notebooks/L1_Batch.ipynb
Normal file
716
notebooks/L1_Batch.ipynb
Normal file
@ -0,0 +1,716 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Remember to install tianshou first\n",
|
||||
"!pip install tianshou==0.4.8\n",
|
||||
"!pip install gym"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "sHZTxH6m2FpG"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Overview\n",
|
||||
"In this tutorial, we will introduce the **Batch** to you, which is the most basic data structure in Tianshou. You can simply considered Batch as a numpy version of python dictionary."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "69y6AHvq1S3f"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"from tianshou.data import Batch\n",
|
||||
"data = Batch(a=4, b=[5, 5], c='2312312', d=('a', -2, -3))\n",
|
||||
"print(data)\n",
|
||||
"print(data.b)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "NkfiIe_y2FI-",
|
||||
"outputId": "5008275f-8f77-489a-af64-b35af4448589"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Batch(\n",
|
||||
" a: array(4),\n",
|
||||
" b: array([5, 5]),\n",
|
||||
" c: '2312312',\n",
|
||||
" d: array(['a', '-2', '-3'], dtype=object),\n",
|
||||
")\n",
|
||||
"[5 5]\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"A batch is simply a dictionary which stores all passed in data as key-value pairs, and automatically turns the value into a numpy array if possible.\n",
|
||||
"\n",
|
||||
"## Why we need Batch in Tianshou?\n",
|
||||
"The motivation behind the implementation of Batch module is simple. In DRL, you need to handle a lot of dictionary-format data. For instance most algorithms would reuqire you to store state, action, and reward data for every step when interacting with the environment. All these data can be organised as a dictionary and a Batch module helps Tianshou unify the interface of a diverse set of algorithms. Plus, Batch supports advanced indexing, concantenation and splitting, formatting print just like any other numpy array, which may be very helpful for developers.\n",
|
||||
"<div align=center>\n",
|
||||
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/concepts_arch.png\", title=\"Data flow is converted into a Batch in Tianshou\">\n",
|
||||
"\n",
|
||||
"<a> Data flow is converted into a Batch in Tianshou </a>\n",
|
||||
"</div>\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "S6e6OuXe3UT-"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Basic Usages"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "_Xenx64M9HhV"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Initialisation\n",
|
||||
"Batch can be converted directly from a python dictionary, and all data structure will be converted to numpy array if possible."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "4YGX_f1Z9Uil"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# converted from a python library\n",
|
||||
"print(\"========================================\")\n",
|
||||
"batch1 = Batch({'a': [4, 4], 'b': (5, 5)})\n",
|
||||
"print(batch1)\n",
|
||||
"\n",
|
||||
"# initialisation of batch2 is equivalent to batch1\n",
|
||||
"print(\"========================================\")\n",
|
||||
"batch2 = Batch(a=[4, 4], b=(5, 5))\n",
|
||||
"print(batch2)\n",
|
||||
"\n",
|
||||
"# the dictionary can be nested, and it will be turned into a nested Batch\n",
|
||||
"print(\"========================================\")\n",
|
||||
"data = {\n",
|
||||
" 'action': np.array([1.0, 2.0, 3.0]),\n",
|
||||
" 'reward': 3.66,\n",
|
||||
" 'obs': {\n",
|
||||
" \"rgb_obs\": np.zeros((3, 3)),\n",
|
||||
" \"flatten_obs\": np.ones(5),\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"batch3 = Batch(data, extra=\"extra_string\")\n",
|
||||
"print(batch3)\n",
|
||||
"# batch3.obs is also a Batch\n",
|
||||
"print(type(batch3.obs))\n",
|
||||
"print(batch3.obs.rgb_obs)\n",
|
||||
"\n",
|
||||
"# a list of dictionary/Batch will automatically be concatenated/stacked, providing convenience if you\n",
|
||||
"# want to use parallelized environments to collect data.\n",
|
||||
"print(\"========================================\")\n",
|
||||
"batch4 = Batch([data] * 3)\n",
|
||||
"print(batch4)\n",
|
||||
"print(batch4.obs.rgb_obs.shape)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Jl3-4BRbp3MM",
|
||||
"outputId": "a8b225f6-2893-4716-c694-3c2ff558b7f0"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"========================================\n",
|
||||
"Batch(\n",
|
||||
" a: array([4, 4]),\n",
|
||||
" b: array([5, 5]),\n",
|
||||
")\n",
|
||||
"========================================\n",
|
||||
"Batch(\n",
|
||||
" a: array([4, 4]),\n",
|
||||
" b: array([5, 5]),\n",
|
||||
")\n",
|
||||
"========================================\n",
|
||||
"Batch(\n",
|
||||
" action: array([1., 2., 3.]),\n",
|
||||
" reward: array(3.66),\n",
|
||||
" obs: Batch(\n",
|
||||
" rgb_obs: array([[0., 0., 0.],\n",
|
||||
" [0., 0., 0.],\n",
|
||||
" [0., 0., 0.]]),\n",
|
||||
" flatten_obs: array([1., 1., 1., 1., 1.]),\n",
|
||||
" ),\n",
|
||||
" extra: 'extra_string',\n",
|
||||
")\n",
|
||||
"<class 'tianshou.data.batch.Batch'>\n",
|
||||
"[[0. 0. 0.]\n",
|
||||
" [0. 0. 0.]\n",
|
||||
" [0. 0. 0.]]\n",
|
||||
"========================================\n",
|
||||
"Batch(\n",
|
||||
" obs: Batch(\n",
|
||||
" rgb_obs: array([[[0., 0., 0.],\n",
|
||||
" [0., 0., 0.],\n",
|
||||
" [0., 0., 0.]],\n",
|
||||
" \n",
|
||||
" [[0., 0., 0.],\n",
|
||||
" [0., 0., 0.],\n",
|
||||
" [0., 0., 0.]],\n",
|
||||
" \n",
|
||||
" [[0., 0., 0.],\n",
|
||||
" [0., 0., 0.],\n",
|
||||
" [0., 0., 0.]]]),\n",
|
||||
" flatten_obs: array([[1., 1., 1., 1., 1.],\n",
|
||||
" [1., 1., 1., 1., 1.],\n",
|
||||
" [1., 1., 1., 1., 1.]]),\n",
|
||||
" ),\n",
|
||||
" reward: array([3.66, 3.66, 3.66]),\n",
|
||||
" action: array([[1., 2., 3.],\n",
|
||||
" [1., 2., 3.],\n",
|
||||
" [1., 2., 3.]]),\n",
|
||||
")\n",
|
||||
"(3, 3, 3)\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Getting access to data\n",
|
||||
"You can conveniently search or change the key-value pair in the Batch just as if it is a python dictionary."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "JCf6bqY3uf5L"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"batch1 = Batch({'a': [4, 4], 'b': (5, 5)})\n",
|
||||
"print(batch1)\n",
|
||||
"# add or delete key-value pair in batch1\n",
|
||||
"print(\"========================================\")\n",
|
||||
"batch1.c = Batch(c1=np.arange(3), c2=False)\n",
|
||||
"del batch1.a\n",
|
||||
"print(batch1)\n",
|
||||
"\n",
|
||||
"# access value by key\n",
|
||||
"print(\"========================================\")\n",
|
||||
"assert batch1[\"c\"] is batch1.c\n",
|
||||
"print(\"c\" in batch1)\n",
|
||||
"\n",
|
||||
"# traverse the Batch\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for key, value in batch1.items():\n",
|
||||
" print(str(key) + \": \" + str(value))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "2TNIY90-vU9b",
|
||||
"outputId": "de52ffe9-03c2-45f2-d95a-4071132daa4a"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Batch(\n",
|
||||
" a: array([4, 4]),\n",
|
||||
" b: array([5, 5]),\n",
|
||||
")\n",
|
||||
"========================================\n",
|
||||
"Batch(\n",
|
||||
" b: array([5, 5]),\n",
|
||||
" c: Batch(\n",
|
||||
" c1: array([0, 1, 2]),\n",
|
||||
" c2: array(False),\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"========================================\n",
|
||||
"True\n",
|
||||
"========================================\n",
|
||||
"b: [5 5]\n",
|
||||
"c: Batch(\n",
|
||||
" c1: array([0, 1, 2]),\n",
|
||||
" c2: array(False),\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Indexing and Slicing\n",
|
||||
"If all values in Batch share the same shape in certain dimensions. Batch can support advanced indexing and slicing just like a normal numpy array."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "bVywStbV9jD2"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Let us suppose we've got 4 environments, each returns a step of data\n",
|
||||
"step_datas = [\n",
|
||||
" {\n",
|
||||
" \"act\": np.random.randint(10),\n",
|
||||
" \"rew\": 0.0,\n",
|
||||
" \"obs\": np.ones((3, 3)),\n",
|
||||
" \"info\": {\"done\": np.random.choice(2), \"failed\": False},\n",
|
||||
" } for _ in range(4)\n",
|
||||
" ]\n",
|
||||
"batch = Batch(step_datas)\n",
|
||||
"print(batch)\n",
|
||||
"print(batch.shape)\n",
|
||||
"\n",
|
||||
"# advanced indexing is supported, if we only want to select data in a given set of environments\n",
|
||||
"print(\"========================================\")\n",
|
||||
"print(batch[0])\n",
|
||||
"print(batch[[0,3]])\n",
|
||||
"\n",
|
||||
"# slicing is also supported\n",
|
||||
"print(\"========================================\")\n",
|
||||
"print(batch[-2:])\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "gKza3OJnzc_D",
|
||||
"outputId": "4f240bfe-4a69-4c1b-b40e-983c5c4d0cbc"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Batch(\n",
|
||||
" obs: array([[[1., 1., 1.],\n",
|
||||
" [1., 1., 1.],\n",
|
||||
" [1., 1., 1.]],\n",
|
||||
" \n",
|
||||
" [[1., 1., 1.],\n",
|
||||
" [1., 1., 1.],\n",
|
||||
" [1., 1., 1.]],\n",
|
||||
" \n",
|
||||
" [[1., 1., 1.],\n",
|
||||
" [1., 1., 1.],\n",
|
||||
" [1., 1., 1.]],\n",
|
||||
" \n",
|
||||
" [[1., 1., 1.],\n",
|
||||
" [1., 1., 1.],\n",
|
||||
" [1., 1., 1.]]]),\n",
|
||||
" rew: array([0., 0., 0., 0.]),\n",
|
||||
" info: Batch(\n",
|
||||
" done: array([0, 1, 1, 0]),\n",
|
||||
" failed: array([False, False, False, False]),\n",
|
||||
" ),\n",
|
||||
" act: array([0, 5, 1, 8]),\n",
|
||||
")\n",
|
||||
"[4]\n",
|
||||
"========================================\n",
|
||||
"Batch(\n",
|
||||
" obs: array([[1., 1., 1.],\n",
|
||||
" [1., 1., 1.],\n",
|
||||
" [1., 1., 1.]]),\n",
|
||||
" rew: 0.0,\n",
|
||||
" info: Batch(\n",
|
||||
" done: 0,\n",
|
||||
" failed: False,\n",
|
||||
" ),\n",
|
||||
" act: 0,\n",
|
||||
")\n",
|
||||
"Batch(\n",
|
||||
" obs: array([[[1., 1., 1.],\n",
|
||||
" [1., 1., 1.],\n",
|
||||
" [1., 1., 1.]],\n",
|
||||
" \n",
|
||||
" [[1., 1., 1.],\n",
|
||||
" [1., 1., 1.],\n",
|
||||
" [1., 1., 1.]]]),\n",
|
||||
" rew: array([0., 0.]),\n",
|
||||
" info: Batch(\n",
|
||||
" done: array([0, 0]),\n",
|
||||
" failed: array([False, False]),\n",
|
||||
" ),\n",
|
||||
" act: array([0, 8]),\n",
|
||||
")\n",
|
||||
"========================================\n",
|
||||
"Batch(\n",
|
||||
" obs: array([[[1., 1., 1.],\n",
|
||||
" [1., 1., 1.],\n",
|
||||
" [1., 1., 1.]],\n",
|
||||
" \n",
|
||||
" [[1., 1., 1.],\n",
|
||||
" [1., 1., 1.],\n",
|
||||
" [1., 1., 1.]]]),\n",
|
||||
" rew: array([0., 0.]),\n",
|
||||
" info: Batch(\n",
|
||||
" done: array([1, 0]),\n",
|
||||
" failed: array([False, False]),\n",
|
||||
" ),\n",
|
||||
" act: array([1, 8]),\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Aggregation and Splitting\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "C7N9kU_Q9jXm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Again, just like a numpy array. Play the example code below."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "1vUwQ-Hw9jtu"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# concat batches with compatible keys\n",
|
||||
"# try incompatible keys yourself if you feel curious\n",
|
||||
"print(\"========================================\")\n",
|
||||
"b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])\n",
|
||||
"b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])\n",
|
||||
"b12_cat_out = Batch.cat([b1, b2])\n",
|
||||
"print(b1)\n",
|
||||
"print(b2)\n",
|
||||
"print(b12_cat_out)\n",
|
||||
"\n",
|
||||
"# stack batches with compatible keys\n",
|
||||
"# try incompatible keys yourself if you feel curious\n",
|
||||
"print(\"========================================\")\n",
|
||||
"b3 = Batch(a=np.zeros((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[1], [2]]))\n",
|
||||
"b4 = Batch(a=np.ones((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[0], [3]]))\n",
|
||||
"b34_stack = Batch.stack((b3, b4), axis=1)\n",
|
||||
"print(b3)\n",
|
||||
"print(b4)\n",
|
||||
"print(b34_stack)\n",
|
||||
"\n",
|
||||
"# split the batch into small batches of size 1, breaking the order of the data\n",
|
||||
"print(\"========================================\")\n",
|
||||
"print(type(b34_stack.split(1)))\n",
|
||||
"print(list(b34_stack.split(1, shuffle=True)))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "f5UkReyn3_kb",
|
||||
"outputId": "e7bb3324-7f20-4810-a328-479117efca55"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"========================================\n",
|
||||
"Batch(\n",
|
||||
" a: Batch(\n",
|
||||
" d: Batch(\n",
|
||||
" e: array([3.]),\n",
|
||||
" ),\n",
|
||||
" b: array([1.]),\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"Batch(\n",
|
||||
" a: Batch(\n",
|
||||
" d: Batch(\n",
|
||||
" e: array([6.]),\n",
|
||||
" ),\n",
|
||||
" b: array([4.]),\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"Batch(\n",
|
||||
" a: Batch(\n",
|
||||
" d: Batch(\n",
|
||||
" e: array([3., 6.]),\n",
|
||||
" ),\n",
|
||||
" b: array([1., 4.]),\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"========================================\n",
|
||||
"Batch(\n",
|
||||
" a: array([[0., 0.],\n",
|
||||
" [0., 0.],\n",
|
||||
" [0., 0.]]),\n",
|
||||
" b: array([[1., 1., 1.],\n",
|
||||
" [1., 1., 1.]]),\n",
|
||||
" c: Batch(\n",
|
||||
" d: array([[1],\n",
|
||||
" [2]]),\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"Batch(\n",
|
||||
" a: array([[1., 1.],\n",
|
||||
" [1., 1.],\n",
|
||||
" [1., 1.]]),\n",
|
||||
" b: array([[1., 1., 1.],\n",
|
||||
" [1., 1., 1.]]),\n",
|
||||
" c: Batch(\n",
|
||||
" d: array([[0],\n",
|
||||
" [3]]),\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"Batch(\n",
|
||||
" c: Batch(\n",
|
||||
" d: array([[[1],\n",
|
||||
" [0]],\n",
|
||||
" \n",
|
||||
" [[2],\n",
|
||||
" [3]]]),\n",
|
||||
" ),\n",
|
||||
" a: array([[[0., 0.],\n",
|
||||
" [1., 1.]],\n",
|
||||
" \n",
|
||||
" [[0., 0.],\n",
|
||||
" [1., 1.]],\n",
|
||||
" \n",
|
||||
" [[0., 0.],\n",
|
||||
" [1., 1.]]]),\n",
|
||||
" b: array([[[1., 1., 1.],\n",
|
||||
" [1., 1., 1.]],\n",
|
||||
" \n",
|
||||
" [[1., 1., 1.],\n",
|
||||
" [1., 1., 1.]]]),\n",
|
||||
")\n",
|
||||
"========================================\n",
|
||||
"<class 'generator'>\n",
|
||||
"[Batch(\n",
|
||||
" c: Batch(\n",
|
||||
" d: array([[[1],\n",
|
||||
" [0]]]),\n",
|
||||
" ),\n",
|
||||
" a: array([[[0., 0.],\n",
|
||||
" [1., 1.]]]),\n",
|
||||
" b: array([[[1., 1., 1.],\n",
|
||||
" [1., 1., 1.]]]),\n",
|
||||
"), Batch(\n",
|
||||
" c: Batch(\n",
|
||||
" d: array([[[2],\n",
|
||||
" [3]]]),\n",
|
||||
" ),\n",
|
||||
" a: array([[[0., 0.],\n",
|
||||
" [1., 1.]]]),\n",
|
||||
" b: array([[[1., 1., 1.],\n",
|
||||
" [1., 1., 1.]]]),\n",
|
||||
")]\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Data type converting\n",
|
||||
"Besides numpy array, Batch actually also supports Torch Tensor. The usages are exactly the same. Cool, isn't it?"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Smc_W1Cx6zRS"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"batch1 = Batch(a=np.arange(2), b=torch.zeros((2,2)))\n",
|
||||
"batch2 = Batch(a=np.arange(2), b=torch.ones((2,2)))\n",
|
||||
"batch_cat = Batch.cat([batch1, batch2, batch1])\n",
|
||||
"print(batch_cat)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Y6im_Mtb7Ody",
|
||||
"outputId": "898e82c4-b940-4c35-a0f9-dedc4a9bc500"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Batch(\n",
|
||||
" a: array([0, 1, 0, 1, 0, 1]),\n",
|
||||
" b: tensor([[0., 0.],\n",
|
||||
" [0., 0.],\n",
|
||||
" [1., 1.],\n",
|
||||
" [1., 1.],\n",
|
||||
" [0., 0.],\n",
|
||||
" [0., 0.]]),\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"You can convert the data type easily, if you no longer want to use hybrid data type anymore."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "1wfTUVKb6xki"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"batch_cat.to_numpy()\n",
|
||||
"print(batch_cat)\n",
|
||||
"batch_cat.to_torch()\n",
|
||||
"print(batch_cat)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "F7WknVs98DHD",
|
||||
"outputId": "cfd0712a-1df3-4208-e6cc-9149840bdc40"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Batch(\n",
|
||||
" a: array([0, 1, 0, 1, 0, 1]),\n",
|
||||
" b: array([[0., 0.],\n",
|
||||
" [0., 0.],\n",
|
||||
" [1., 1.],\n",
|
||||
" [1., 1.],\n",
|
||||
" [0., 0.],\n",
|
||||
" [0., 0.]], dtype=float32),\n",
|
||||
")\n",
|
||||
"Batch(\n",
|
||||
" a: tensor([0, 1, 0, 1, 0, 1]),\n",
|
||||
" b: tensor([[0., 0.],\n",
|
||||
" [0., 0.],\n",
|
||||
" [1., 1.],\n",
|
||||
" [1., 1.],\n",
|
||||
" [0., 0.],\n",
|
||||
" [0., 0.]]),\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Batch is even serializable, just in case you may need to save it to disk or restore it."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "NTFVle1-9Biz"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))\n",
|
||||
"batch_pk = pickle.loads(pickle.dumps(batch))\n",
|
||||
"print(batch_pk)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Lnf17OXv9YRb",
|
||||
"outputId": "753753f2-3f66-4d4b-b4ff-d57f9c40d1da"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Batch(\n",
|
||||
" obs: Batch(\n",
|
||||
" a: array(0.),\n",
|
||||
" c: tensor([1., 2.]),\n",
|
||||
" ),\n",
|
||||
" np: array([[0., 0., 0., 0.],\n",
|
||||
" [0., 0., 0., 0.],\n",
|
||||
" [0., 0., 0., 0.]]),\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Further Reading"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "-vPMiPZ-9kJN"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Would like to learn more advanced usages of Batch? Feel curious about how data is organised inside the Batch? Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.data.html) and other [tutorials](https://tianshou.readthedocs.io/en/master/tutorials/batch.html#) for more details."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "8Oc1p8ud9kcu"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
575
notebooks/L2_Buffer.ipynb
Normal file
575
notebooks/L2_Buffer.ipynb
Normal file
@ -0,0 +1,575 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "4TCEkXj7LFe2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Remember to install tianshou first\n",
|
||||
"!pip install tianshou==0.4.8\n",
|
||||
"!pip install gym"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Overview\n",
|
||||
"Replay Buffer is a very common module in DRL implementations. In Tianshou, you can consider Buffer module as as a specialized form of Batch, which helps you track all data trajectories and provide utilities such as sampling method besides the basic storage.\n",
|
||||
"\n",
|
||||
"There are many kinds of Buffer modules in Tianshou, two most basic ones are ReplayBuffer and VectorReplayBuffer. The later one is specially designed for parallelized environments (will introduce in tutorial L3)."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "xoPiGVD8LNma"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Usages"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "OdesCAxANehZ"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Basic usages as a batch\n",
|
||||
"Usually a buffer stores all the data in a batch with circular-queue style."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "fUbLl9T_SrTR"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from tianshou.data import Batch, ReplayBuffer\n",
|
||||
"# a buffer is initialised with its maxsize set to 10 (older data will be discarded if more data flow in).\n",
|
||||
"print(\"========================================\")\n",
|
||||
"buf = ReplayBuffer(size=10)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"\n",
|
||||
"# add 3 steps of data into ReplayBuffer sequentially\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(3):\n",
|
||||
" buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"\n",
|
||||
"# add another 10 steps of data into ReplayBuffer sequentially\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(3, 13):\n",
|
||||
" buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "mocZ6IqZTH62",
|
||||
"outputId": "66cc4181-c51b-4a47-aacf-666b92b7fc52"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"========================================\n",
|
||||
"ReplayBuffer()\n",
|
||||
"maxsize: 10, data length: 0\n",
|
||||
"========================================\n",
|
||||
"ReplayBuffer(\n",
|
||||
" info: Batch(),\n",
|
||||
" act: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),\n",
|
||||
" obs: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),\n",
|
||||
" done: array([False, False, False, False, False, False, False, False, False,\n",
|
||||
" False]),\n",
|
||||
" rew: array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0.]),\n",
|
||||
" obs_next: array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0]),\n",
|
||||
")\n",
|
||||
"maxsize: 10, data length: 3\n",
|
||||
"========================================\n",
|
||||
"ReplayBuffer(\n",
|
||||
" info: Batch(),\n",
|
||||
" act: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n",
|
||||
" obs: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n",
|
||||
" done: array([False, False, False, False, False, False, False, False, False,\n",
|
||||
" False]),\n",
|
||||
" rew: array([10., 11., 12., 3., 4., 5., 6., 7., 8., 9.]),\n",
|
||||
" obs_next: array([11, 12, 13, 4, 5, 6, 7, 8, 9, 10]),\n",
|
||||
")\n",
|
||||
"maxsize: 10, data length: 10\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Just like Batch, ReplayBuffer supports concatenation, splitting, advanced slicing and indexing, etc."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "H8B85Y5yUfTy"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(buf[-1])\n",
|
||||
"print(buf[-3:])\n",
|
||||
"# Try more methods you find useful in Batch yourself."
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "cOX-ADOPNeEK",
|
||||
"outputId": "f1a8ec01-b878-419b-f180-bdce3dee73e6"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Batch(\n",
|
||||
" obs: array(9),\n",
|
||||
" act: array(9),\n",
|
||||
" rew: array(9.),\n",
|
||||
" done: array(False),\n",
|
||||
" obs_next: array(10),\n",
|
||||
" info: Batch(),\n",
|
||||
" policy: Batch(),\n",
|
||||
")\n",
|
||||
"Batch(\n",
|
||||
" obs: array([7, 8, 9]),\n",
|
||||
" act: array([7, 8, 9]),\n",
|
||||
" rew: array([7., 8., 9.]),\n",
|
||||
" done: array([False, False, False]),\n",
|
||||
" obs_next: array([ 8, 9, 10]),\n",
|
||||
" info: Batch(),\n",
|
||||
" policy: Batch(),\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"ReplayBuffer can also be saved into local disk, still keeping track of the trajectories. This is extremely helpful in offline DRL settings."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "vqldap-2WQBh"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"_buf = pickle.loads(pickle.dumps(buf))"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Ppx0L3niNT5K"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Understanding reserved keys for buffer\n",
|
||||
"As I have explained, ReplayBuffer is specially designed to utilize the implementations of DRL algorithms. So, for convenience, we reserve certain seven reserved keys in Batch.\n",
|
||||
"\n",
|
||||
"* `obs`\n",
|
||||
"* `act`\n",
|
||||
"* `rew`\n",
|
||||
"* `done`\n",
|
||||
"* `obs_next`\n",
|
||||
"* `info`\n",
|
||||
"* `policy`\n",
|
||||
"\n",
|
||||
"The meaning of these seven reserved keys are consistent with the meaning in [OPENAI Gym](https://gym.openai.com/). We would recommend you simply use these seven keys when adding batched data into ReplayBuffer, because\n",
|
||||
"some of them are tracked in ReplayBuffer (e.g. \"done\" value is tracked to help us determine a trajectory's start index and end index, together with its total reward and episode length.)\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"buf.add(Batch(......, extro_info=0)) # This is okay but not recommended.\n",
|
||||
"buf.add(Batch(......, info={\"extro_info\":0})) # Recommended.\n",
|
||||
"```\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Eqezp0OyXn6J"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Data sampling\n",
|
||||
"We keep a replay buffer in DRL for one purpose:\"sample data from it for training\". `ReplayBuffer.sample()` and `ReplayBuffer.split(..., shuffle=True)` can both fullfill this need."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ueAbTspsc6jo"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"buf.sample(5)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "P5xnYOhrchDl",
|
||||
"outputId": "bcd2c970-efa6-43bb-8709-720d38f77bbd"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(Batch(\n",
|
||||
" obs: array([10, 11, 4, 3, 8]),\n",
|
||||
" act: array([10, 11, 4, 3, 8]),\n",
|
||||
" rew: array([10., 11., 4., 3., 8.]),\n",
|
||||
" done: array([False, False, False, False, False]),\n",
|
||||
" obs_next: array([11, 12, 5, 4, 9]),\n",
|
||||
" info: Batch(),\n",
|
||||
" policy: Batch(),\n",
|
||||
" ), array([0, 1, 4, 3, 8]))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"execution_count": 5
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Trajectory tracking\n",
|
||||
"Compared to Batch, a unique feature of ReplayBuffer is that it can help you track the environment trajectories.\n",
|
||||
"\n",
|
||||
"First, let us simulate a situation, where we add three trajectories into the buffer. The last trajectory is still not finished yet."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "IWyaOSKOcgK4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from numpy import False_\n",
|
||||
"buf = ReplayBuffer(size=10)\n",
|
||||
"# Add the first trajectory (length is 3) into ReplayBuffer\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(3):\n",
|
||||
" result = buf.add(Batch(obs=i, act=i, rew=i, done=True if i==2 else False, obs_next=i + 1, info={}))\n",
|
||||
" print(result)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"# Add the second trajectory (length is 5) into ReplayBuffer\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(3, 8):\n",
|
||||
" result = buf.add(Batch(obs=i, act=i, rew=i, done=True if i==7 else False, obs_next=i + 1, info={}))\n",
|
||||
" print(result)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"# Add the third trajectory (length is 5, still not finished) into ReplayBuffer\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(8, 13):\n",
|
||||
" result = buf.add(Batch(obs=i, act=i, rew=i, done=False, obs_next=i + 1, info={}))\n",
|
||||
" print(result)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "H0qRb6HLfhLB",
|
||||
"outputId": "9bdb7d4e-b6ec-489f-a221-0bddf706d85b"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"========================================\n",
|
||||
"(array([0]), array([0.]), array([0]), array([0]))\n",
|
||||
"(array([1]), array([0.]), array([0]), array([0]))\n",
|
||||
"(array([2]), array([3.]), array([3]), array([0]))\n",
|
||||
"ReplayBuffer(\n",
|
||||
" info: Batch(),\n",
|
||||
" act: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),\n",
|
||||
" obs: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),\n",
|
||||
" done: array([False, False, True, False, False, False, False, False, False,\n",
|
||||
" False]),\n",
|
||||
" rew: array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0.]),\n",
|
||||
" obs_next: array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0]),\n",
|
||||
")\n",
|
||||
"maxsize: 10, data length: 3\n",
|
||||
"========================================\n",
|
||||
"(array([3]), array([0.]), array([0]), array([3]))\n",
|
||||
"(array([4]), array([0.]), array([0]), array([3]))\n",
|
||||
"(array([5]), array([0.]), array([0]), array([3]))\n",
|
||||
"(array([6]), array([0.]), array([0]), array([3]))\n",
|
||||
"(array([7]), array([25.]), array([5]), array([3]))\n",
|
||||
"ReplayBuffer(\n",
|
||||
" info: Batch(),\n",
|
||||
" act: array([0, 1, 2, 3, 4, 5, 6, 7, 0, 0]),\n",
|
||||
" obs: array([0, 1, 2, 3, 4, 5, 6, 7, 0, 0]),\n",
|
||||
" done: array([False, False, True, False, False, False, False, True, False,\n",
|
||||
" False]),\n",
|
||||
" rew: array([0., 1., 2., 3., 4., 5., 6., 7., 0., 0.]),\n",
|
||||
" obs_next: array([1, 2, 3, 4, 5, 6, 7, 8, 0, 0]),\n",
|
||||
")\n",
|
||||
"maxsize: 10, data length: 8\n",
|
||||
"========================================\n",
|
||||
"(array([8]), array([0.]), array([0]), array([8]))\n",
|
||||
"(array([9]), array([0.]), array([0]), array([8]))\n",
|
||||
"(array([0]), array([0.]), array([0]), array([8]))\n",
|
||||
"(array([1]), array([0.]), array([0]), array([8]))\n",
|
||||
"(array([2]), array([0.]), array([0]), array([8]))\n",
|
||||
"ReplayBuffer(\n",
|
||||
" info: Batch(),\n",
|
||||
" act: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n",
|
||||
" obs: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n",
|
||||
" done: array([False, False, False, False, False, False, False, True, False,\n",
|
||||
" False]),\n",
|
||||
" rew: array([10., 11., 12., 3., 4., 5., 6., 7., 8., 9.]),\n",
|
||||
" obs_next: array([11, 12, 13, 4, 5, 6, 7, 8, 9, 10]),\n",
|
||||
")\n",
|
||||
"maxsize: 10, data length: 10\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### episode length and rewards tracking\n",
|
||||
"Notice that `ReplayBuffer.add()` returns a tuple of 4 numbers every time it returns, meaning `(current_index, episode_reward, episode_length, episode_start_index)`. `episode_reward` and `episode_length` are valid only when a trajectory is finished. This might save developers some trouble.\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "dO7PWdb_hkXA"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Episode index management\n",
|
||||
"In the ReplayBuffer above, we can get access to any data step by indexing.\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "xbVc90z8itH0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(buf)\n",
|
||||
"data = buf[6]\n",
|
||||
"print(data)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "4mKwo54MjupY",
|
||||
"outputId": "9ae14a7e-908b-44eb-afec-89b45bac5961"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"ReplayBuffer(\n",
|
||||
" info: Batch(),\n",
|
||||
" act: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n",
|
||||
" obs: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n",
|
||||
" done: array([False, False, False, False, False, False, False, True, False,\n",
|
||||
" False]),\n",
|
||||
" rew: array([10., 11., 12., 3., 4., 5., 6., 7., 8., 9.]),\n",
|
||||
" obs_next: array([11, 12, 13, 4, 5, 6, 7, 8, 9, 10]),\n",
|
||||
")\n",
|
||||
"Batch(\n",
|
||||
" obs: array(6),\n",
|
||||
" act: array(6),\n",
|
||||
" rew: array(6.),\n",
|
||||
" done: array(False),\n",
|
||||
" obs_next: array(7),\n",
|
||||
" info: Batch(),\n",
|
||||
" policy: Batch(),\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Now we know that step \"6\" is not the start of an episode (it should be step 4, 4-7 is the second trajectory we add into the ReplayBuffer), we wonder what is the earliest index of the that episode.\n",
|
||||
"\n",
|
||||
"This may seem easy but actually it is not. We cannot simply look at the \"done\" flag now, because we can see that since the third-added trajectory is not finished yet, step \"4\" is surrounded by flag \"False\". There are many things to consider. Things could get more nasty if you are using more advanced ReplayBuffer like VectorReplayBuffer, because now the data is not stored in a simple circular-queue.\n",
|
||||
"\n",
|
||||
"Luckily, all ReplayBuffer instances help you identify step indexes through a unified API."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "p5Co_Fmzj8Sw"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Search for the previous index of index \"6\"\n",
|
||||
"now_index = 6\n",
|
||||
"while True:\n",
|
||||
" prev_index = buf.prev(now_index)\n",
|
||||
" print(prev_index)\n",
|
||||
" if prev_index == now_index:\n",
|
||||
" break\n",
|
||||
" else:\n",
|
||||
" now_index = prev_index"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "DcJ0LEX6mxHg",
|
||||
"outputId": "7830f5fb-96d9-4298-d09b-24e64b2f633c"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"5\n",
|
||||
"4\n",
|
||||
"3\n",
|
||||
"3\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Using `ReplayBuffer.prev()`, we know that the earliest step of that episode is step \"3\". Similarly, `ReplayBuffer.next()` helps us indentify the last index of an episode regardless of which kind of ReplayBuffer we are using."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "4Wlb57V4lQyQ"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# next step of indexes [4,5,6,7,8,9] are:\n",
|
||||
"print(buf.next([4,5,6,7,8,9]))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "zl5TRMo7oOy5",
|
||||
"outputId": "4a11612c-3ee0-4e74-b028-c8759e71fbdb"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"[5 6 7 7 9 0]\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We can also search for the indexes which are labeled \"done: False\", but are the last step in a trajectory."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "YJ9CcWZXoOXw"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(buf.unfinished_index())"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Xkawk97NpItg",
|
||||
"outputId": "df10b359-c2c7-42ca-e50d-9caee6bccadd"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"[2]\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Aforementioned APIs will be helpful when we calculate quantities like GAE and n-step-returns in DRL algorithms ([Example usage in Tianshou](https://github.com/thu-ml/tianshou/blob/6fc68578127387522424460790cbcb32a2bd43c4/tianshou/policy/base.py#L384)). The unified APIs ensure a modular design and a flexible interface."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "8_lMr0j3pOmn"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Further Reading\n",
|
||||
"## Other Buffer Module\n",
|
||||
"\n",
|
||||
"* PrioritizedReplayBuffer, which helps you implement [prioritized experience replay](https://arxiv.org/abs/1511.05952)\n",
|
||||
"* CachedReplayBuffer, one main buffer with several cached buffers (higher sample efficiency in some scenarios)\n",
|
||||
"* ReplayBufferManager, A base class that can be inherited (may help you manage multiple buffers).\n",
|
||||
"\n",
|
||||
"Check the documentation and the source code for more details.\n",
|
||||
"\n",
|
||||
"## Support for steps stacking to use RNN in DRL.\n",
|
||||
"There is an option called `stack_num` (default to 1) when initialising the ReplayBuffer, which may help you use RNN in your algorithm. Check the documentation for details."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "FEyE0c7tNfwa"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
225
notebooks/L3_Vectorized__Environment.ipynb
Normal file
225
notebooks/L3_Vectorized__Environment.ipynb
Normal file
@ -0,0 +1,225 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "0T7FYEnlBT6F"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Remember to install tianshou first\n",
|
||||
"!pip install tianshou==0.4.8\n",
|
||||
"!pip install gym"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Overview\n",
|
||||
"In reinforcement learning, the agent interacts with environments to improve itself. In this tutorial we will concentrate on the environment part. Although there are many kinds of environments or their libraries in DRL research, Tianshou chooses to keep a consistent API with [OPENAI Gym](https://gym.openai.com/).\n",
|
||||
"\n",
|
||||
"<div align=center>\n",
|
||||
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/rl-loop.jpg\", title=\"The agents interacting with the environment\">\n",
|
||||
"\n",
|
||||
"<a> The agents interacting with the environment </a>\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
"In Gym, an environment receives an action and returns next observation and reward. This process is slow and sometimes can be the throughput bottleneck in a DRL experiment.\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "W5V7z3fVX7_b"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Tianshou provides vectorized environment wrapper for a Gym environment. This wrapper allows you to make use of multiple cpu cores in your server to accelerate the data sampling."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "A0NGWZ8adBwt"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from tianshou.env import SubprocVectorEnv\n",
|
||||
"import numpy as np\n",
|
||||
"import gym\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"num_cpus = [1,2,5]\n",
|
||||
"for num_cpu in num_cpus:\n",
|
||||
" env = SubprocVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(num_cpu)])\n",
|
||||
" env.reset()\n",
|
||||
" sampled_steps = 0\n",
|
||||
" time_start = time.time()\n",
|
||||
" while sampled_steps < 1000:\n",
|
||||
" act = np.random.choice(2, size=num_cpu)\n",
|
||||
" obs, rew, done, info = env.step(act)\n",
|
||||
" if np.sum(done):\n",
|
||||
" env.reset(np.where(done)[0])\n",
|
||||
" sampled_steps += num_cpu\n",
|
||||
" time_used = time.time() - time_start\n",
|
||||
" print(\"{}s used to sample 1000 steps if using {} cpus.\".format(time_used, num_cpu))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "67wKtkiNi3lb",
|
||||
"outputId": "1e04353b-7a91-4c32-e2ae-f3889d58aa5e"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"0.30551695823669434s used to sample 1000 steps if using 1 cpus.\n",
|
||||
"0.2602052688598633s used to sample 1000 steps if using 2 cpus.\n",
|
||||
"0.15763545036315918s used to sample 1000 steps if using 5 cpus.\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"You may notice that the speed doesn't increase linearly when we add subprocess numbers. There are multiple reasons behind this. One reason is that synchronize exection causes straggler effect. One way to solve this would be to use asynchronous mode. We leave this for further reading if you feel interested.\n",
|
||||
"\n",
|
||||
"Note that SubprocVectorEnv should only be used when the environment exection is slow. In practice, DummyVectorEnv (or raw Gym environment) is actually more efficient for a simple environment like CartPole because now you avoid both straggler effect and the overhead of communication between subprocesses."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "S1b6vxp9nEUS"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Usages\n",
|
||||
"## Initialisation\n",
|
||||
"Just pass in a list of functions which return the initialised environment upon called."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Z6yPxdqFp18j"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from tianshou.env import DummyVectorEnv\n",
|
||||
"# In Gym\n",
|
||||
"env = gym.make(\"CartPole-v0\")\n",
|
||||
"\n",
|
||||
"# In Tianshou\n",
|
||||
"def helper_function():\n",
|
||||
" env = gym.make(\"CartPole-v0\")\n",
|
||||
" # other operations such as env.seed(np.random.choice(10))\n",
|
||||
" return env\n",
|
||||
"\n",
|
||||
"envs = DummyVectorEnv([helper_function for _ in range(5)])\n",
|
||||
"print(envs)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ssLcrL_pq24-"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## EnvPool supporting\n",
|
||||
"Besides integrated environment wrappers, Tianshou also fully supports [EnvPool](https://github.com/sail-sg/envpool/). Explore its Github page yourself."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "X7p8csjdrwIN"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Environment exection and resetting\n",
|
||||
"The only difference between Vectorized environments and standard Gym environments is that passed in actions and returned rewards/observations are also vectorized."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "kvIfqh0vqAR5"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# In Gym, env.reset() returns a single observation.\n",
|
||||
"print(\"In Gym, env.reset() returns a single observation.\")\n",
|
||||
"print(env.reset())\n",
|
||||
"\n",
|
||||
"# In Tianshou, envs.reset() returns stacked observations.\n",
|
||||
"print(\"========================================\")\n",
|
||||
"print(\"In Tianshou, envs.reset() returns stacked observations.\")\n",
|
||||
"print(envs.reset())\n",
|
||||
"\n",
|
||||
"obs, rew, done, info = envs.step(np.random.choice(2, size=num_cpu))\n",
|
||||
"print(info)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "BH1ZnPG6tkdD"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"If we only want to execute several environments. The `id` argument can be used."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "qXroB7KluvP9"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(envs.step(np.random.choice(2, size=3), id=[0,3,1]))"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ufvFViKTu8d_"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Further Reading\n",
|
||||
"## Other environment wrappers in Tianshou\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"* ShmemVectorEnv: use share memory instead of pipe based on SubprocVectorEnv;\n",
|
||||
"* RayVectorEnv: use Ray for concurrent activities and is currently the only choice for parallel simulation in a cluster with multiple machines.\n",
|
||||
"\n",
|
||||
"Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.env.html) for details.\n",
|
||||
"\n",
|
||||
"## Difference between synchronous and asynchronous mode (How to choose?)\n",
|
||||
"Explanation can be found at the [Parallel Sampling](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling) tutorial."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "fekHR1a6X_HB"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
954
notebooks/L4_Policy.ipynb
Normal file
954
notebooks/L4_Policy.ipynb
Normal file
File diff suppressed because one or more lines are too long
365
notebooks/L5_Collector.ipynb
Normal file
365
notebooks/L5_Collector.ipynb
Normal file
File diff suppressed because one or more lines are too long
379
notebooks/L6_Trainer.ipynb
Normal file
379
notebooks/L6_Trainer.ipynb
Normal file
File diff suppressed because one or more lines are too long
352
notebooks/L7_Experiment.ipynb
Normal file
352
notebooks/L7_Experiment.ipynb
Normal file
@ -0,0 +1,352 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": [],
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Overview\n",
|
||||
"Finally, we can assemble building blocks that we have came across in previous tutorials to conduct our first DRL experiment. In this experiment, we will use [PPO](https://arxiv.org/abs/1707.06347) algorithm to solve the classic CartPole task in Gym."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "_UaXOSRjDUF9"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Experiment\n",
|
||||
"To conduct this experiment, we need the following building blocks.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"* Two vectorized environments, one for training and one for evaluation\n",
|
||||
"* A PPO agent\n",
|
||||
"* A replay buffer to store transition data\n",
|
||||
"* Two collectors to manage the data collecting process, one for training and one for evaluation\n",
|
||||
"* A trainer to manage the training loop\n",
|
||||
"\n",
|
||||
"<div align=center>\n",
|
||||
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/pipeline.png\", width=500>\n",
|
||||
"\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
"Let us do this step by step."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "2QRbCJvDHNAd"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Preparation\n",
|
||||
"Firstly, install Tianshou if you haven't installed it before."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "-Hh4E6i0Hj0I"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!pip install tianshou==0.4.8\n",
|
||||
"!pip install gym"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "w50BVwaRHg3N"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Import libraries we might need later."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "7E4EhiBeHxD5"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import gym\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from tianshou.data import Collector, VectorReplayBuffer\n",
|
||||
"from tianshou.env import DummyVectorEnv\n",
|
||||
"from tianshou.policy import PPOPolicy\n",
|
||||
"from tianshou.trainer import onpolicy_trainer\n",
|
||||
"from tianshou.utils.net.common import ActorCritic, Net\n",
|
||||
"from tianshou.utils.net.discrete import Actor, Critic\n",
|
||||
"\n",
|
||||
"import warnings\n",
|
||||
"warnings.filterwarnings('ignore')\n",
|
||||
"\n",
|
||||
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ao9gWJDiHgG-"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Environment"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "QnRg5y7THRYw"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We create two vectorized environments both for training and testing. Since the execution time of CartPole is extremely short, there is no need to use multi-process wrappers and we simply use DummyVectorEnv."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "YZERKCGtH8W1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Mpuj5PFnDKVS"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = gym.make('CartPole-v0')\n",
|
||||
"train_envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(20)])\n",
|
||||
"test_envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Policy\n",
|
||||
"Next we need to initialise our PPO policy. PPO is an actor-critic-style on-policy algorithm, so we have to define the actor and the critic in PPO first.\n",
|
||||
"\n",
|
||||
"The actor is a neural network that shares the same network head with the critic. Both networks' input is the environment observation. The output of the actor is the action and the output of the critic is a single value, representing the value of the current policy.\n",
|
||||
"\n",
|
||||
"Luckily, Tianshou already provides basic network modules that we can use in this experiment."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "BJtt_Ya8DTAh"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# net is the shared head of the actor and the critic\n",
|
||||
"net = Net(env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n",
|
||||
"actor = Actor(net, env.action_space.n, device=device).to(device)\n",
|
||||
"critic = Critic(net, device=device).to(device)\n",
|
||||
"actor_critic = ActorCritic(actor, critic)\n",
|
||||
"\n",
|
||||
"# optimizer of the actor and the critic\n",
|
||||
"optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "_Vy8uPWXP4m_"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Once we have defined the actor, the critic and the optimizer. We can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Lh2-hwE5Dn9I"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"dist = torch.distributions.Categorical\n",
|
||||
"policy = PPOPolicy(actor, critic, optim, dist, action_space=env.action_space, deterministic_eval=True)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "OiJ2GkT0Qnbr"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"`deterministic_eval=True` means that we want to sample actions during training but we would like to always use the best action in evaluation. No randomness included."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "okxfj6IEQ-r8"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Collector\n",
|
||||
"We can set up the collectors now. Train collector is used to collect and store training data, so an additional replay buffer has to be passed in."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "n5XAAbuBZarO"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)))\n",
|
||||
"test_collector = Collector(policy, test_envs)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ezwz0qerZhQM"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We use `VectorReplayBuffer` here because it's more efficient to collaborate with vectorized environments, you can simply consider `VectorReplayBuffer` as a a list of ordinary replay buffers."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ZaoPxOd2hm0b"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Trainer\n",
|
||||
"Finally, we can use the trainer to help us set up the training loop."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "qBoE9pLUiC-8"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"result = onpolicy_trainer(\n",
|
||||
" policy,\n",
|
||||
" train_collector,\n",
|
||||
" test_collector,\n",
|
||||
" max_epoch=10,\n",
|
||||
" step_per_epoch=50000,\n",
|
||||
" repeat_per_collect=10,\n",
|
||||
" episode_per_test=10,\n",
|
||||
" batch_size=256,\n",
|
||||
" step_per_collect=2000,\n",
|
||||
" stop_fn=lambda mean_reward: mean_reward >= 195,\n",
|
||||
")"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "i45EDnpxQ8gu",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"outputId": "b1666b88-0bfa-4340-868e-58611872d988"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"Epoch #1: 50001it [00:13, 3601.81it/s, env_step=50000, len=143, loss=41.162, loss/clip=0.001, loss/ent=0.583, loss/vf=82.332, n/ep=12, n/st=2000, rew=143.08] \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Epoch #1: test_reward: 200.000000 ± 0.000000, best_reward: 200.000000 ± 0.000000 in #1\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Results\n",
|
||||
"Print the training result."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ckgINHE2iTFR"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(result)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "tJCPgmiyiaaX",
|
||||
"outputId": "40123ae3-3365-4782-9563-46c43812f10f"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"{'duration': '14.17s', 'train_time/model': '8.80s', 'test_step': 2094, 'test_episode': 20, 'test_time': '0.27s', 'test_speed': '7770.16 step/s', 'best_reward': 200.0, 'best_result': '200.00 ± 0.00', 'train_step': 50000, 'train_episode': 942, 'train_time/collector': '5.10s', 'train_speed': '3597.32 step/s'}\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We can also test our trained agent."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "A-MJ9avMibxN"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Let's watch its performance!\n",
|
||||
"policy.eval()\n",
|
||||
"result = test_collector.collect(n_episode=1, render=False)\n",
|
||||
"print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "mnMANFcciiAQ",
|
||||
"outputId": "6febcc1e-7265-4a75-c9dd-34e29a3e5d21"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Final reward: 200.0, length: 200.0\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user