From efaadec6a10ee4c5d42d50a25c0a232bc1be026e Mon Sep 17 00:00:00 2001 From: carlocagnetta Date: Tue, 17 Oct 2023 13:59:37 +0200 Subject: [PATCH] Removed notebook outputs --- notebooks/L0_overview.ipynb | 157 +- notebooks/L1_Batch.ipynb | 1086 +++++------- notebooks/L2_Buffer.ipynb | 949 +++++------ notebooks/L3_Vectorized__Environment.ipynb | 434 +++-- notebooks/L4_Policy.ipynb | 1740 +++++++++----------- notebooks/L5_Collector.ipynb | 614 +++---- notebooks/L6_Trainer.ipynb | 598 +++---- notebooks/L7_Experiment.ipynb | 663 ++++---- 8 files changed, 2572 insertions(+), 3669 deletions(-) diff --git a/notebooks/L0_overview.ipynb b/notebooks/L0_overview.ipynb index c7d95a6..7a4a72d 100644 --- a/notebooks/L0_overview.ipynb +++ b/notebooks/L0_overview.ipynb @@ -38,134 +38,11 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { - "id": "qvplhjduVDs6", - "ExecuteTime": { - "end_time": "2023-10-12T15:51:01.680688825Z", - "start_time": "2023-10-12T15:48:15.090023052Z" - } + "id": "qvplhjduVDs6" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting tianshou==0.4.8\r\n", - " Downloading tianshou-0.4.8-py3-none-any.whl (150 kB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m150.4/150.4 kB\u001B[0m \u001B[31m3.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0ma \u001B[36m0:00:01\u001B[0m\r\n", - "\u001B[?25hCollecting gym>=0.15.4 (from tianshou==0.4.8)\r\n", - " Downloading gym-0.26.2.tar.gz (721 kB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m721.7/721.7 kB\u001B[0m \u001B[31m11.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0ma \u001B[36m0:00:01\u001B[0m\r\n", - "\u001B[?25h Installing build dependencies ... \u001B[?25ldone\r\n", - "\u001B[?25h Getting requirements to build wheel ... \u001B[?25ldone\r\n", - "\u001B[?25h Preparing metadata (pyproject.toml) ... \u001B[?25ldone\r\n", - "\u001B[?25hRequirement already satisfied: tqdm in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (4.66.1)\r\n", - "Requirement already satisfied: numpy>1.16.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (1.24.4)\r\n", - "Requirement already satisfied: tensorboard>=2.5.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (2.14.1)\r\n", - "Requirement already satisfied: torch>=1.4.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (2.1.0)\r\n", - "Requirement already satisfied: numba>=0.51.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (0.57.1)\r\n", - "Requirement already satisfied: h5py>=2.10.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (3.10.0)\r\n", - "Requirement already satisfied: cloudpickle>=1.2.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym>=0.15.4->tianshou==0.4.8) (2.2.1)\r\n", - "Collecting gym-notices>=0.0.4 (from gym>=0.15.4->tianshou==0.4.8)\r\n", - " Downloading gym_notices-0.0.8-py3-none-any.whl (3.0 kB)\r\n", - "Requirement already satisfied: llvmlite<0.41,>=0.40.0dev0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from numba>=0.51.0->tianshou==0.4.8) (0.40.1)\r\n", - "Requirement already satisfied: absl-py>=0.4 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (2.0.0)\r\n", - "Requirement already satisfied: grpcio>=1.48.2 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (1.59.0)\r\n", - "Requirement already satisfied: google-auth<3,>=1.6.3 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (2.23.3)\r\n", - "Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (1.0.0)\r\n", - "Requirement already satisfied: markdown>=2.6.8 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (3.5)\r\n", - "Requirement already satisfied: protobuf>=3.19.6 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (3.20.3)\r\n", - "Requirement already satisfied: requests<3,>=2.21.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (2.31.0)\r\n", - "Requirement already satisfied: setuptools>=41.0.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (68.2.2)\r\n", - "Requirement already satisfied: six>1.9 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (1.16.0)\r\n", - "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (0.7.1)\r\n", - "Requirement already satisfied: werkzeug>=1.0.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (3.0.0)\r\n", - "Requirement already satisfied: filelock in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (3.12.4)\r\n", - "Requirement already satisfied: typing-extensions in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (4.8.0)\r\n", - "Requirement already satisfied: sympy in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (1.12)\r\n", - "Requirement already satisfied: networkx in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (3.1)\r\n", - "Requirement already satisfied: jinja2 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (3.1.2)\r\n", - "Requirement already satisfied: fsspec in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (2023.9.2)\r\n", - "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n", - " Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m23.7/23.7 MB\u001B[0m \u001B[31m17.1 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n", - "\u001B[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n", - " Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m823.6/823.6 kB\u001B[0m \u001B[31m20.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m\r\n", - "\u001B[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n", - " Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m14.1/14.1 MB\u001B[0m \u001B[31m14.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n", - "\u001B[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.4.0->tianshou==0.4.8)\r\n", - " Obtaining dependency information for nvidia-cudnn-cu12==8.9.2.26 from https://files.pythonhosted.org/packages/ff/74/a2e2be7fb83aaedec84f391f082cf765dfb635e7caa9b49065f73e4835d8/nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata\r\n", - " Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\r\n", - "Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=1.4.0->tianshou==0.4.8)\r\n", - " Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m410.6/410.6 MB\u001B[0m \u001B[31m5.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n", - "\u001B[?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch>=1.4.0->tianshou==0.4.8)\r\n", - " Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m121.6/121.6 MB\u001B[0m \u001B[31m9.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n", - "\u001B[?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch>=1.4.0->tianshou==0.4.8)\r\n", - " Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m56.5/56.5 MB\u001B[0m \u001B[31m13.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n", - "\u001B[?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch>=1.4.0->tianshou==0.4.8)\r\n", - " Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m124.2/124.2 MB\u001B[0m \u001B[31m10.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n", - "\u001B[?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch>=1.4.0->tianshou==0.4.8)\r\n", - " Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m196.0/196.0 MB\u001B[0m \u001B[31m8.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n", - "\u001B[?25hCollecting nvidia-nccl-cu12==2.18.1 (from torch>=1.4.0->tianshou==0.4.8)\r\n", - " Downloading nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl (209.8 MB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m209.8/209.8 MB\u001B[0m \u001B[31m8.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n", - "\u001B[?25hCollecting nvidia-nvtx-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n", - " Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m99.1/99.1 kB\u001B[0m \u001B[31m28.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\r\n", - "\u001B[?25hCollecting triton==2.1.0 (from torch>=1.4.0->tianshou==0.4.8)\r\n", - " Obtaining dependency information for triton==2.1.0 from https://files.pythonhosted.org/packages/5c/c1/54fffb2eb13d293d9a429fead3646752ea190de0229bcf3d591ba2481263/triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata\r\n", - " Downloading triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.3 kB)\r\n", - "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.4.0->tianshou==0.4.8)\r\n", - " Obtaining dependency information for nvidia-nvjitlink-cu12 from https://files.pythonhosted.org/packages/0a/f8/5193b57555cbeecfdb6ade643df0d4218cc6385485492b6e2f64ceae53bb/nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl.metadata\r\n", - " Downloading nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\r\n", - "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (5.3.1)\r\n", - "Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (0.3.0)\r\n", - "Requirement already satisfied: rsa<5,>=3.1.4 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (4.9)\r\n", - "Requirement already satisfied: requests-oauthlib>=0.7.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard>=2.5.0->tianshou==0.4.8) (1.3.1)\r\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (3.3.0)\r\n", - "Requirement already satisfied: idna<4,>=2.5 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (3.4)\r\n", - "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (2.0.6)\r\n", - "Requirement already satisfied: certifi>=2017.4.17 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (2023.7.22)\r\n", - "Requirement already satisfied: MarkupSafe>=2.1.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from werkzeug>=1.0.1->tensorboard>=2.5.0->tianshou==0.4.8) (2.1.3)\r\n", - "Requirement already satisfied: mpmath>=0.19 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from sympy->torch>=1.4.0->tianshou==0.4.8) (1.3.0)\r\n", - "Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (0.5.0)\r\n", - "Requirement already satisfied: oauthlib>=3.0.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard>=2.5.0->tianshou==0.4.8) (3.2.2)\r\n", - "Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m731.7/731.7 MB\u001B[0m \u001B[31m3.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m:00:01\u001B[0m00:01\u001B[0m\r\n", - "\u001B[?25hDownloading triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (89.2 MB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m89.2/89.2 MB\u001B[0m \u001B[31m13.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n", - "\u001B[?25hDownloading nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl (20.2 MB)\r\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m20.2/20.2 MB\u001B[0m \u001B[31m15.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n", - "\u001B[?25hBuilding wheels for collected packages: gym\r\n", - " Building wheel for gym (pyproject.toml) ... \u001B[?25ldone\r\n", - "\u001B[?25h Created wheel for gym: filename=gym-0.26.2-py3-none-any.whl size=827621 sha256=612698033ee83c54db52d872001a111f5f0adf14dd996065edff561305ac2266\r\n", - " Stored in directory: /home/ccagnetta/.cache/pip/wheels/1c/77/9e/9af5470201a0b0543937933ee99ba884cd237d2faefe8f4d37\r\n", - "Successfully built gym\r\n", - "Installing collected packages: gym-notices, triton, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, gym, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, tianshou\r\n", - " Attempting uninstall: triton\r\n", - " Found existing installation: triton 2.0.0\r\n", - " Uninstalling triton-2.0.0:\r\n", - " Successfully uninstalled triton-2.0.0\r\n", - " Attempting uninstall: tianshou\r\n", - " Found existing installation: tianshou 0.5.1\r\n", - " Uninstalling tianshou-0.5.1:\r\n", - " Successfully uninstalled tianshou-0.5.1\r\n", - "Successfully installed gym-0.26.2 gym-notices-0.0.8 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.18.1 nvidia-nvjitlink-cu12-12.2.140 nvidia-nvtx-cu12-12.1.105 tianshou-0.4.8 triton-2.1.0\r\n", - "Requirement already satisfied: gym in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (0.26.2)\r\n", - "Requirement already satisfied: numpy>=1.18.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym) (1.24.4)\r\n", - "Requirement already satisfied: cloudpickle>=1.2.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym) (2.2.1)\r\n", - "Requirement already satisfied: gym-notices>=0.0.4 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym) (0.0.8)\r\n" - ] - } - ], + "outputs": [], "source": [ "!pip install tianshou==0.4.8\n", "!pip install gym" @@ -248,23 +125,7 @@ "outputId": "b792fc24-f42c-426a-9d83-fe1a4f3f91f1" }, "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Epoch #1: 50001it [00:19, 2529.50it/s, env_step=50000, len=87, loss=80.895, loss/clip=-0.009, loss/ent=0.566, loss/vf=161.818, n/ep=15, n/st=2000, rew=87.27] \n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epoch #1: test_reward: 200.000000 ± 0.000000, best_reward: 200.000000 ± 0.000000 in #1\n", - "{'duration': '20.26s', 'train_time/model': '13.75s', 'test_step': 2159, 'test_episode': 20, 'test_time': '0.48s', 'test_speed': '4496.33 step/s', 'best_reward': 200.0, 'best_result': '200.00 ± 0.00', 'train_step': 50000, 'train_episode': 944, 'train_time/collector': '6.03s', 'train_speed': '2527.97 step/s'}\n" - ] - } - ] + "outputs": [] }, { "cell_type": "code", @@ -282,15 +143,7 @@ "outputId": "2a9b5b22-be50-4bb7-ae93-af7e65e7442a" }, "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Final reward: 200.0, length: 200.0\n" - ] - } - ] + "outputs": [] }, { "cell_type": "markdown", diff --git a/notebooks/L1_Batch.ipynb b/notebooks/L1_Batch.ipynb index dfafa57..b297a67 100644 --- a/notebooks/L1_Batch.ipynb +++ b/notebooks/L1_Batch.ipynb @@ -1,716 +1,390 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] }, - "cells": [ - { - "cell_type": "code", - "source": [ - "# Remember to install tianshou first\n", - "!pip install tianshou==0.4.8\n", - "!pip install gym" - ], - "metadata": { - "id": "sHZTxH6m2FpG" - }, - "execution_count": null, - "outputs": [] + "kernelspec": { + "name": "python3", + "language": "python", + "display_name": "Python 3 (ipykernel)" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "source": [ + "# Remember to install tianshou first\n", + "!pip install tianshou==0.4.8\n", + "!pip install gym" + ], + "metadata": { + "id": "sHZTxH6m2FpG" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Overview\n", + "In this tutorial, we will introduce the **Batch** to you, which is the most basic data structure in Tianshou. You can simply considered Batch as a numpy version of python dictionary." + ], + "metadata": { + "id": "69y6AHvq1S3f" + } + }, + { + "cell_type": "code", + "source": [ + "import numpy as np\n", + "from tianshou.data import Batch\n", + "data = Batch(a=4, b=[5, 5], c='2312312', d=('a', -2, -3))\n", + "print(data)\n", + "print(data.b)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "# Overview\n", - "In this tutorial, we will introduce the **Batch** to you, which is the most basic data structure in Tianshou. You can simply considered Batch as a numpy version of python dictionary." - ], - "metadata": { - "id": "69y6AHvq1S3f" - } + "id": "NkfiIe_y2FI-", + "outputId": "5008275f-8f77-489a-af64-b35af4448589" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "A batch is simply a dictionary which stores all passed in data as key-value pairs, and automatically turns the value into a numpy array if possible.\n", + "\n", + "## Why we need Batch in Tianshou?\n", + "The motivation behind the implementation of Batch module is simple. In DRL, you need to handle a lot of dictionary-format data. For instance most algorithms would reuqire you to store state, action, and reward data for every step when interacting with the environment. All these data can be organised as a dictionary and a Batch module helps Tianshou unify the interface of a diverse set of algorithms. Plus, Batch supports advanced indexing, concantenation and splitting, formatting print just like any other numpy array, which may be very helpful for developers.\n", + "
\n", + "\n", + "\n", + " Data flow is converted into a Batch in Tianshou \n", + "
\n", + "\n" + ], + "metadata": { + "id": "S6e6OuXe3UT-" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Basic Usages" + ], + "metadata": { + "id": "_Xenx64M9HhV" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Initialisation\n", + "Batch can be converted directly from a python dictionary, and all data structure will be converted to numpy array if possible." + ], + "metadata": { + "id": "4YGX_f1Z9Uil" + } + }, + { + "cell_type": "code", + "source": [ + "# converted from a python library\n", + "print(\"========================================\")\n", + "batch1 = Batch({'a': [4, 4], 'b': (5, 5)})\n", + "print(batch1)\n", + "\n", + "# initialisation of batch2 is equivalent to batch1\n", + "print(\"========================================\")\n", + "batch2 = Batch(a=[4, 4], b=(5, 5))\n", + "print(batch2)\n", + "\n", + "# the dictionary can be nested, and it will be turned into a nested Batch\n", + "print(\"========================================\")\n", + "data = {\n", + " 'action': np.array([1.0, 2.0, 3.0]),\n", + " 'reward': 3.66,\n", + " 'obs': {\n", + " \"rgb_obs\": np.zeros((3, 3)),\n", + " \"flatten_obs\": np.ones(5),\n", + " },\n", + " }\n", + "\n", + "\n", + "batch3 = Batch(data, extra=\"extra_string\")\n", + "print(batch3)\n", + "# batch3.obs is also a Batch\n", + "print(type(batch3.obs))\n", + "print(batch3.obs.rgb_obs)\n", + "\n", + "# a list of dictionary/Batch will automatically be concatenated/stacked, providing convenience if you\n", + "# want to use parallelized environments to collect data.\n", + "print(\"========================================\")\n", + "batch4 = Batch([data] * 3)\n", + "print(batch4)\n", + "print(batch4.obs.rgb_obs.shape)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "source": [ - "import numpy as np\n", - "from tianshou.data import Batch\n", - "data = Batch(a=4, b=[5, 5], c='2312312', d=('a', -2, -3))\n", - "print(data)\n", - "print(data.b)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "NkfiIe_y2FI-", - "outputId": "5008275f-8f77-489a-af64-b35af4448589" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Batch(\n", - " a: array(4),\n", - " b: array([5, 5]),\n", - " c: '2312312',\n", - " d: array(['a', '-2', '-3'], dtype=object),\n", - ")\n", - "[5 5]\n" - ] - } - ] + "id": "Jl3-4BRbp3MM", + "outputId": "a8b225f6-2893-4716-c694-3c2ff558b7f0" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Getting access to data\n", + "You can conveniently search or change the key-value pair in the Batch just as if it is a python dictionary." + ], + "metadata": { + "id": "JCf6bqY3uf5L" + } + }, + { + "cell_type": "code", + "source": [ + "batch1 = Batch({'a': [4, 4], 'b': (5, 5)})\n", + "print(batch1)\n", + "# add or delete key-value pair in batch1\n", + "print(\"========================================\")\n", + "batch1.c = Batch(c1=np.arange(3), c2=False)\n", + "del batch1.a\n", + "print(batch1)\n", + "\n", + "# access value by key\n", + "print(\"========================================\")\n", + "assert batch1[\"c\"] is batch1.c\n", + "print(\"c\" in batch1)\n", + "\n", + "# traverse the Batch\n", + "print(\"========================================\")\n", + "for key, value in batch1.items():\n", + " print(str(key) + \": \" + str(value))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "A batch is simply a dictionary which stores all passed in data as key-value pairs, and automatically turns the value into a numpy array if possible.\n", - "\n", - "## Why we need Batch in Tianshou?\n", - "The motivation behind the implementation of Batch module is simple. In DRL, you need to handle a lot of dictionary-format data. For instance most algorithms would reuqire you to store state, action, and reward data for every step when interacting with the environment. All these data can be organised as a dictionary and a Batch module helps Tianshou unify the interface of a diverse set of algorithms. Plus, Batch supports advanced indexing, concantenation and splitting, formatting print just like any other numpy array, which may be very helpful for developers.\n", - "
\n", - "\n", - "\n", - " Data flow is converted into a Batch in Tianshou \n", - "
\n", - "\n" - ], - "metadata": { - "id": "S6e6OuXe3UT-" - } + "id": "2TNIY90-vU9b", + "outputId": "de52ffe9-03c2-45f2-d95a-4071132daa4a" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Indexing and Slicing\n", + "If all values in Batch share the same shape in certain dimensions. Batch can support advanced indexing and slicing just like a normal numpy array." + ], + "metadata": { + "id": "bVywStbV9jD2" + } + }, + { + "cell_type": "code", + "source": [ + "# Let us suppose we've got 4 environments, each returns a step of data\n", + "step_datas = [\n", + " {\n", + " \"act\": np.random.randint(10),\n", + " \"rew\": 0.0,\n", + " \"obs\": np.ones((3, 3)),\n", + " \"info\": {\"done\": np.random.choice(2), \"failed\": False},\n", + " } for _ in range(4)\n", + " ]\n", + "batch = Batch(step_datas)\n", + "print(batch)\n", + "print(batch.shape)\n", + "\n", + "# advanced indexing is supported, if we only want to select data in a given set of environments\n", + "print(\"========================================\")\n", + "print(batch[0])\n", + "print(batch[[0,3]])\n", + "\n", + "# slicing is also supported\n", + "print(\"========================================\")\n", + "print(batch[-2:])\n", + "\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "# Basic Usages" - ], - "metadata": { - "id": "_Xenx64M9HhV" - } + "id": "gKza3OJnzc_D", + "outputId": "4f240bfe-4a69-4c1b-b40e-983c5c4d0cbc" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Aggregation and Splitting\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "Again, just like a numpy array. Play the example code below." + ], + "metadata": { + "id": "1vUwQ-Hw9jtu" + } + }, + { + "cell_type": "code", + "source": [ + "# concat batches with compatible keys\n", + "# try incompatible keys yourself if you feel curious\n", + "print(\"========================================\")\n", + "b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])\n", + "b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])\n", + "b12_cat_out = Batch.cat([b1, b2])\n", + "print(b1)\n", + "print(b2)\n", + "print(b12_cat_out)\n", + "\n", + "# stack batches with compatible keys\n", + "# try incompatible keys yourself if you feel curious\n", + "print(\"========================================\")\n", + "b3 = Batch(a=np.zeros((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[1], [2]]))\n", + "b4 = Batch(a=np.ones((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[0], [3]]))\n", + "b34_stack = Batch.stack((b3, b4), axis=1)\n", + "print(b3)\n", + "print(b4)\n", + "print(b34_stack)\n", + "\n", + "# split the batch into small batches of size 1, breaking the order of the data\n", + "print(\"========================================\")\n", + "print(type(b34_stack.split(1)))\n", + "print(list(b34_stack.split(1, shuffle=True)))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "## Initialisation\n", - "Batch can be converted directly from a python dictionary, and all data structure will be converted to numpy array if possible." - ], - "metadata": { - "id": "4YGX_f1Z9Uil" - } + "id": "f5UkReyn3_kb", + "outputId": "e7bb3324-7f20-4810-a328-479117efca55" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Data type converting\n", + "Besides numpy array, Batch actually also supports Torch Tensor. The usages are exactly the same. Cool, isn't it?" + ], + "metadata": { + "id": "Smc_W1Cx6zRS" + } + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "batch1 = Batch(a=np.arange(2), b=torch.zeros((2,2)))\n", + "batch2 = Batch(a=np.arange(2), b=torch.ones((2,2)))\n", + "batch_cat = Batch.cat([batch1, batch2, batch1])\n", + "print(batch_cat)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "source": [ - "# converted from a python library\n", - "print(\"========================================\")\n", - "batch1 = Batch({'a': [4, 4], 'b': (5, 5)})\n", - "print(batch1)\n", - "\n", - "# initialisation of batch2 is equivalent to batch1\n", - "print(\"========================================\")\n", - "batch2 = Batch(a=[4, 4], b=(5, 5))\n", - "print(batch2)\n", - "\n", - "# the dictionary can be nested, and it will be turned into a nested Batch\n", - "print(\"========================================\")\n", - "data = {\n", - " 'action': np.array([1.0, 2.0, 3.0]),\n", - " 'reward': 3.66,\n", - " 'obs': {\n", - " \"rgb_obs\": np.zeros((3, 3)),\n", - " \"flatten_obs\": np.ones(5),\n", - " },\n", - " }\n", - "\n", - "\n", - "batch3 = Batch(data, extra=\"extra_string\")\n", - "print(batch3)\n", - "# batch3.obs is also a Batch\n", - "print(type(batch3.obs))\n", - "print(batch3.obs.rgb_obs)\n", - "\n", - "# a list of dictionary/Batch will automatically be concatenated/stacked, providing convenience if you\n", - "# want to use parallelized environments to collect data.\n", - "print(\"========================================\")\n", - "batch4 = Batch([data] * 3)\n", - "print(batch4)\n", - "print(batch4.obs.rgb_obs.shape)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Jl3-4BRbp3MM", - "outputId": "a8b225f6-2893-4716-c694-3c2ff558b7f0" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "========================================\n", - "Batch(\n", - " a: array([4, 4]),\n", - " b: array([5, 5]),\n", - ")\n", - "========================================\n", - "Batch(\n", - " a: array([4, 4]),\n", - " b: array([5, 5]),\n", - ")\n", - "========================================\n", - "Batch(\n", - " action: array([1., 2., 3.]),\n", - " reward: array(3.66),\n", - " obs: Batch(\n", - " rgb_obs: array([[0., 0., 0.],\n", - " [0., 0., 0.],\n", - " [0., 0., 0.]]),\n", - " flatten_obs: array([1., 1., 1., 1., 1.]),\n", - " ),\n", - " extra: 'extra_string',\n", - ")\n", - "\n", - "[[0. 0. 0.]\n", - " [0. 0. 0.]\n", - " [0. 0. 0.]]\n", - "========================================\n", - "Batch(\n", - " obs: Batch(\n", - " rgb_obs: array([[[0., 0., 0.],\n", - " [0., 0., 0.],\n", - " [0., 0., 0.]],\n", - " \n", - " [[0., 0., 0.],\n", - " [0., 0., 0.],\n", - " [0., 0., 0.]],\n", - " \n", - " [[0., 0., 0.],\n", - " [0., 0., 0.],\n", - " [0., 0., 0.]]]),\n", - " flatten_obs: array([[1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.]]),\n", - " ),\n", - " reward: array([3.66, 3.66, 3.66]),\n", - " action: array([[1., 2., 3.],\n", - " [1., 2., 3.],\n", - " [1., 2., 3.]]),\n", - ")\n", - "(3, 3, 3)\n" - ] - } - ] + "id": "Y6im_Mtb7Ody", + "outputId": "898e82c4-b940-4c35-a0f9-dedc4a9bc500" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "You can convert the data type easily, if you no longer want to use hybrid data type anymore." + ], + "metadata": { + "id": "1wfTUVKb6xki" + } + }, + { + "cell_type": "code", + "source": [ + "batch_cat.to_numpy()\n", + "print(batch_cat)\n", + "batch_cat.to_torch()\n", + "print(batch_cat)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "## Getting access to data\n", - "You can conveniently search or change the key-value pair in the Batch just as if it is a python dictionary." - ], - "metadata": { - "id": "JCf6bqY3uf5L" - } + "id": "F7WknVs98DHD", + "outputId": "cfd0712a-1df3-4208-e6cc-9149840bdc40" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Batch is even serializable, just in case you may need to save it to disk or restore it." + ], + "metadata": { + "id": "NTFVle1-9Biz" + } + }, + { + "cell_type": "code", + "source": [ + "import pickle\n", + "batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))\n", + "batch_pk = pickle.loads(pickle.dumps(batch))\n", + "print(batch_pk)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "source": [ - "batch1 = Batch({'a': [4, 4], 'b': (5, 5)})\n", - "print(batch1)\n", - "# add or delete key-value pair in batch1\n", - "print(\"========================================\")\n", - "batch1.c = Batch(c1=np.arange(3), c2=False)\n", - "del batch1.a\n", - "print(batch1)\n", - "\n", - "# access value by key\n", - "print(\"========================================\")\n", - "assert batch1[\"c\"] is batch1.c\n", - "print(\"c\" in batch1)\n", - "\n", - "# traverse the Batch\n", - "print(\"========================================\")\n", - "for key, value in batch1.items():\n", - " print(str(key) + \": \" + str(value))" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "2TNIY90-vU9b", - "outputId": "de52ffe9-03c2-45f2-d95a-4071132daa4a" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Batch(\n", - " a: array([4, 4]),\n", - " b: array([5, 5]),\n", - ")\n", - "========================================\n", - "Batch(\n", - " b: array([5, 5]),\n", - " c: Batch(\n", - " c1: array([0, 1, 2]),\n", - " c2: array(False),\n", - " ),\n", - ")\n", - "========================================\n", - "True\n", - "========================================\n", - "b: [5 5]\n", - "c: Batch(\n", - " c1: array([0, 1, 2]),\n", - " c2: array(False),\n", - ")\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Indexing and Slicing\n", - "If all values in Batch share the same shape in certain dimensions. Batch can support advanced indexing and slicing just like a normal numpy array." - ], - "metadata": { - "id": "bVywStbV9jD2" - } - }, - { - "cell_type": "code", - "source": [ - "# Let us suppose we've got 4 environments, each returns a step of data\n", - "step_datas = [\n", - " {\n", - " \"act\": np.random.randint(10),\n", - " \"rew\": 0.0,\n", - " \"obs\": np.ones((3, 3)),\n", - " \"info\": {\"done\": np.random.choice(2), \"failed\": False},\n", - " } for _ in range(4)\n", - " ]\n", - "batch = Batch(step_datas)\n", - "print(batch)\n", - "print(batch.shape)\n", - "\n", - "# advanced indexing is supported, if we only want to select data in a given set of environments\n", - "print(\"========================================\")\n", - "print(batch[0])\n", - "print(batch[[0,3]])\n", - "\n", - "# slicing is also supported\n", - "print(\"========================================\")\n", - "print(batch[-2:])\n", - "\n" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "gKza3OJnzc_D", - "outputId": "4f240bfe-4a69-4c1b-b40e-983c5c4d0cbc" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Batch(\n", - " obs: array([[[1., 1., 1.],\n", - " [1., 1., 1.],\n", - " [1., 1., 1.]],\n", - " \n", - " [[1., 1., 1.],\n", - " [1., 1., 1.],\n", - " [1., 1., 1.]],\n", - " \n", - " [[1., 1., 1.],\n", - " [1., 1., 1.],\n", - " [1., 1., 1.]],\n", - " \n", - " [[1., 1., 1.],\n", - " [1., 1., 1.],\n", - " [1., 1., 1.]]]),\n", - " rew: array([0., 0., 0., 0.]),\n", - " info: Batch(\n", - " done: array([0, 1, 1, 0]),\n", - " failed: array([False, False, False, False]),\n", - " ),\n", - " act: array([0, 5, 1, 8]),\n", - ")\n", - "[4]\n", - "========================================\n", - "Batch(\n", - " obs: array([[1., 1., 1.],\n", - " [1., 1., 1.],\n", - " [1., 1., 1.]]),\n", - " rew: 0.0,\n", - " info: Batch(\n", - " done: 0,\n", - " failed: False,\n", - " ),\n", - " act: 0,\n", - ")\n", - "Batch(\n", - " obs: array([[[1., 1., 1.],\n", - " [1., 1., 1.],\n", - " [1., 1., 1.]],\n", - " \n", - " [[1., 1., 1.],\n", - " [1., 1., 1.],\n", - " [1., 1., 1.]]]),\n", - " rew: array([0., 0.]),\n", - " info: Batch(\n", - " done: array([0, 0]),\n", - " failed: array([False, False]),\n", - " ),\n", - " act: array([0, 8]),\n", - ")\n", - "========================================\n", - "Batch(\n", - " obs: array([[[1., 1., 1.],\n", - " [1., 1., 1.],\n", - " [1., 1., 1.]],\n", - " \n", - " [[1., 1., 1.],\n", - " [1., 1., 1.],\n", - " [1., 1., 1.]]]),\n", - " rew: array([0., 0.]),\n", - " info: Batch(\n", - " done: array([1, 0]),\n", - " failed: array([False, False]),\n", - " ),\n", - " act: array([1, 8]),\n", - ")\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Aggregation and Splitting\n" - ], - "metadata": { - "id": "C7N9kU_Q9jXm" - } - }, - { - "cell_type": "markdown", - "source": [ - "Again, just like a numpy array. Play the example code below." - ], - "metadata": { - "id": "1vUwQ-Hw9jtu" - } - }, - { - "cell_type": "code", - "source": [ - "# concat batches with compatible keys\n", - "# try incompatible keys yourself if you feel curious\n", - "print(\"========================================\")\n", - "b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])\n", - "b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])\n", - "b12_cat_out = Batch.cat([b1, b2])\n", - "print(b1)\n", - "print(b2)\n", - "print(b12_cat_out)\n", - "\n", - "# stack batches with compatible keys\n", - "# try incompatible keys yourself if you feel curious\n", - "print(\"========================================\")\n", - "b3 = Batch(a=np.zeros((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[1], [2]]))\n", - "b4 = Batch(a=np.ones((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[0], [3]]))\n", - "b34_stack = Batch.stack((b3, b4), axis=1)\n", - "print(b3)\n", - "print(b4)\n", - "print(b34_stack)\n", - "\n", - "# split the batch into small batches of size 1, breaking the order of the data\n", - "print(\"========================================\")\n", - "print(type(b34_stack.split(1)))\n", - "print(list(b34_stack.split(1, shuffle=True)))" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "f5UkReyn3_kb", - "outputId": "e7bb3324-7f20-4810-a328-479117efca55" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "========================================\n", - "Batch(\n", - " a: Batch(\n", - " d: Batch(\n", - " e: array([3.]),\n", - " ),\n", - " b: array([1.]),\n", - " ),\n", - ")\n", - "Batch(\n", - " a: Batch(\n", - " d: Batch(\n", - " e: array([6.]),\n", - " ),\n", - " b: array([4.]),\n", - " ),\n", - ")\n", - "Batch(\n", - " a: Batch(\n", - " d: Batch(\n", - " e: array([3., 6.]),\n", - " ),\n", - " b: array([1., 4.]),\n", - " ),\n", - ")\n", - "========================================\n", - "Batch(\n", - " a: array([[0., 0.],\n", - " [0., 0.],\n", - " [0., 0.]]),\n", - " b: array([[1., 1., 1.],\n", - " [1., 1., 1.]]),\n", - " c: Batch(\n", - " d: array([[1],\n", - " [2]]),\n", - " ),\n", - ")\n", - "Batch(\n", - " a: array([[1., 1.],\n", - " [1., 1.],\n", - " [1., 1.]]),\n", - " b: array([[1., 1., 1.],\n", - " [1., 1., 1.]]),\n", - " c: Batch(\n", - " d: array([[0],\n", - " [3]]),\n", - " ),\n", - ")\n", - "Batch(\n", - " c: Batch(\n", - " d: array([[[1],\n", - " [0]],\n", - " \n", - " [[2],\n", - " [3]]]),\n", - " ),\n", - " a: array([[[0., 0.],\n", - " [1., 1.]],\n", - " \n", - " [[0., 0.],\n", - " [1., 1.]],\n", - " \n", - " [[0., 0.],\n", - " [1., 1.]]]),\n", - " b: array([[[1., 1., 1.],\n", - " [1., 1., 1.]],\n", - " \n", - " [[1., 1., 1.],\n", - " [1., 1., 1.]]]),\n", - ")\n", - "========================================\n", - "\n", - "[Batch(\n", - " c: Batch(\n", - " d: array([[[1],\n", - " [0]]]),\n", - " ),\n", - " a: array([[[0., 0.],\n", - " [1., 1.]]]),\n", - " b: array([[[1., 1., 1.],\n", - " [1., 1., 1.]]]),\n", - "), Batch(\n", - " c: Batch(\n", - " d: array([[[2],\n", - " [3]]]),\n", - " ),\n", - " a: array([[[0., 0.],\n", - " [1., 1.]]]),\n", - " b: array([[[1., 1., 1.],\n", - " [1., 1., 1.]]]),\n", - ")]\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Data type converting\n", - "Besides numpy array, Batch actually also supports Torch Tensor. The usages are exactly the same. Cool, isn't it?" - ], - "metadata": { - "id": "Smc_W1Cx6zRS" - } - }, - { - "cell_type": "code", - "source": [ - "import torch\n", - "batch1 = Batch(a=np.arange(2), b=torch.zeros((2,2)))\n", - "batch2 = Batch(a=np.arange(2), b=torch.ones((2,2)))\n", - "batch_cat = Batch.cat([batch1, batch2, batch1])\n", - "print(batch_cat)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Y6im_Mtb7Ody", - "outputId": "898e82c4-b940-4c35-a0f9-dedc4a9bc500" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Batch(\n", - " a: array([0, 1, 0, 1, 0, 1]),\n", - " b: tensor([[0., 0.],\n", - " [0., 0.],\n", - " [1., 1.],\n", - " [1., 1.],\n", - " [0., 0.],\n", - " [0., 0.]]),\n", - ")\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "You can convert the data type easily, if you no longer want to use hybrid data type anymore." - ], - "metadata": { - "id": "1wfTUVKb6xki" - } - }, - { - "cell_type": "code", - "source": [ - "batch_cat.to_numpy()\n", - "print(batch_cat)\n", - "batch_cat.to_torch()\n", - "print(batch_cat)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "F7WknVs98DHD", - "outputId": "cfd0712a-1df3-4208-e6cc-9149840bdc40" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Batch(\n", - " a: array([0, 1, 0, 1, 0, 1]),\n", - " b: array([[0., 0.],\n", - " [0., 0.],\n", - " [1., 1.],\n", - " [1., 1.],\n", - " [0., 0.],\n", - " [0., 0.]], dtype=float32),\n", - ")\n", - "Batch(\n", - " a: tensor([0, 1, 0, 1, 0, 1]),\n", - " b: tensor([[0., 0.],\n", - " [0., 0.],\n", - " [1., 1.],\n", - " [1., 1.],\n", - " [0., 0.],\n", - " [0., 0.]]),\n", - ")\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "Batch is even serializable, just in case you may need to save it to disk or restore it." - ], - "metadata": { - "id": "NTFVle1-9Biz" - } - }, - { - "cell_type": "code", - "source": [ - "import pickle\n", - "batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))\n", - "batch_pk = pickle.loads(pickle.dumps(batch))\n", - "print(batch_pk)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Lnf17OXv9YRb", - "outputId": "753753f2-3f66-4d4b-b4ff-d57f9c40d1da" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Batch(\n", - " obs: Batch(\n", - " a: array(0.),\n", - " c: tensor([1., 2.]),\n", - " ),\n", - " np: array([[0., 0., 0., 0.],\n", - " [0., 0., 0., 0.],\n", - " [0., 0., 0., 0.]]),\n", - ")\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "# Further Reading" - ], - "metadata": { - "id": "-vPMiPZ-9kJN" - } - }, - { - "cell_type": "markdown", - "source": [ - "Would like to learn more advanced usages of Batch? Feel curious about how data is organised inside the Batch? Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.data.html) and other [tutorials](https://tianshou.readthedocs.io/en/master/tutorials/batch.html#) for more details." - ], - "metadata": { - "id": "8Oc1p8ud9kcu" - } - } - ] -} \ No newline at end of file + "id": "Lnf17OXv9YRb", + "outputId": "753753f2-3f66-4d4b-b4ff-d57f9c40d1da" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Further Reading" + ], + "metadata": { + "id": "-vPMiPZ-9kJN" + } + }, + { + "cell_type": "markdown", + "source": [ + "Would like to learn more advanced usages of Batch? Feel curious about how data is organised inside the Batch? Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.data.html) and other [tutorials](https://tianshou.readthedocs.io/en/master/tutorials/batch.html#) for more details." + ], + "metadata": { + "id": "8Oc1p8ud9kcu" + } + } + ] +} diff --git a/notebooks/L2_Buffer.ipynb b/notebooks/L2_Buffer.ipynb index a559444..e5ae2d8 100644 --- a/notebooks/L2_Buffer.ipynb +++ b/notebooks/L2_Buffer.ipynb @@ -1,575 +1,394 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] }, - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "4TCEkXj7LFe2" - }, - "outputs": [], - "source": [ - "# Remember to install tianshou first\n", - "!pip install tianshou==0.4.8\n", - "!pip install gym" - ] + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4TCEkXj7LFe2" + }, + "outputs": [], + "source": [ + "# Remember to install tianshou first\n", + "!pip install tianshou==0.4.8\n", + "!pip install gym" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Overview\n", + "Replay Buffer is a very common module in DRL implementations. In Tianshou, you can consider Buffer module as as a specialized form of Batch, which helps you track all data trajectories and provide utilities such as sampling method besides the basic storage.\n", + "\n", + "There are many kinds of Buffer modules in Tianshou, two most basic ones are ReplayBuffer and VectorReplayBuffer. The later one is specially designed for parallelized environments (will introduce in tutorial L3)." + ], + "metadata": { + "id": "xoPiGVD8LNma" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Usages" + ], + "metadata": { + "id": "OdesCAxANehZ" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Basic usages as a batch\n", + "Usually a buffer stores all the data in a batch with circular-queue style." + ], + "metadata": { + "id": "fUbLl9T_SrTR" + } + }, + { + "cell_type": "code", + "source": [ + "from tianshou.data import Batch, ReplayBuffer\n", + "# a buffer is initialised with its maxsize set to 10 (older data will be discarded if more data flow in).\n", + "print(\"========================================\")\n", + "buf = ReplayBuffer(size=10)\n", + "print(buf)\n", + "print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n", + "\n", + "# add 3 steps of data into ReplayBuffer sequentially\n", + "print(\"========================================\")\n", + "for i in range(3):\n", + " buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))\n", + "print(buf)\n", + "print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n", + "\n", + "# add another 10 steps of data into ReplayBuffer sequentially\n", + "print(\"========================================\")\n", + "for i in range(3, 13):\n", + " buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))\n", + "print(buf)\n", + "print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "# Overview\n", - "Replay Buffer is a very common module in DRL implementations. In Tianshou, you can consider Buffer module as as a specialized form of Batch, which helps you track all data trajectories and provide utilities such as sampling method besides the basic storage.\n", - "\n", - "There are many kinds of Buffer modules in Tianshou, two most basic ones are ReplayBuffer and VectorReplayBuffer. The later one is specially designed for parallelized environments (will introduce in tutorial L3)." - ], - "metadata": { - "id": "xoPiGVD8LNma" - } + "id": "mocZ6IqZTH62", + "outputId": "66cc4181-c51b-4a47-aacf-666b92b7fc52" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Just like Batch, ReplayBuffer supports concatenation, splitting, advanced slicing and indexing, etc." + ], + "metadata": { + "id": "H8B85Y5yUfTy" + } + }, + { + "cell_type": "code", + "source": [ + "print(buf[-1])\n", + "print(buf[-3:])\n", + "# Try more methods you find useful in Batch yourself." + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "# Usages" - ], - "metadata": { - "id": "OdesCAxANehZ" - } + "id": "cOX-ADOPNeEK", + "outputId": "f1a8ec01-b878-419b-f180-bdce3dee73e6" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "ReplayBuffer can also be saved into local disk, still keeping track of the trajectories. This is extremely helpful in offline DRL settings." + ], + "metadata": { + "id": "vqldap-2WQBh" + } + }, + { + "cell_type": "code", + "source": [ + "import pickle\n", + "_buf = pickle.loads(pickle.dumps(buf))" + ], + "metadata": { + "id": "Ppx0L3niNT5K" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Understanding reserved keys for buffer\n", + "As I have explained, ReplayBuffer is specially designed to utilize the implementations of DRL algorithms. So, for convenience, we reserve certain seven reserved keys in Batch.\n", + "\n", + "* `obs`\n", + "* `act`\n", + "* `rew`\n", + "* `done`\n", + "* `obs_next`\n", + "* `info`\n", + "* `policy`\n", + "\n", + "The meaning of these seven reserved keys are consistent with the meaning in [OPENAI Gym](https://gym.openai.com/). We would recommend you simply use these seven keys when adding batched data into ReplayBuffer, because\n", + "some of them are tracked in ReplayBuffer (e.g. \"done\" value is tracked to help us determine a trajectory's start index and end index, together with its total reward and episode length.)\n", + "\n", + "```\n", + "buf.add(Batch(......, extro_info=0)) # This is okay but not recommended.\n", + "buf.add(Batch(......, info={\"extro_info\":0})) # Recommended.\n", + "```\n" + ], + "metadata": { + "id": "Eqezp0OyXn6J" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Data sampling\n", + "We keep a replay buffer in DRL for one purpose:\"sample data from it for training\". `ReplayBuffer.sample()` and `ReplayBuffer.split(..., shuffle=True)` can both fullfill this need." + ], + "metadata": { + "id": "ueAbTspsc6jo" + } + }, + { + "cell_type": "code", + "source": [ + "buf.sample(5)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "## Basic usages as a batch\n", - "Usually a buffer stores all the data in a batch with circular-queue style." - ], - "metadata": { - "id": "fUbLl9T_SrTR" - } + "id": "P5xnYOhrchDl", + "outputId": "bcd2c970-efa6-43bb-8709-720d38f77bbd" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Trajectory tracking\n", + "Compared to Batch, a unique feature of ReplayBuffer is that it can help you track the environment trajectories.\n", + "\n", + "First, let us simulate a situation, where we add three trajectories into the buffer. The last trajectory is still not finished yet." + ], + "metadata": { + "id": "IWyaOSKOcgK4" + } + }, + { + "cell_type": "code", + "source": [ + "from numpy import False_\n", + "buf = ReplayBuffer(size=10)\n", + "# Add the first trajectory (length is 3) into ReplayBuffer\n", + "print(\"========================================\")\n", + "for i in range(3):\n", + " result = buf.add(Batch(obs=i, act=i, rew=i, done=True if i==2 else False, obs_next=i + 1, info={}))\n", + " print(result)\n", + "print(buf)\n", + "print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n", + "# Add the second trajectory (length is 5) into ReplayBuffer\n", + "print(\"========================================\")\n", + "for i in range(3, 8):\n", + " result = buf.add(Batch(obs=i, act=i, rew=i, done=True if i==7 else False, obs_next=i + 1, info={}))\n", + " print(result)\n", + "print(buf)\n", + "print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n", + "# Add the third trajectory (length is 5, still not finished) into ReplayBuffer\n", + "print(\"========================================\")\n", + "for i in range(8, 13):\n", + " result = buf.add(Batch(obs=i, act=i, rew=i, done=False, obs_next=i + 1, info={}))\n", + " print(result)\n", + "print(buf)\n", + "print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "source": [ - "from tianshou.data import Batch, ReplayBuffer\n", - "# a buffer is initialised with its maxsize set to 10 (older data will be discarded if more data flow in).\n", - "print(\"========================================\")\n", - "buf = ReplayBuffer(size=10)\n", - "print(buf)\n", - "print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n", - "\n", - "# add 3 steps of data into ReplayBuffer sequentially\n", - "print(\"========================================\")\n", - "for i in range(3):\n", - " buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))\n", - "print(buf)\n", - "print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n", - "\n", - "# add another 10 steps of data into ReplayBuffer sequentially\n", - "print(\"========================================\")\n", - "for i in range(3, 13):\n", - " buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))\n", - "print(buf)\n", - "print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "mocZ6IqZTH62", - "outputId": "66cc4181-c51b-4a47-aacf-666b92b7fc52" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "========================================\n", - "ReplayBuffer()\n", - "maxsize: 10, data length: 0\n", - "========================================\n", - "ReplayBuffer(\n", - " info: Batch(),\n", - " act: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),\n", - " obs: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),\n", - " done: array([False, False, False, False, False, False, False, False, False,\n", - " False]),\n", - " rew: array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0.]),\n", - " obs_next: array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0]),\n", - ")\n", - "maxsize: 10, data length: 3\n", - "========================================\n", - "ReplayBuffer(\n", - " info: Batch(),\n", - " act: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n", - " obs: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n", - " done: array([False, False, False, False, False, False, False, False, False,\n", - " False]),\n", - " rew: array([10., 11., 12., 3., 4., 5., 6., 7., 8., 9.]),\n", - " obs_next: array([11, 12, 13, 4, 5, 6, 7, 8, 9, 10]),\n", - ")\n", - "maxsize: 10, data length: 10\n" - ] - } - ] + "id": "H0qRb6HLfhLB", + "outputId": "9bdb7d4e-b6ec-489f-a221-0bddf706d85b" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### episode length and rewards tracking\n", + "Notice that `ReplayBuffer.add()` returns a tuple of 4 numbers every time it returns, meaning `(current_index, episode_reward, episode_length, episode_start_index)`. `episode_reward` and `episode_length` are valid only when a trajectory is finished. This might save developers some trouble.\n", + "\n" + ], + "metadata": { + "id": "dO7PWdb_hkXA" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Episode index management\n", + "In the ReplayBuffer above, we can get access to any data step by indexing.\n" + ], + "metadata": { + "id": "xbVc90z8itH0" + } + }, + { + "cell_type": "code", + "source": [ + "print(buf)\n", + "data = buf[6]\n", + "print(data)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "Just like Batch, ReplayBuffer supports concatenation, splitting, advanced slicing and indexing, etc." - ], - "metadata": { - "id": "H8B85Y5yUfTy" - } + "id": "4mKwo54MjupY", + "outputId": "9ae14a7e-908b-44eb-afec-89b45bac5961" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now we know that step \"6\" is not the start of an episode (it should be step 4, 4-7 is the second trajectory we add into the ReplayBuffer), we wonder what is the earliest index of the that episode.\n", + "\n", + "This may seem easy but actually it is not. We cannot simply look at the \"done\" flag now, because we can see that since the third-added trajectory is not finished yet, step \"4\" is surrounded by flag \"False\". There are many things to consider. Things could get more nasty if you are using more advanced ReplayBuffer like VectorReplayBuffer, because now the data is not stored in a simple circular-queue.\n", + "\n", + "Luckily, all ReplayBuffer instances help you identify step indexes through a unified API." + ], + "metadata": { + "id": "p5Co_Fmzj8Sw" + } + }, + { + "cell_type": "code", + "source": [ + "# Search for the previous index of index \"6\"\n", + "now_index = 6\n", + "while True:\n", + " prev_index = buf.prev(now_index)\n", + " print(prev_index)\n", + " if prev_index == now_index:\n", + " break\n", + " else:\n", + " now_index = prev_index" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "source": [ - "print(buf[-1])\n", - "print(buf[-3:])\n", - "# Try more methods you find useful in Batch yourself." - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cOX-ADOPNeEK", - "outputId": "f1a8ec01-b878-419b-f180-bdce3dee73e6" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Batch(\n", - " obs: array(9),\n", - " act: array(9),\n", - " rew: array(9.),\n", - " done: array(False),\n", - " obs_next: array(10),\n", - " info: Batch(),\n", - " policy: Batch(),\n", - ")\n", - "Batch(\n", - " obs: array([7, 8, 9]),\n", - " act: array([7, 8, 9]),\n", - " rew: array([7., 8., 9.]),\n", - " done: array([False, False, False]),\n", - " obs_next: array([ 8, 9, 10]),\n", - " info: Batch(),\n", - " policy: Batch(),\n", - ")\n" - ] - } - ] + "id": "DcJ0LEX6mxHg", + "outputId": "7830f5fb-96d9-4298-d09b-24e64b2f633c" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Using `ReplayBuffer.prev()`, we know that the earliest step of that episode is step \"3\". Similarly, `ReplayBuffer.next()` helps us indentify the last index of an episode regardless of which kind of ReplayBuffer we are using." + ], + "metadata": { + "id": "4Wlb57V4lQyQ" + } + }, + { + "cell_type": "code", + "source": [ + "# next step of indexes [4,5,6,7,8,9] are:\n", + "print(buf.next([4,5,6,7,8,9]))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "ReplayBuffer can also be saved into local disk, still keeping track of the trajectories. This is extremely helpful in offline DRL settings." - ], - "metadata": { - "id": "vqldap-2WQBh" - } + "id": "zl5TRMo7oOy5", + "outputId": "4a11612c-3ee0-4e74-b028-c8759e71fbdb" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "We can also search for the indexes which are labeled \"done: False\", but are the last step in a trajectory." + ], + "metadata": { + "id": "YJ9CcWZXoOXw" + } + }, + { + "cell_type": "code", + "source": [ + "print(buf.unfinished_index())" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "source": [ - "import pickle\n", - "_buf = pickle.loads(pickle.dumps(buf))" - ], - "metadata": { - "id": "Ppx0L3niNT5K" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Understanding reserved keys for buffer\n", - "As I have explained, ReplayBuffer is specially designed to utilize the implementations of DRL algorithms. So, for convenience, we reserve certain seven reserved keys in Batch.\n", - "\n", - "* `obs`\n", - "* `act`\n", - "* `rew`\n", - "* `done`\n", - "* `obs_next`\n", - "* `info`\n", - "* `policy`\n", - "\n", - "The meaning of these seven reserved keys are consistent with the meaning in [OPENAI Gym](https://gym.openai.com/). We would recommend you simply use these seven keys when adding batched data into ReplayBuffer, because\n", - "some of them are tracked in ReplayBuffer (e.g. \"done\" value is tracked to help us determine a trajectory's start index and end index, together with its total reward and episode length.)\n", - "\n", - "```\n", - "buf.add(Batch(......, extro_info=0)) # This is okay but not recommended.\n", - "buf.add(Batch(......, info={\"extro_info\":0})) # Recommended.\n", - "```\n" - ], - "metadata": { - "id": "Eqezp0OyXn6J" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Data sampling\n", - "We keep a replay buffer in DRL for one purpose:\"sample data from it for training\". `ReplayBuffer.sample()` and `ReplayBuffer.split(..., shuffle=True)` can both fullfill this need." - ], - "metadata": { - "id": "ueAbTspsc6jo" - } - }, - { - "cell_type": "code", - "source": [ - "buf.sample(5)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "P5xnYOhrchDl", - "outputId": "bcd2c970-efa6-43bb-8709-720d38f77bbd" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(Batch(\n", - " obs: array([10, 11, 4, 3, 8]),\n", - " act: array([10, 11, 4, 3, 8]),\n", - " rew: array([10., 11., 4., 3., 8.]),\n", - " done: array([False, False, False, False, False]),\n", - " obs_next: array([11, 12, 5, 4, 9]),\n", - " info: Batch(),\n", - " policy: Batch(),\n", - " ), array([0, 1, 4, 3, 8]))" - ] - }, - "metadata": {}, - "execution_count": 5 - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Trajectory tracking\n", - "Compared to Batch, a unique feature of ReplayBuffer is that it can help you track the environment trajectories.\n", - "\n", - "First, let us simulate a situation, where we add three trajectories into the buffer. The last trajectory is still not finished yet." - ], - "metadata": { - "id": "IWyaOSKOcgK4" - } - }, - { - "cell_type": "code", - "source": [ - "from numpy import False_\n", - "buf = ReplayBuffer(size=10)\n", - "# Add the first trajectory (length is 3) into ReplayBuffer\n", - "print(\"========================================\")\n", - "for i in range(3):\n", - " result = buf.add(Batch(obs=i, act=i, rew=i, done=True if i==2 else False, obs_next=i + 1, info={}))\n", - " print(result)\n", - "print(buf)\n", - "print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n", - "# Add the second trajectory (length is 5) into ReplayBuffer\n", - "print(\"========================================\")\n", - "for i in range(3, 8):\n", - " result = buf.add(Batch(obs=i, act=i, rew=i, done=True if i==7 else False, obs_next=i + 1, info={}))\n", - " print(result)\n", - "print(buf)\n", - "print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n", - "# Add the third trajectory (length is 5, still not finished) into ReplayBuffer\n", - "print(\"========================================\")\n", - "for i in range(8, 13):\n", - " result = buf.add(Batch(obs=i, act=i, rew=i, done=False, obs_next=i + 1, info={}))\n", - " print(result)\n", - "print(buf)\n", - "print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "H0qRb6HLfhLB", - "outputId": "9bdb7d4e-b6ec-489f-a221-0bddf706d85b" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "========================================\n", - "(array([0]), array([0.]), array([0]), array([0]))\n", - "(array([1]), array([0.]), array([0]), array([0]))\n", - "(array([2]), array([3.]), array([3]), array([0]))\n", - "ReplayBuffer(\n", - " info: Batch(),\n", - " act: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),\n", - " obs: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),\n", - " done: array([False, False, True, False, False, False, False, False, False,\n", - " False]),\n", - " rew: array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0.]),\n", - " obs_next: array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0]),\n", - ")\n", - "maxsize: 10, data length: 3\n", - "========================================\n", - "(array([3]), array([0.]), array([0]), array([3]))\n", - "(array([4]), array([0.]), array([0]), array([3]))\n", - "(array([5]), array([0.]), array([0]), array([3]))\n", - "(array([6]), array([0.]), array([0]), array([3]))\n", - "(array([7]), array([25.]), array([5]), array([3]))\n", - "ReplayBuffer(\n", - " info: Batch(),\n", - " act: array([0, 1, 2, 3, 4, 5, 6, 7, 0, 0]),\n", - " obs: array([0, 1, 2, 3, 4, 5, 6, 7, 0, 0]),\n", - " done: array([False, False, True, False, False, False, False, True, False,\n", - " False]),\n", - " rew: array([0., 1., 2., 3., 4., 5., 6., 7., 0., 0.]),\n", - " obs_next: array([1, 2, 3, 4, 5, 6, 7, 8, 0, 0]),\n", - ")\n", - "maxsize: 10, data length: 8\n", - "========================================\n", - "(array([8]), array([0.]), array([0]), array([8]))\n", - "(array([9]), array([0.]), array([0]), array([8]))\n", - "(array([0]), array([0.]), array([0]), array([8]))\n", - "(array([1]), array([0.]), array([0]), array([8]))\n", - "(array([2]), array([0.]), array([0]), array([8]))\n", - "ReplayBuffer(\n", - " info: Batch(),\n", - " act: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n", - " obs: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n", - " done: array([False, False, False, False, False, False, False, True, False,\n", - " False]),\n", - " rew: array([10., 11., 12., 3., 4., 5., 6., 7., 8., 9.]),\n", - " obs_next: array([11, 12, 13, 4, 5, 6, 7, 8, 9, 10]),\n", - ")\n", - "maxsize: 10, data length: 10\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "### episode length and rewards tracking\n", - "Notice that `ReplayBuffer.add()` returns a tuple of 4 numbers every time it returns, meaning `(current_index, episode_reward, episode_length, episode_start_index)`. `episode_reward` and `episode_length` are valid only when a trajectory is finished. This might save developers some trouble.\n", - "\n" - ], - "metadata": { - "id": "dO7PWdb_hkXA" - } - }, - { - "cell_type": "markdown", - "source": [ - "### Episode index management\n", - "In the ReplayBuffer above, we can get access to any data step by indexing.\n" - ], - "metadata": { - "id": "xbVc90z8itH0" - } - }, - { - "cell_type": "code", - "source": [ - "print(buf)\n", - "data = buf[6]\n", - "print(data)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "4mKwo54MjupY", - "outputId": "9ae14a7e-908b-44eb-afec-89b45bac5961" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "ReplayBuffer(\n", - " info: Batch(),\n", - " act: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n", - " obs: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n", - " done: array([False, False, False, False, False, False, False, True, False,\n", - " False]),\n", - " rew: array([10., 11., 12., 3., 4., 5., 6., 7., 8., 9.]),\n", - " obs_next: array([11, 12, 13, 4, 5, 6, 7, 8, 9, 10]),\n", - ")\n", - "Batch(\n", - " obs: array(6),\n", - " act: array(6),\n", - " rew: array(6.),\n", - " done: array(False),\n", - " obs_next: array(7),\n", - " info: Batch(),\n", - " policy: Batch(),\n", - ")\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "Now we know that step \"6\" is not the start of an episode (it should be step 4, 4-7 is the second trajectory we add into the ReplayBuffer), we wonder what is the earliest index of the that episode.\n", - "\n", - "This may seem easy but actually it is not. We cannot simply look at the \"done\" flag now, because we can see that since the third-added trajectory is not finished yet, step \"4\" is surrounded by flag \"False\". There are many things to consider. Things could get more nasty if you are using more advanced ReplayBuffer like VectorReplayBuffer, because now the data is not stored in a simple circular-queue.\n", - "\n", - "Luckily, all ReplayBuffer instances help you identify step indexes through a unified API." - ], - "metadata": { - "id": "p5Co_Fmzj8Sw" - } - }, - { - "cell_type": "code", - "source": [ - "# Search for the previous index of index \"6\"\n", - "now_index = 6\n", - "while True:\n", - " prev_index = buf.prev(now_index)\n", - " print(prev_index)\n", - " if prev_index == now_index:\n", - " break\n", - " else:\n", - " now_index = prev_index" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "DcJ0LEX6mxHg", - "outputId": "7830f5fb-96d9-4298-d09b-24e64b2f633c" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "5\n", - "4\n", - "3\n", - "3\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "Using `ReplayBuffer.prev()`, we know that the earliest step of that episode is step \"3\". Similarly, `ReplayBuffer.next()` helps us indentify the last index of an episode regardless of which kind of ReplayBuffer we are using." - ], - "metadata": { - "id": "4Wlb57V4lQyQ" - } - }, - { - "cell_type": "code", - "source": [ - "# next step of indexes [4,5,6,7,8,9] are:\n", - "print(buf.next([4,5,6,7,8,9]))" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zl5TRMo7oOy5", - "outputId": "4a11612c-3ee0-4e74-b028-c8759e71fbdb" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "[5 6 7 7 9 0]\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "We can also search for the indexes which are labeled \"done: False\", but are the last step in a trajectory." - ], - "metadata": { - "id": "YJ9CcWZXoOXw" - } - }, - { - "cell_type": "code", - "source": [ - "print(buf.unfinished_index())" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Xkawk97NpItg", - "outputId": "df10b359-c2c7-42ca-e50d-9caee6bccadd" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "[2]\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "Aforementioned APIs will be helpful when we calculate quantities like GAE and n-step-returns in DRL algorithms ([Example usage in Tianshou](https://github.com/thu-ml/tianshou/blob/6fc68578127387522424460790cbcb32a2bd43c4/tianshou/policy/base.py#L384)). The unified APIs ensure a modular design and a flexible interface." - ], - "metadata": { - "id": "8_lMr0j3pOmn" - } - }, - { - "cell_type": "markdown", - "source": [ - "# Further Reading\n", - "## Other Buffer Module\n", - "\n", - "* PrioritizedReplayBuffer, which helps you implement [prioritized experience replay](https://arxiv.org/abs/1511.05952)\n", - "* CachedReplayBuffer, one main buffer with several cached buffers (higher sample efficiency in some scenarios)\n", - "* ReplayBufferManager, A base class that can be inherited (may help you manage multiple buffers).\n", - "\n", - "Check the documentation and the source code for more details.\n", - "\n", - "## Support for steps stacking to use RNN in DRL.\n", - "There is an option called `stack_num` (default to 1) when initialising the ReplayBuffer, which may help you use RNN in your algorithm. Check the documentation for details." - ], - "metadata": { - "id": "FEyE0c7tNfwa" - } - } - ] -} \ No newline at end of file + "id": "Xkawk97NpItg", + "outputId": "df10b359-c2c7-42ca-e50d-9caee6bccadd" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Aforementioned APIs will be helpful when we calculate quantities like GAE and n-step-returns in DRL algorithms ([Example usage in Tianshou](https://github.com/thu-ml/tianshou/blob/6fc68578127387522424460790cbcb32a2bd43c4/tianshou/policy/base.py#L384)). The unified APIs ensure a modular design and a flexible interface." + ], + "metadata": { + "id": "8_lMr0j3pOmn" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Further Reading\n", + "## Other Buffer Module\n", + "\n", + "* PrioritizedReplayBuffer, which helps you implement [prioritized experience replay](https://arxiv.org/abs/1511.05952)\n", + "* CachedReplayBuffer, one main buffer with several cached buffers (higher sample efficiency in some scenarios)\n", + "* ReplayBufferManager, A base class that can be inherited (may help you manage multiple buffers).\n", + "\n", + "Check the documentation and the source code for more details.\n", + "\n", + "## Support for steps stacking to use RNN in DRL.\n", + "There is an option called `stack_num` (default to 1) when initialising the ReplayBuffer, which may help you use RNN in your algorithm. Check the documentation for details." + ], + "metadata": { + "id": "FEyE0c7tNfwa" + } + } + ] +} diff --git a/notebooks/L3_Vectorized__Environment.ipynb b/notebooks/L3_Vectorized__Environment.ipynb index ad07bcd..6ecad98 100644 --- a/notebooks/L3_Vectorized__Environment.ipynb +++ b/notebooks/L3_Vectorized__Environment.ipynb @@ -1,225 +1,215 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] }, - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0T7FYEnlBT6F" - }, - "outputs": [], - "source": [ - "# Remember to install tianshou first\n", - "!pip install tianshou==0.4.8\n", - "!pip install gym" - ] + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0T7FYEnlBT6F" + }, + "outputs": [], + "source": [ + "# Remember to install tianshou first\n", + "!pip install tianshou==0.4.8\n", + "!pip install gym" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Overview\n", + "In reinforcement learning, the agent interacts with environments to improve itself. In this tutorial we will concentrate on the environment part. Although there are many kinds of environments or their libraries in DRL research, Tianshou chooses to keep a consistent API with [OPENAI Gym](https://gym.openai.com/).\n", + "\n", + "
\n", + "\n", + "\n", + " The agents interacting with the environment \n", + "
\n", + "\n", + "In Gym, an environment receives an action and returns next observation and reward. This process is slow and sometimes can be the throughput bottleneck in a DRL experiment.\n" + ], + "metadata": { + "id": "W5V7z3fVX7_b" + } + }, + { + "cell_type": "markdown", + "source": [ + "Tianshou provides vectorized environment wrapper for a Gym environment. This wrapper allows you to make use of multiple cpu cores in your server to accelerate the data sampling." + ], + "metadata": { + "id": "A0NGWZ8adBwt" + } + }, + { + "cell_type": "code", + "source": [ + "from tianshou.env import SubprocVectorEnv\n", + "import numpy as np\n", + "import gym\n", + "import time\n", + "\n", + "num_cpus = [1,2,5]\n", + "for num_cpu in num_cpus:\n", + " env = SubprocVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(num_cpu)])\n", + " env.reset()\n", + " sampled_steps = 0\n", + " time_start = time.time()\n", + " while sampled_steps < 1000:\n", + " act = np.random.choice(2, size=num_cpu)\n", + " obs, rew, done, info = env.step(act)\n", + " if np.sum(done):\n", + " env.reset(np.where(done)[0])\n", + " sampled_steps += num_cpu\n", + " time_used = time.time() - time_start\n", + " print(\"{}s used to sample 1000 steps if using {} cpus.\".format(time_used, num_cpu))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "# Overview\n", - "In reinforcement learning, the agent interacts with environments to improve itself. In this tutorial we will concentrate on the environment part. Although there are many kinds of environments or their libraries in DRL research, Tianshou chooses to keep a consistent API with [OPENAI Gym](https://gym.openai.com/).\n", - "\n", - "
\n", - "\n", - "\n", - " The agents interacting with the environment \n", - "
\n", - "\n", - "In Gym, an environment receives an action and returns next observation and reward. This process is slow and sometimes can be the throughput bottleneck in a DRL experiment.\n" - ], - "metadata": { - "id": "W5V7z3fVX7_b" - } - }, - { - "cell_type": "markdown", - "source": [ - "Tianshou provides vectorized environment wrapper for a Gym environment. This wrapper allows you to make use of multiple cpu cores in your server to accelerate the data sampling." - ], - "metadata": { - "id": "A0NGWZ8adBwt" - } - }, - { - "cell_type": "code", - "source": [ - "from tianshou.env import SubprocVectorEnv\n", - "import numpy as np\n", - "import gym\n", - "import time\n", - "\n", - "num_cpus = [1,2,5]\n", - "for num_cpu in num_cpus:\n", - " env = SubprocVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(num_cpu)])\n", - " env.reset()\n", - " sampled_steps = 0\n", - " time_start = time.time()\n", - " while sampled_steps < 1000:\n", - " act = np.random.choice(2, size=num_cpu)\n", - " obs, rew, done, info = env.step(act)\n", - " if np.sum(done):\n", - " env.reset(np.where(done)[0])\n", - " sampled_steps += num_cpu\n", - " time_used = time.time() - time_start\n", - " print(\"{}s used to sample 1000 steps if using {} cpus.\".format(time_used, num_cpu))" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "67wKtkiNi3lb", - "outputId": "1e04353b-7a91-4c32-e2ae-f3889d58aa5e" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "0.30551695823669434s used to sample 1000 steps if using 1 cpus.\n", - "0.2602052688598633s used to sample 1000 steps if using 2 cpus.\n", - "0.15763545036315918s used to sample 1000 steps if using 5 cpus.\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "You may notice that the speed doesn't increase linearly when we add subprocess numbers. There are multiple reasons behind this. One reason is that synchronize exection causes straggler effect. One way to solve this would be to use asynchronous mode. We leave this for further reading if you feel interested.\n", - "\n", - "Note that SubprocVectorEnv should only be used when the environment exection is slow. In practice, DummyVectorEnv (or raw Gym environment) is actually more efficient for a simple environment like CartPole because now you avoid both straggler effect and the overhead of communication between subprocesses." - ], - "metadata": { - "id": "S1b6vxp9nEUS" - } - }, - { - "cell_type": "markdown", - "source": [ - "# Usages\n", - "## Initialisation\n", - "Just pass in a list of functions which return the initialised environment upon called." - ], - "metadata": { - "id": "Z6yPxdqFp18j" - } - }, - { - "cell_type": "code", - "source": [ - "from tianshou.env import DummyVectorEnv\n", - "# In Gym\n", - "env = gym.make(\"CartPole-v0\")\n", - "\n", - "# In Tianshou\n", - "def helper_function():\n", - " env = gym.make(\"CartPole-v0\")\n", - " # other operations such as env.seed(np.random.choice(10))\n", - " return env\n", - "\n", - "envs = DummyVectorEnv([helper_function for _ in range(5)])\n", - "print(envs)" - ], - "metadata": { - "id": "ssLcrL_pq24-" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## EnvPool supporting\n", - "Besides integrated environment wrappers, Tianshou also fully supports [EnvPool](https://github.com/sail-sg/envpool/). Explore its Github page yourself." - ], - "metadata": { - "id": "X7p8csjdrwIN" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Environment exection and resetting\n", - "The only difference between Vectorized environments and standard Gym environments is that passed in actions and returned rewards/observations are also vectorized." - ], - "metadata": { - "id": "kvIfqh0vqAR5" - } - }, - { - "cell_type": "code", - "source": [ - "# In Gym, env.reset() returns a single observation.\n", - "print(\"In Gym, env.reset() returns a single observation.\")\n", - "print(env.reset())\n", - "\n", - "# In Tianshou, envs.reset() returns stacked observations.\n", - "print(\"========================================\")\n", - "print(\"In Tianshou, envs.reset() returns stacked observations.\")\n", - "print(envs.reset())\n", - "\n", - "obs, rew, done, info = envs.step(np.random.choice(2, size=num_cpu))\n", - "print(info)" - ], - "metadata": { - "id": "BH1ZnPG6tkdD" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "If we only want to execute several environments. The `id` argument can be used." - ], - "metadata": { - "id": "qXroB7KluvP9" - } - }, - { - "cell_type": "code", - "source": [ - "print(envs.step(np.random.choice(2, size=3), id=[0,3,1]))" - ], - "metadata": { - "id": "ufvFViKTu8d_" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# Further Reading\n", - "## Other environment wrappers in Tianshou\n", - "\n", - "\n", - "* ShmemVectorEnv: use share memory instead of pipe based on SubprocVectorEnv;\n", - "* RayVectorEnv: use Ray for concurrent activities and is currently the only choice for parallel simulation in a cluster with multiple machines.\n", - "\n", - "Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.env.html) for details.\n", - "\n", - "## Difference between synchronous and asynchronous mode (How to choose?)\n", - "Explanation can be found at the [Parallel Sampling](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling) tutorial." - ], - "metadata": { - "id": "fekHR1a6X_HB" - } - } - ] -} \ No newline at end of file + "id": "67wKtkiNi3lb", + "outputId": "1e04353b-7a91-4c32-e2ae-f3889d58aa5e" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "You may notice that the speed doesn't increase linearly when we add subprocess numbers. There are multiple reasons behind this. One reason is that synchronize exection causes straggler effect. One way to solve this would be to use asynchronous mode. We leave this for further reading if you feel interested.\n", + "\n", + "Note that SubprocVectorEnv should only be used when the environment exection is slow. In practice, DummyVectorEnv (or raw Gym environment) is actually more efficient for a simple environment like CartPole because now you avoid both straggler effect and the overhead of communication between subprocesses." + ], + "metadata": { + "id": "S1b6vxp9nEUS" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Usages\n", + "## Initialisation\n", + "Just pass in a list of functions which return the initialised environment upon called." + ], + "metadata": { + "id": "Z6yPxdqFp18j" + } + }, + { + "cell_type": "code", + "source": [ + "from tianshou.env import DummyVectorEnv\n", + "# In Gym\n", + "env = gym.make(\"CartPole-v0\")\n", + "\n", + "# In Tianshou\n", + "def helper_function():\n", + " env = gym.make(\"CartPole-v0\")\n", + " # other operations such as env.seed(np.random.choice(10))\n", + " return env\n", + "\n", + "envs = DummyVectorEnv([helper_function for _ in range(5)])\n", + "print(envs)" + ], + "metadata": { + "id": "ssLcrL_pq24-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## EnvPool supporting\n", + "Besides integrated environment wrappers, Tianshou also fully supports [EnvPool](https://github.com/sail-sg/envpool/). Explore its Github page yourself." + ], + "metadata": { + "id": "X7p8csjdrwIN" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Environment exection and resetting\n", + "The only difference between Vectorized environments and standard Gym environments is that passed in actions and returned rewards/observations are also vectorized." + ], + "metadata": { + "id": "kvIfqh0vqAR5" + } + }, + { + "cell_type": "code", + "source": [ + "# In Gym, env.reset() returns a single observation.\n", + "print(\"In Gym, env.reset() returns a single observation.\")\n", + "print(env.reset())\n", + "\n", + "# In Tianshou, envs.reset() returns stacked observations.\n", + "print(\"========================================\")\n", + "print(\"In Tianshou, envs.reset() returns stacked observations.\")\n", + "print(envs.reset())\n", + "\n", + "obs, rew, done, info = envs.step(np.random.choice(2, size=num_cpu))\n", + "print(info)" + ], + "metadata": { + "id": "BH1ZnPG6tkdD" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "If we only want to execute several environments. The `id` argument can be used." + ], + "metadata": { + "id": "qXroB7KluvP9" + } + }, + { + "cell_type": "code", + "source": [ + "print(envs.step(np.random.choice(2, size=3), id=[0,3,1]))" + ], + "metadata": { + "id": "ufvFViKTu8d_" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Further Reading\n", + "## Other environment wrappers in Tianshou\n", + "\n", + "\n", + "* ShmemVectorEnv: use share memory instead of pipe based on SubprocVectorEnv;\n", + "* RayVectorEnv: use Ray for concurrent activities and is currently the only choice for parallel simulation in a cluster with multiple machines.\n", + "\n", + "Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.env.html) for details.\n", + "\n", + "## Difference between synchronous and asynchronous mode (How to choose?)\n", + "Explanation can be found at the [Parallel Sampling](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling) tutorial." + ], + "metadata": { + "id": "fekHR1a6X_HB" + } + } + ] +} diff --git a/notebooks/L4_Policy.ipynb b/notebooks/L4_Policy.ipynb index a274386..834b47d 100644 --- a/notebooks/L4_Policy.ipynb +++ b/notebooks/L4_Policy.ipynb @@ -1,954 +1,806 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] }, - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "cesUkq8hA373" - }, - "outputs": [], - "source": [ - "# Remember to install tianshou first\n", - "!pip install tianshou==0.4.8\n", - "!pip install gym" - ] + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cesUkq8hA373" + }, + "outputs": [], + "source": [ + "# Remember to install tianshou first\n", + "!pip install tianshou==0.4.8\n", + "!pip install gym" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Overview\n", + "In reinforcement learning, the agent interacts with environments to improve itself. In this tutorial we will concentrate on the agent part. In Tianshou, both the agent and the core DRL algorithm are implementated in the Policy module. Tianshou provides more than 20 Policy modules, each representing one DRL algorithm. See supported algorithms [here](https://github.com/thu-ml/tianshou).\n", + "\n", + "
\n", + "\n", + "\n", + " The agents interacting with the environment \n", + "
\n", + "\n", + "All Policy modules inherit from a BasePolicy Class and share the same interface." + ], + "metadata": { + "id": "PNM9wqstBSY_" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Creating you own Policy\n", + "We will use the a simple REINFORCE algorithm Policy to show the implementation of a Policy Module. The Policy we implement here will be a highly scaled-down version of [PGPolicy](https://github.com/thu-ml/tianshou/blob/master/tianshou/policy/modelfree/pg.py) in Tianshou." + ], + "metadata": { + "id": "ZqdHYdoJJS51" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Initialisation\n", + "Firstly we create the `REINFORCEPolicy` by inheriting from `BasePolicy` in Tianshou." + ], + "metadata": { + "id": "PWFBgZ4TJkfz" + } + }, + { + "cell_type": "code", + "source": [ + "\n", + "from typing import Any, Dict, List, Optional, Type, Union\n", + "\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as\n", + "from tianshou.policy import BasePolicy\n", + "\n", + "class REINFORCEPolicy(BasePolicy):\n", + " \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n", + " def __init__(self):\n", + " super().__init__()" + ], + "metadata": { + "id": "cDlSjASbJmy-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "As we have mentioned, the Policy Module mainly does two things:\n", + "\n", + "\n", + "1. `policy.forward()` receives observation and other information (stored in a Batch) from the environment and returns a new Batch containing the action.\n", + "2. `policy.update()` receives training data sampled from the replay buffer and updates itself, and then returns logging details.\n", + "\n", + "\n", + "
\n", + "\n", + "\n", + " policy.forward() and policy.update() \n", + "
\n", + "\n", + "We also need to take care of the following things:\n", + "\n", + "\n", + "\n", + "1. Since Tianshou is a **Deep** RL libraries, there should be a policy network in our Policy Module, also a Torch optimizer.\n", + "2. In Tianshou's BasePolicy, `Policy.update()` first calls `Policy.process_fn()` to preprocess training data and computes quantities like episodic returns (gradient free), then it will call `Policy.learn()` to perform the back-propagation.\n", + "\n", + "Then we get the implementation below.\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "qc1RnIBbLCDN" + } + }, + { + "cell_type": "code", + "source": [ + "class REINFORCEPolicy(BasePolicy):\n", + " \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n", + " def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer,):\n", + " super().__init__()\n", + " self.actor = model\n", + " self.optim = optim\n", + "\n", + " def forward(self, batch: Batch) -> Batch:\n", + " \"\"\"Compute action over the given batch data.\"\"\"\n", + " act = None\n", + " return Batch(act=act)\n", + "\n", + " def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:\n", + " \"\"\"Compute the discounted returns for each transition.\"\"\"\n", + " pass\n", + "\n", + " def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n", + " \"\"\"Perform the back-propagation.\"\"\"\n", + " return" + ], + "metadata": { + "id": "6j32PSKUQ23w" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Policy.forward()\n", + "According to the equation of REINFORCE algorithm in Spinning Up's [documentation](https://spinningup.openai.com/en/latest/algorithms/vpg.html), we need to map the observation to an action distribution in action space using neural network (`self.actor`).\n", + "\n", + "
\n", + "\n", + "\n", + "
\n", + "\n", + "Let's us suppose the action space is discrete, and the distribution is a simple categorical distribution.\n", + "\n" + ], + "metadata": { + "id": "tjtqjt8WRY5e" + } + }, + { + "cell_type": "code", + "source": [ + "def forward(self, batch: Batch) -> Batch:\n", + " \"\"\"Compute action over the given batch data.\"\"\"\n", + " self.dist_fn = torch.distributions.Categorical\n", + " logits = self.actor(batch.obs)\n", + " dist = self.dist_fn(logits)\n", + " act = dist.sample()\n", + " return Batch(act=act, dist=dist)" + ], + "metadata": { + "id": "uE4YDE-_RwgN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Policy.process_fn()\n", + "Now that we have defined our actor, if given training data we can set up a loss function and optimize our neural network. However, before that, we must first calculate episodic returns for every step in our training data to construct the REINFORCE loss function.\n", + "\n", + "Calculating episodic return is not hard, given `ReplayBuffer.next()` allows us to access every reward to go in an episode. A more convenient way would be to simply use the built-in method `BasePolicy.compute_episodic_return()` inherited from BasePolicy.\n" + ], + "metadata": { + "id": "CultfOeuTx2V" + } + }, + { + "cell_type": "code", + "source": [ + "def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:\n", + " \"\"\"Compute the discounted returns for each transition.\"\"\"\n", + " returns, _ = self.compute_episodic_return(batch, buffer, indices, gamma=0.99, gae_lambda=1.0)\n", + " batch.returns = returns\n", + " return batch" + ], + "metadata": { + "id": "wPAmOD7zV7n2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "`BasePolicy.compute_episodic_return()` could also be used to compute [GAE](https://arxiv.org/abs/1506.02438). Another similar method is `BasePolicy.compute_nstep_return()`. Check the [source code](https://github.com/thu-ml/tianshou/blob/6fc68578127387522424460790cbcb32a2bd43c4/tianshou/policy/base.py#L304) for more details." + ], + "metadata": { + "id": "XA8OF4GnWWr5" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Policy.learn()\n", + "Data batch returned by `Policy.process_fn()` will flow into `Policy.learn()`. Finall we can construct our loss function and perform the back-propagation." + ], + "metadata": { + "id": "7UsdzNaOXPpC" + } + }, + { + "cell_type": "code", + "source": [ + "def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n", + " \"\"\"Perform the back-propagation.\"\"\"\n", + " logging_losses = []\n", + " for _ in range(repeat):\n", + " for minibatch in batch.split(batch_size, merge_last=True):\n", + " self.optim.zero_grad()\n", + " result = self(minibatch)\n", + " dist = result.dist\n", + " act = to_torch_as(minibatch.act, result.act)\n", + " ret = to_torch(minibatch.returns, torch.float, result.act.device)\n", + " log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n", + " loss = -(log_prob * ret).mean()\n", + " loss.backward()\n", + " self.optim.step()\n", + " logging_losses.append(loss.item())\n", + " return {\"loss\": logging_losses}" + ], + "metadata": { + "id": "aCO-dLXWXtz9" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Implementation\n", + "Finally we can assemble the implemented methods and form a REINFORCE Policy." + ], + "metadata": { + "id": "1BtuV2W0YJTi" + } + }, + { + "cell_type": "code", + "source": [ + "class REINFORCEPolicy(BasePolicy):\n", + " \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n", + " def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer,):\n", + " super().__init__()\n", + " self.actor = model\n", + " self.optim = optim\n", + " # action distribution\n", + " self.dist_fn = torch.distributions.Categorical\n", + "\n", + " def forward(self, batch: Batch) -> Batch:\n", + " \"\"\"Compute action over the given batch data.\"\"\"\n", + " logits, _ = self.actor(batch.obs)\n", + " dist = self.dist_fn(logits)\n", + " act = dist.sample()\n", + " return Batch(act=act, dist=dist)\n", + "\n", + " def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:\n", + " \"\"\"Compute the discounted returns for each transition.\"\"\"\n", + " returns, _ = self.compute_episodic_return(batch, buffer, indices, gamma=0.99, gae_lambda=1.0)\n", + " batch.returns = returns\n", + " return batch\n", + "\n", + " def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n", + " \"\"\"Perform the back-propagation.\"\"\"\n", + " logging_losses = []\n", + " for _ in range(repeat):\n", + " for minibatch in batch.split(batch_size, merge_last=True):\n", + " self.optim.zero_grad()\n", + " result = self(minibatch)\n", + " dist = result.dist\n", + " act = to_torch_as(minibatch.act, result.act)\n", + " ret = to_torch(minibatch.returns, torch.float, result.act.device)\n", + " log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n", + " loss = -(log_prob * ret).mean()\n", + " loss.backward()\n", + " self.optim.step()\n", + " logging_losses.append(loss.item())\n", + " return {\"loss\": logging_losses}\n" + ], + "metadata": { + "id": "Ab0KNQHTOlGo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Use the policy\n", + "Note that `BasePolicy` itself inherits from `torch.nn.Module`. As a result, you can consider all Policy modules as a Torch Module. They share similar APIs.\n", + "\n", + "Firstly we will initialise a new REINFORCE policy." + ], + "metadata": { + "id": "xlPAbh0lKti8" + } + }, + { + "cell_type": "code", + "source": [ + "from tianshou.utils.net.common import Net\n", + "from tianshou.utils.net.discrete import Actor\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "state_shape = 4\n", + "action_shape = 2\n", + "net = Net(state_shape, hidden_sizes=[16, 16], device=\"cpu\")\n", + "actor = Actor(net, action_shape, device=\"cpu\").to(\"cpu\")\n", + "optim = torch.optim.Adam(actor.parameters(), lr=0.0003)\n", + "\n", + "policy = REINFORCEPolicy(actor, optim)" + ], + "metadata": { + "id": "JkLFA9Z1KjuX" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "REINFORCE policy shares same APIs with the Torch Module." + ], + "metadata": { + "id": "LAo_0t2fekUD" + } + }, + { + "cell_type": "code", + "source": [ + "print(policy)\n", + "print(\"========================================\")\n", + "for para in policy.parameters():\n", + " print(para.shape)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "# Overview\n", - "In reinforcement learning, the agent interacts with environments to improve itself. In this tutorial we will concentrate on the agent part. In Tianshou, both the agent and the core DRL algorithm are implementated in the Policy module. Tianshou provides more than 20 Policy modules, each representing one DRL algorithm. See supported algorithms [here](https://github.com/thu-ml/tianshou).\n", - "\n", - "
\n", - "\n", - "\n", - " The agents interacting with the environment \n", - "
\n", - "\n", - "All Policy modules inherit from a BasePolicy Class and share the same interface." - ], - "metadata": { - "id": "PNM9wqstBSY_" - } + "id": "UiuTc8RhJiEi", + "outputId": "9b5bc54c-6303-45f3-ba81-2216a44931e8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Making decision\n", + "Given a batch of observations, the policy can return a batch of actions and other data." + ], + "metadata": { + "id": "-RCrsttYgAG-" + } + }, + { + "cell_type": "code", + "source": [ + "obs_batch = Batch(obs=np.ones(shape=(256, 4)))\n", + "action = policy(obs_batch) # forward() method is called\n", + "print(action)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "# Creating you own Policy\n", - "We will use the a simple REINFORCE algorithm Policy to show the implementation of a Policy Module. The Policy we implement here will be a highly scaled-down version of [PGPolicy](https://github.com/thu-ml/tianshou/blob/master/tianshou/policy/modelfree/pg.py) in Tianshou." - ], - "metadata": { - "id": "ZqdHYdoJJS51" - } + "id": "0jkBb6AAgUla", + "outputId": "37948844-cdd8-4567-9481-89453c80a157" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Save and Load models\n", + "Naturally, Tianshou Policy can be saved and loaded like a normal Torch Network." + ], + "metadata": { + "id": "swikhnuDfKep" + } + }, + { + "cell_type": "code", + "source": [ + "torch.save(policy.state_dict(), 'policy.pth')\n", + "assert policy.load_state_dict(torch.load('policy.pth'))" + ], + "metadata": { + "id": "tYOoWM_OJRnA" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Algorithm Updating\n", + "We have to collect some data and save them in the ReplayBuffer before updating our agent(policy). Typically we use collector to collect data, but we leave this part till later when we have learned the Collector in Tianshou. For now we generate some **fake** data." + ], + "metadata": { + "id": "gp8PzOYsg5z-" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Generating fake data\n", + "Firstly, we need to \"pretend\" that we are using the \"Policy\" to collect data. We plan to collect 10 data so that we can update our algorithm." + ], + "metadata": { + "id": "XrrPxOUAYShR" + } + }, + { + "cell_type": "code", + "source": [ + "import gym\n", + "from tianshou.data import Batch, ReplayBuffer\n", + "# a buffer is initialised with its maxsize set to 20.\n", + "print(\"========================================\")\n", + "buf = ReplayBuffer(size=12)\n", + "print(buf)\n", + "print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n", + "env = gym.make(\"CartPole-v0\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "## Initialisation\n", - "Firstly we create the `REINFORCEPolicy` by inheriting from `BasePolicy` in Tianshou." - ], - "metadata": { - "id": "PWFBgZ4TJkfz" - } + "id": "a14CmzSfYh5C", + "outputId": "aaf45a1f-5e21-4bc8-cbe3-8ce798258af0" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now we are pretending to collect the first episode. The first episode ends at step 3 (perhaps because we are performing too badly)." + ], + "metadata": { + "id": "8S94cV7yZITR" + } + }, + { + "cell_type": "code", + "source": [ + "obs = env.reset()\n", + "for i in range(3):\n", + " act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n", + " obs_next, rew, done, info = env.step(act)\n", + " # pretend ending at step 3\n", + " done = True if i==2 else False\n", + " info[\"id\"] = i\n", + " buf.add(Batch(obs=obs, act=act, rew=rew, done=done, obs_next=obs_next, info=info))\n", + " obs = obs_next" + ], + "metadata": { + "id": "a_mtvbmBZbfs" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now we are pretending to collect the second episode. At step 7 the second episode still does't end, but we are unwilling to wait, so we stop collecting to update the algorithm." + ], + "metadata": { + "id": "pkxq4gu9bGkt" + } + }, + { + "cell_type": "code", + "source": [ + "obs = env.reset()\n", + "for i in range(3, 10):\n", + " act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n", + " obs_next, rew, done, info = env.step(act)\n", + " # pretend this episode never end\n", + " done = False\n", + " info[\"id\"] = i\n", + " buf.add(Batch(obs=obs, act=act, rew=rew, done=done, obs_next=obs_next, info=info))\n", + " obs = obs_next" + ], + "metadata": { + "id": "pAoKe02ybG68" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Our replay buffer looks like this now." + ], + "metadata": { + "id": "MKM6aWMucv-M" + } + }, + { + "cell_type": "code", + "source": [ + "print(buf)\n", + "print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "source": [ - "\n", - "from typing import Any, Dict, List, Optional, Type, Union\n", - "\n", - "import numpy as np\n", - "import torch\n", - "\n", - "from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as\n", - "from tianshou.policy import BasePolicy\n", - "\n", - "class REINFORCEPolicy(BasePolicy):\n", - " \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n", - " def __init__(self):\n", - " super().__init__()" - ], - "metadata": { - "id": "cDlSjASbJmy-" - }, - "execution_count": null, - "outputs": [] + "id": "CSJEEWOqXdTU", + "outputId": "2b3bb75c-f219-4e82-ca78-0ea6173a91f9" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Updates\n", + "Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train." + ], + "metadata": { + "id": "55VWhWpkdfEb" + } + }, + { + "cell_type": "code", + "source": [ + "# 0 means sample all data from the buffer\n", + "# batch_size=10 defines the training batch size\n", + "# repeat=6 means repeat the training for 6 times\n", + "policy.update(0, buf, batch_size=10, repeat=6)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "As we have mentioned, the Policy Module mainly does two things:\n", - "\n", - "\n", - "1. `policy.forward()` receives observation and other information (stored in a Batch) from the environment and returns a new Batch containing the action.\n", - "2. `policy.update()` receives training data sampled from the replay buffer and updates itself, and then returns logging details.\n", - "\n", - "\n", - "
\n", - "\n", - "\n", - " policy.forward() and policy.update() \n", - "
\n", - "\n", - "We also need to take care of the following things:\n", - "\n", - "\n", - "\n", - "1. Since Tianshou is a **Deep** RL libraries, there should be a policy network in our Policy Module, also a Torch optimizer.\n", - "2. In Tianshou's BasePolicy, `Policy.update()` first calls `Policy.process_fn()` to preprocess training data and computes quantities like episodic returns (gradient free), then it will call `Policy.learn()` to perform the back-propagation.\n", - "\n", - "Then we get the implementation below.\n", - "\n", - "\n", - "\n" - ], - "metadata": { - "id": "qc1RnIBbLCDN" - } + "id": "i_O1lJDWdeoc", + "outputId": "b154741a-d6dc-46cb-898f-6e84fa14e5a7" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Not that difficult, right?" + ], + "metadata": { + "id": "enqlFQLSJrQl" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Further Reading\n", + "\n", + "\n" + ], + "metadata": { + "id": "QJ5krjrcbuiA" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Pre-defined Networks\n", + "Tianshou provides numberous pre-defined networks usually used in DRL so that you don't have to bother yourself. Check this [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.utils.html#pre-defined-networks) for details." + ], + "metadata": { + "id": "pmWi3HuXWcV8" + } + }, + { + "cell_type": "markdown", + "source": [ + "## How to compute GAE on your own?\n", + "(Note that for this reading you need to understand the calculation of [GAE](https://arxiv.org/abs/1506.02438) advantage first)\n", + "\n", + "In terms of code implementation, perhaps the most difficult and annoying part is computing GAE advantage. Just now, we use the `self.compute_episodic_return()` method inherited from `BasePolicy` to save us from all those troubles. However, it is still important that we know the details behind this.\n", + "\n", + "To compute GAE advantage, the usage of `self.compute_episodic_return()` may goes like:" + ], + "metadata": { + "id": "UPVl5LBEWJ0t" + } + }, + { + "cell_type": "code", + "source": [ + "batch, indices = buf.sample(0) # 0 means sampling all the data from the buffer\n", + "returns, advantage = BasePolicy.compute_episodic_return(batch, buf, indices, v_s_=np.zeros(10), v_s=np.zeros(10), gamma=1.0, gae_lambda=1.0)\n", + "print(returns)\n", + "print(advantage)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "source": [ - "class REINFORCEPolicy(BasePolicy):\n", - " \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n", - " def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer,):\n", - " super().__init__()\n", - " self.actor = model\n", - " self.optim = optim\n", - "\n", - " def forward(self, batch: Batch) -> Batch:\n", - " \"\"\"Compute action over the given batch data.\"\"\"\n", - " act = None\n", - " return Batch(act=act)\n", - "\n", - " def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:\n", - " \"\"\"Compute the discounted returns for each transition.\"\"\"\n", - " pass\n", - "\n", - " def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n", - " \"\"\"Perform the back-propagation.\"\"\"\n", - " return" - ], - "metadata": { - "id": "6j32PSKUQ23w" - }, - "execution_count": null, - "outputs": [] + "id": "D34GlVvPNz08", + "outputId": "43a4e5df-59b5-4e4a-c61c-e69090810215" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "In the code above, we sample all the 10 data in the buffer and try to compute the GAE advantage. As we know, we need to estimate the value function of every observation to compute GAE advantage. so the passed in `v_s` is the value of bacth.obs, `v_s_` is the value of bacth.obs_next this is usually computed by:\n", + "\n", + "`v_s = critic(bacth.obs)`,\n", + "\n", + "`v_s_ = critic(bacth.obs_next)`,\n", + "\n", + "where uboth `v_s` and `v_s_` are 10 dimensional arrays and `critic` is usually a neural network.\n", + "\n", + "After we've got all those values, GAE can be computed following the equation below." + ], + "metadata": { + "id": "h_5Dt6XwQLXV" + } + }, + { + "cell_type": "markdown", + "source": [ + "\\begin{aligned}\n", + "\\hat{A}_{t}^{\\mathrm{GAE}(\\gamma, \\lambda)}: =& \\sum_{l=0}^{\\infty}(\\gamma \\lambda)^{l} \\delta_{t+l}^{V}\n", + "\\end{aligned}\n", + "\n", + "while\n", + "\n", + "\\begin{equation}\n", + "\\delta_{t}^{V} \\quad=-V\\left(s_{t}\\right)+r_{t}+\\gamma V\\left(s_{t+1}\\right)\n", + "\\end{equation}\n" + ], + "metadata": { + "id": "ooHNIICGUO19" + } + }, + { + "cell_type": "markdown", + "source": [ + "But, if you do follow this equation I refered from the paper. You probably will get a slightly lower performance than you expected. There are at least 3 \"bugs\" in this equation." + ], + "metadata": { + "id": "eV6XZaouU7EV" + } + }, + { + "cell_type": "markdown", + "source": [ + "**First** is that Gym always returns you a `obs_next` even if this is already the last step. The value of this timestep is exactly 0 and you should not let the neural network estimate it." + ], + "metadata": { + "id": "FCxD9gNNVYbd" + } + }, + { + "cell_type": "code", + "source": [ + "import copy\n", + "# Assume v_s_ is got by calling critic(bacth.obs_next)\n", + "v_s_ = np.ones(10)\n", + "v_s_ *= ~batch.done\n", + "print(v_s_)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "## Policy.forward()\n", - "According to the equation of REINFORCE algorithm in Spinning Up's [documentation](https://spinningup.openai.com/en/latest/algorithms/vpg.html), we need to map the observation to an action distribution in action space using neural network (`self.actor`).\n", - "\n", - "
\n", - "\n", - "\n", - "
\n", - "\n", - "Let's us suppose the action space is discrete, and the distribution is a simple categorical distribution.\n", - "\n" - ], - "metadata": { - "id": "tjtqjt8WRY5e" - } + "id": "rNZNUNgQVvRJ", + "outputId": "44354595-c25a-4da8-b4d8-cffa31ac4b7d" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "After the fix above, we will perhaps get a more accurate estimate.\n", + "\n", + "**Secondly**, you must know when to stop bootstrapping. Usually we stop bootstrapping when we meet a `done` flag. However, in the buffer above, the last (10th) step is not marked by done=True, because the collecting has not finished. We must know all those unfinished steps so that we know when to stop bootstraping.\n", + "\n", + "Luckily, this can be done under the assistance of buffer because buffers in Tianshou not only store data, but also help you manage data trajectories." + ], + "metadata": { + "id": "2EtMi18QWXTN" + } + }, + { + "cell_type": "code", + "source": [ + "unfinished_indexes = buf.unfinished_index()\n", + "print(unfinished_indexes)\n", + "done_indexes = np.where(batch.done)[0]\n", + "print(done_indexes)\n", + "stop_bootstrap_ids = np.concatenate([unfinished_indexes, done_indexes])\n", + "print(stop_bootstrap_ids)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "source": [ - "def forward(self, batch: Batch) -> Batch:\n", - " \"\"\"Compute action over the given batch data.\"\"\"\n", - " self.dist_fn = torch.distributions.Categorical\n", - " logits = self.actor(batch.obs)\n", - " dist = self.dist_fn(logits)\n", - " act = dist.sample()\n", - " return Batch(act=act, dist=dist)" - ], - "metadata": { - "id": "uE4YDE-_RwgN" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Policy.process_fn()\n", - "Now that we have defined our actor, if given training data we can set up a loss function and optimize our neural network. However, before that, we must first calculate episodic returns for every step in our training data to construct the REINFORCE loss function.\n", - "\n", - "Calculating episodic return is not hard, given `ReplayBuffer.next()` allows us to access every reward to go in an episode. A more convenient way would be to simply use the built-in method `BasePolicy.compute_episodic_return()` inherited from BasePolicy.\n" - ], - "metadata": { - "id": "CultfOeuTx2V" - } - }, - { - "cell_type": "code", - "source": [ - "def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:\n", - " \"\"\"Compute the discounted returns for each transition.\"\"\"\n", - " returns, _ = self.compute_episodic_return(batch, buffer, indices, gamma=0.99, gae_lambda=1.0)\n", - " batch.returns = returns\n", - " return batch" - ], - "metadata": { - "id": "wPAmOD7zV7n2" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "`BasePolicy.compute_episodic_return()` could also be used to compute [GAE](https://arxiv.org/abs/1506.02438). Another similar method is `BasePolicy.compute_nstep_return()`. Check the [source code](https://github.com/thu-ml/tianshou/blob/6fc68578127387522424460790cbcb32a2bd43c4/tianshou/policy/base.py#L304) for more details." - ], - "metadata": { - "id": "XA8OF4GnWWr5" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Policy.learn()\n", - "Data batch returned by `Policy.process_fn()` will flow into `Policy.learn()`. Finall we can construct our loss function and perform the back-propagation." - ], - "metadata": { - "id": "7UsdzNaOXPpC" - } - }, - { - "cell_type": "code", - "source": [ - "def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n", - " \"\"\"Perform the back-propagation.\"\"\"\n", - " logging_losses = []\n", - " for _ in range(repeat):\n", - " for minibatch in batch.split(batch_size, merge_last=True):\n", - " self.optim.zero_grad()\n", - " result = self(minibatch)\n", - " dist = result.dist\n", - " act = to_torch_as(minibatch.act, result.act)\n", - " ret = to_torch(minibatch.returns, torch.float, result.act.device)\n", - " log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n", - " loss = -(log_prob * ret).mean()\n", - " loss.backward()\n", - " self.optim.step()\n", - " logging_losses.append(loss.item())\n", - " return {\"loss\": logging_losses}" - ], - "metadata": { - "id": "aCO-dLXWXtz9" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Implementation\n", - "Finally we can assemble the implemented methods and form a REINFORCE Policy." - ], - "metadata": { - "id": "1BtuV2W0YJTi" - } - }, - { - "cell_type": "code", - "source": [ - "class REINFORCEPolicy(BasePolicy):\n", - " \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n", - " def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer,):\n", - " super().__init__()\n", - " self.actor = model\n", - " self.optim = optim\n", - " # action distribution\n", - " self.dist_fn = torch.distributions.Categorical\n", - "\n", - " def forward(self, batch: Batch) -> Batch:\n", - " \"\"\"Compute action over the given batch data.\"\"\"\n", - " logits, _ = self.actor(batch.obs)\n", - " dist = self.dist_fn(logits)\n", - " act = dist.sample()\n", - " return Batch(act=act, dist=dist)\n", - "\n", - " def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:\n", - " \"\"\"Compute the discounted returns for each transition.\"\"\"\n", - " returns, _ = self.compute_episodic_return(batch, buffer, indices, gamma=0.99, gae_lambda=1.0)\n", - " batch.returns = returns\n", - " return batch\n", - "\n", - " def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n", - " \"\"\"Perform the back-propagation.\"\"\"\n", - " logging_losses = []\n", - " for _ in range(repeat):\n", - " for minibatch in batch.split(batch_size, merge_last=True):\n", - " self.optim.zero_grad()\n", - " result = self(minibatch)\n", - " dist = result.dist\n", - " act = to_torch_as(minibatch.act, result.act)\n", - " ret = to_torch(minibatch.returns, torch.float, result.act.device)\n", - " log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n", - " loss = -(log_prob * ret).mean()\n", - " loss.backward()\n", - " self.optim.step()\n", - " logging_losses.append(loss.item())\n", - " return {\"loss\": logging_losses}\n" - ], - "metadata": { - "id": "Ab0KNQHTOlGo" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# Use the policy\n", - "Note that `BasePolicy` itself inherits from `torch.nn.Module`. As a result, you can consider all Policy modules as a Torch Module. They share similar APIs.\n", - "\n", - "Firstly we will initialise a new REINFORCE policy." - ], - "metadata": { - "id": "xlPAbh0lKti8" - } - }, - { - "cell_type": "code", - "source": [ - "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor\n", - "import warnings\n", - "warnings.filterwarnings('ignore')\n", - "state_shape = 4\n", - "action_shape = 2\n", - "net = Net(state_shape, hidden_sizes=[16, 16], device=\"cpu\")\n", - "actor = Actor(net, action_shape, device=\"cpu\").to(\"cpu\")\n", - "optim = torch.optim.Adam(actor.parameters(), lr=0.0003)\n", - "\n", - "policy = REINFORCEPolicy(actor, optim)" - ], - "metadata": { - "id": "JkLFA9Z1KjuX" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "REINFORCE policy shares same APIs with the Torch Module." - ], - "metadata": { - "id": "LAo_0t2fekUD" - } - }, - { - "cell_type": "code", - "source": [ - "print(policy)\n", - "print(\"========================================\")\n", - "for para in policy.parameters():\n", - " print(para.shape)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "UiuTc8RhJiEi", - "outputId": "9b5bc54c-6303-45f3-ba81-2216a44931e8" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "REINFORCEPolicy(\n", - " (actor): Actor(\n", - " (preprocess): Net(\n", - " (model): MLP(\n", - " (model): Sequential(\n", - " (0): Linear(in_features=4, out_features=16, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=16, out_features=16, bias=True)\n", - " (3): ReLU()\n", - " )\n", - " )\n", - " )\n", - " (last): MLP(\n", - " (model): Sequential(\n", - " (0): Linear(in_features=16, out_features=2, bias=True)\n", - " )\n", - " )\n", - " )\n", - ")\n", - "========================================\n", - "torch.Size([16, 4])\n", - "torch.Size([16])\n", - "torch.Size([16, 16])\n", - "torch.Size([16])\n", - "torch.Size([2, 16])\n", - "torch.Size([2])\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Making decision\n", - "Given a batch of observations, the policy can return a batch of actions and other data." - ], - "metadata": { - "id": "-RCrsttYgAG-" - } - }, - { - "cell_type": "code", - "source": [ - "obs_batch = Batch(obs=np.ones(shape=(256, 4)))\n", - "action = policy(obs_batch) # forward() method is called\n", - "print(action)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0jkBb6AAgUla", - "outputId": "37948844-cdd8-4567-9481-89453c80a157" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Batch(\n", - " act: tensor([1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0,\n", - " 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0,\n", - " 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0,\n", - " 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1,\n", - " 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0,\n", - " 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0,\n", - " 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1,\n", - " 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n", - " 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0,\n", - " 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1,\n", - " 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0]),\n", - " dist: Categorical(probs: torch.Size([256, 2])),\n", - ")\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Save and Load models\n", - "Naturally, Tianshou Policy can be saved and loaded like a normal Torch Network." - ], - "metadata": { - "id": "swikhnuDfKep" - } - }, - { - "cell_type": "code", - "source": [ - "torch.save(policy.state_dict(), 'policy.pth')\n", - "assert policy.load_state_dict(torch.load('policy.pth'))" - ], - "metadata": { - "id": "tYOoWM_OJRnA" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Algorithm Updating\n", - "We have to collect some data and save them in the ReplayBuffer before updating our agent(policy). Typically we use collector to collect data, but we leave this part till later when we have learned the Collector in Tianshou. For now we generate some **fake** data." - ], - "metadata": { - "id": "gp8PzOYsg5z-" - } - }, - { - "cell_type": "markdown", - "source": [ - "### Generating fake data\n", - "Firstly, we need to \"pretend\" that we are using the \"Policy\" to collect data. We plan to collect 10 data so that we can update our algorithm." - ], - "metadata": { - "id": "XrrPxOUAYShR" - } - }, - { - "cell_type": "code", - "source": [ - "import gym\n", - "from tianshou.data import Batch, ReplayBuffer\n", - "# a buffer is initialised with its maxsize set to 20.\n", - "print(\"========================================\")\n", - "buf = ReplayBuffer(size=12)\n", - "print(buf)\n", - "print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n", - "env = gym.make(\"CartPole-v0\")" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "a14CmzSfYh5C", - "outputId": "aaf45a1f-5e21-4bc8-cbe3-8ce798258af0" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "========================================\n", - "ReplayBuffer()\n", - "maxsize: 12, data length: 0\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "Now we are pretending to collect the first episode. The first episode ends at step 3 (perhaps because we are performing too badly)." - ], - "metadata": { - "id": "8S94cV7yZITR" - } - }, - { - "cell_type": "code", - "source": [ - "obs = env.reset()\n", - "for i in range(3):\n", - " act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n", - " obs_next, rew, done, info = env.step(act)\n", - " # pretend ending at step 3\n", - " done = True if i==2 else False\n", - " info[\"id\"] = i\n", - " buf.add(Batch(obs=obs, act=act, rew=rew, done=done, obs_next=obs_next, info=info))\n", - " obs = obs_next" - ], - "metadata": { - "id": "a_mtvbmBZbfs" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Now we are pretending to collect the second episode. At step 7 the second episode still does't end, but we are unwilling to wait, so we stop collecting to update the algorithm." - ], - "metadata": { - "id": "pkxq4gu9bGkt" - } - }, - { - "cell_type": "code", - "source": [ - "obs = env.reset()\n", - "for i in range(3, 10):\n", - " act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n", - " obs_next, rew, done, info = env.step(act)\n", - " # pretend this episode never end\n", - " done = False\n", - " info[\"id\"] = i\n", - " buf.add(Batch(obs=obs, act=act, rew=rew, done=done, obs_next=obs_next, info=info))\n", - " obs = obs_next" - ], - "metadata": { - "id": "pAoKe02ybG68" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Our replay buffer looks like this now." - ], - "metadata": { - "id": "MKM6aWMucv-M" - } - }, - { - "cell_type": "code", - "source": [ - "print(buf)\n", - "print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "CSJEEWOqXdTU", - "outputId": "2b3bb75c-f219-4e82-ca78-0ea6173a91f9" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "ReplayBuffer(\n", - " done: array([False, False, True, False, False, False, False, False, False,\n", - " False, False, False]),\n", - " obs: array([[-0.01684963, -0.00896152, 0.00930936, 0.00748042],\n", - " [-0.01684963, -0.00896152, 0.00930936, 0.00748042],\n", - " [-0.01684963, -0.00896152, 0.00930936, 0.00748042],\n", - " [-0.04934945, 0.01028611, -0.01101364, -0.0451668 ],\n", - " [-0.04934945, 0.01028611, -0.01101364, -0.0451668 ],\n", - " [-0.04934945, 0.01028611, -0.01101364, -0.0451668 ],\n", - " [-0.04934945, 0.01028611, -0.01101364, -0.0451668 ],\n", - " [-0.04934945, 0.01028611, -0.01101364, -0.0451668 ],\n", - " [-0.04934945, 0.01028611, -0.01101364, -0.0451668 ],\n", - " [-0.04934945, 0.01028611, -0.01101364, -0.0451668 ],\n", - " [ 0. , 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , 0. ]]),\n", - " rew: array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.]),\n", - " info: Batch(\n", - " id: array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0]),\n", - " ),\n", - " obs_next: array([[-0.01702886, -0.20421573, 0.00945897, 0.30308596],\n", - " [-0.02111317, -0.00922985, 0.01552069, 0.01340113],\n", - " [-0.02129777, -0.20457091, 0.01578871, 0.31094032],\n", - " [-0.04914372, -0.18467619, -0.01191698, 0.24402097],\n", - " [-0.05283725, -0.37962592, -0.00703656, 0.53292129],\n", - " [-0.06042977, -0.5746482 , 0.00362187, 0.82337874],\n", - " [-0.07192273, -0.37957599, 0.02008944, 0.53183716],\n", - " [-0.07951425, -0.18474228, 0.03072618, 0.24555147],\n", - " [-0.0832091 , -0.3802893 , 0.03563721, 0.54776563],\n", - " [-0.09081488, -0.57589331, 0.04659253, 0.85146047],\n", - " [ 0. , 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , 0. ]]),\n", - " act: array([0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]),\n", - ")\n", - "maxsize: 12, data length: 10\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Updates\n", - "Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train." - ], - "metadata": { - "id": "55VWhWpkdfEb" - } - }, - { - "cell_type": "code", - "source": [ - "# 0 means sample all data from the buffer\n", - "# batch_size=10 defines the training batch size\n", - "# repeat=6 means repeat the training for 6 times\n", - "policy.update(0, buf, batch_size=10, repeat=6)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "i_O1lJDWdeoc", - "outputId": "b154741a-d6dc-46cb-898f-6e84fa14e5a7" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "{'loss': [2.414336681365967,\n", - " 2.412271499633789,\n", - " 2.410210609436035,\n", - " 2.4081532955169678,\n", - " 2.406100273132324,\n", - " 2.404050827026367]}" - ] - }, - "metadata": {}, - "execution_count": 45 - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "Not that difficult, right?" - ], - "metadata": { - "id": "enqlFQLSJrQl" - } - }, - { - "cell_type": "markdown", - "source": [ - "# Further Reading\n", - "\n", - "\n" - ], - "metadata": { - "id": "QJ5krjrcbuiA" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Pre-defined Networks\n", - "Tianshou provides numberous pre-defined networks usually used in DRL so that you don't have to bother yourself. Check this [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.utils.html#pre-defined-networks) for details." - ], - "metadata": { - "id": "pmWi3HuXWcV8" - } - }, - { - "cell_type": "markdown", - "source": [ - "## How to compute GAE on your own?\n", - "(Note that for this reading you need to understand the calculation of [GAE](https://arxiv.org/abs/1506.02438) advantage first)\n", - "\n", - "In terms of code implementation, perhaps the most difficult and annoying part is computing GAE advantage. Just now, we use the `self.compute_episodic_return()` method inherited from `BasePolicy` to save us from all those troubles. However, it is still important that we know the details behind this.\n", - "\n", - "To compute GAE advantage, the usage of `self.compute_episodic_return()` may goes like:" - ], - "metadata": { - "id": "UPVl5LBEWJ0t" - } - }, - { - "cell_type": "code", - "source": [ - "batch, indices = buf.sample(0) # 0 means sampling all the data from the buffer\n", - "returns, advantage = BasePolicy.compute_episodic_return(batch, buf, indices, v_s_=np.zeros(10), v_s=np.zeros(10), gamma=1.0, gae_lambda=1.0)\n", - "print(returns)\n", - "print(advantage)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "D34GlVvPNz08", - "outputId": "43a4e5df-59b5-4e4a-c61c-e69090810215" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "[3. 2. 1. 7. 6. 5. 4. 3. 2. 1.]\n", - "[3. 2. 1. 7. 6. 5. 4. 3. 2. 1.]\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "In the code above, we sample all the 10 data in the buffer and try to compute the GAE advantage. As we know, we need to estimate the value function of every observation to compute GAE advantage. so the passed in `v_s` is the value of bacth.obs, `v_s_` is the value of bacth.obs_next this is usually computed by:\n", - "\n", - "`v_s = critic(bacth.obs)`,\n", - "\n", - "`v_s_ = critic(bacth.obs_next)`,\n", - "\n", - "where uboth `v_s` and `v_s_` are 10 dimensional arrays and `critic` is usually a neural network.\n", - "\n", - "After we've got all those values, GAE can be computed following the equation below." - ], - "metadata": { - "id": "h_5Dt6XwQLXV" - } - }, - { - "cell_type": "markdown", - "source": [ - "\\begin{aligned}\n", - "\\hat{A}_{t}^{\\mathrm{GAE}(\\gamma, \\lambda)}: =& \\sum_{l=0}^{\\infty}(\\gamma \\lambda)^{l} \\delta_{t+l}^{V}\n", - "\\end{aligned}\n", - "\n", - "while\n", - "\n", - "\\begin{equation}\n", - "\\delta_{t}^{V} \\quad=-V\\left(s_{t}\\right)+r_{t}+\\gamma V\\left(s_{t+1}\\right)\n", - "\\end{equation}\n" - ], - "metadata": { - "id": "ooHNIICGUO19" - } - }, - { - "cell_type": "markdown", - "source": [ - "But, if you do follow this equation I refered from the paper. You probably will get a slightly lower performance than you expected. There are at least 3 \"bugs\" in this equation." - ], - "metadata": { - "id": "eV6XZaouU7EV" - } - }, - { - "cell_type": "markdown", - "source": [ - "**First** is that Gym always returns you a `obs_next` even if this is already the last step. The value of this timestep is exactly 0 and you should not let the neural network estimate it." - ], - "metadata": { - "id": "FCxD9gNNVYbd" - } - }, - { - "cell_type": "code", - "source": [ - "import copy\n", - "# Assume v_s_ is got by calling critic(bacth.obs_next)\n", - "v_s_ = np.ones(10)\n", - "v_s_ *= ~batch.done\n", - "print(v_s_)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "rNZNUNgQVvRJ", - "outputId": "44354595-c25a-4da8-b4d8-cffa31ac4b7d" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "[1. 1. 0. 1. 1. 1. 1. 1. 1. 1.]\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "After the fix above, we will perhaps get a more accurate estimate.\n", - "\n", - "**Secondly**, you must know when to stop bootstrapping. Usually we stop bootstrapping when we meet a `done` flag. However, in the buffer above, the last (10th) step is not marked by done=True, because the collecting has not finished. We must know all those unfinished steps so that we know when to stop bootstraping.\n", - "\n", - "Luckily, this can be done under the assistance of buffer because buffers in Tianshou not only store data, but also help you manage data trajectories." - ], - "metadata": { - "id": "2EtMi18QWXTN" - } - }, - { - "cell_type": "code", - "source": [ - "unfinished_indexes = buf.unfinished_index()\n", - "print(unfinished_indexes)\n", - "done_indexes = np.where(batch.done)[0]\n", - "print(done_indexes)\n", - "stop_bootstrap_ids = np.concatenate([unfinished_indexes, done_indexes])\n", - "print(stop_bootstrap_ids)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "saluvX4JU6bC", - "outputId": "2994d178-2f33-40a0-a6e4-067916b0b5c5" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "[9]\n", - "[2]\n", - "[9 2]\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "**Thirdly**, there are some special indexes which are marked by done flag. However, its value for obs_next should not be zero. This is because these steps are usually those at the last step of an episode, but this episode stops not because the agent can no longer get any rewards (value=0), but because the episode is too long so we have to truncate it. These kind of steps are always marked with `info['TimeLimit.truncated']=True` in Gym." - ], - "metadata": { - "id": "qp6vVE4dYWv1" - } - }, - { - "cell_type": "markdown", - "source": [ - "As a result, we need to rewrite the equation above\n", - "\n", - "`v_s_ *= ~batch.done`" - ], - "metadata": { - "id": "tWkqXRJfZTvV" - } - }, - { - "cell_type": "markdown", - "source": [ - "to\n", - "\n", - "```\n", - "mask = batch.info['TimeLimit.truncated'] | (~batch.done)\n", - "v_s_ *= mask\n", - "\n", - "```\n", - "\n", - "\n", - "\n" - ], - "metadata": { - "id": "kms-QtxKZe-M" - } - }, - { - "cell_type": "markdown", - "source": [ - "### Summary\n", - "If you already felt bored by now, simply remember that Tianshou can help handle all these little details so that you can focus on the algorithm itself. Just call `BasePolicy.compute_episodic_return()`.\n", - "\n", - "If you still feel interested, we would recommend you check Appendix C in this [paper](https://arxiv.org/abs/2107.14171v2) and implementation of `BasePolicy.value_mask()` and `BasePolicy.compute_episodic_return()` for details." - ], - "metadata": { - "id": "u_aPPoKraBu6" - } - }, - { - "cell_type": "markdown", - "source": [ - "\n", - "![timelimit.svg]()\n", - "\n", - "![22.PNG]()" - ], - "metadata": { - "id": "2cPnUXRBWKD9" - } - } - ] -} \ No newline at end of file + "id": "saluvX4JU6bC", + "outputId": "2994d178-2f33-40a0-a6e4-067916b0b5c5" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "**Thirdly**, there are some special indexes which are marked by done flag. However, its value for obs_next should not be zero. This is because these steps are usually those at the last step of an episode, but this episode stops not because the agent can no longer get any rewards (value=0), but because the episode is too long so we have to truncate it. These kind of steps are always marked with `info['TimeLimit.truncated']=True` in Gym." + ], + "metadata": { + "id": "qp6vVE4dYWv1" + } + }, + { + "cell_type": "markdown", + "source": [ + "As a result, we need to rewrite the equation above\n", + "\n", + "`v_s_ *= ~batch.done`" + ], + "metadata": { + "id": "tWkqXRJfZTvV" + } + }, + { + "cell_type": "markdown", + "source": [ + "to\n", + "\n", + "```\n", + "mask = batch.info['TimeLimit.truncated'] | (~batch.done)\n", + "v_s_ *= mask\n", + "\n", + "```\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "kms-QtxKZe-M" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Summary\n", + "If you already felt bored by now, simply remember that Tianshou can help handle all these little details so that you can focus on the algorithm itself. Just call `BasePolicy.compute_episodic_return()`.\n", + "\n", + "If you still feel interested, we would recommend you check Appendix C in this [paper](https://arxiv.org/abs/2107.14171v2) and implementation of `BasePolicy.value_mask()` and `BasePolicy.compute_episodic_return()` for details." + ], + "metadata": { + "id": "u_aPPoKraBu6" + } + }, + { + "cell_type": "markdown", + "source": [ + "\n", + "![timelimit.svg]()\n", + "\n", + "![22.PNG]()" + ], + "metadata": { + "id": "2cPnUXRBWKD9" + } + } + ] +} diff --git a/notebooks/L5_Collector.ipynb b/notebooks/L5_Collector.ipynb index 17ea547..e6e6818 100644 --- a/notebooks/L5_Collector.ipynb +++ b/notebooks/L5_Collector.ipynb @@ -1,365 +1,263 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] }, - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "AKoInktmr-3t" - }, - "outputs": [], - "source": [ - "# Remember to install tianshou first\n", - "!pip install tianshou==0.4.8\n", - "!pip install gym" - ] + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "AKoInktmr-3t" + }, + "outputs": [], + "source": [ + "# Remember to install tianshou first\n", + "!pip install tianshou==0.4.8\n", + "!pip install gym" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Overview\n", + "From its literal meaning, we can easily know that the Collector in Tianshou is used to collect training data. More specificly, the Collector controls the interaction between Policy (agent) and the environment. It also helps save the interaction data into the ReplayBuffer and returns episode statistics.\n", + "\n", + "
\n", + "\n", + "
\n", + "\n" + ], + "metadata": { + "id": "M98bqxdMsTXK" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Usages\n", + "Collector can be used both for training (data collecting) and evaluation in Tianshou." + ], + "metadata": { + "id": "OX5cayLv4Ziu" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Policy evaluation\n", + "We need to evaluate our trained policy from time to time in DRL experiments. Collector can help us with this.\n", + "\n", + "First we have to initialise a Collector with an (vectorized) environment and a given policy (agent)." + ], + "metadata": { + "id": "Z6XKbj28u8Ze" + } + }, + { + "cell_type": "code", + "source": [ + "import gym\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from tianshou.data import Collector\n", + "from tianshou.env import DummyVectorEnv\n", + "from tianshou.policy import PGPolicy\n", + "from tianshou.utils.net.common import Net\n", + "from tianshou.utils.net.discrete import Actor\n", + "\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "env = gym.make(\"CartPole-v0\")\n", + "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v0\") for _ in range(2)])\n", + "\n", + "# model\n", + "net = Net(env.observation_space.shape, hidden_sizes=[16,])\n", + "actor = Actor(net, env.action_space.shape)\n", + "optim = torch.optim.Adam(actor.parameters(), lr=0.0003)\n", + "\n", + "policy = PGPolicy(actor, optim, dist_fn=torch.distributions.Categorical)\n", + "test_collector = Collector(policy, test_envs)" + ], + "metadata": { + "id": "w8t9ubO7u69J" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now we would like to collect 9 episodes of data to test how our initialised Policy performs." + ], + "metadata": { + "id": "wmt8vuwpzQdR" + } + }, + { + "cell_type": "code", + "source": [ + "collect_result = test_collector.collect(n_episode=9)\n", + "print(collect_result)\n", + "print(\"Rewards of 9 episodes are {}\".format(collect_result[\"rews\"]))\n", + "print(\"Average episode reward is {}.\".format(collect_result[\"rew\"]))\n", + "print(\"Average episode length is {}.\".format(collect_result[\"len\"]))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "# Overview\n", - "From its literal meaning, we can easily know that the Collector in Tianshou is used to collect training data. More specificly, the Collector controls the interaction between Policy (agent) and the environment. It also helps save the interaction data into the ReplayBuffer and returns episode statistics.\n", - "\n", - "
\n", - "\n", - "
\n", - "\n" - ], - "metadata": { - "id": "M98bqxdMsTXK" - } + "id": "9SuT6MClyjyH", + "outputId": "1e48f13b-c1fe-4fc2-ca1b-669485efdcae" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now we wonder what is the performance of a random policy." + ], + "metadata": { + "id": "zX9AQY0M0R3C" + } + }, + { + "cell_type": "code", + "source": [ + "# Reset the collector\n", + "test_collector.reset()\n", + "collect_result = test_collector.collect(n_episode=9, random=True)\n", + "print(collect_result)\n", + "print(\"Rewards of 9 episodes are {}\".format(collect_result[\"rews\"]))\n", + "print(\"Average episode reward is {}.\".format(collect_result[\"rew\"]))\n", + "print(\"Average episode length is {}.\".format(collect_result[\"len\"]))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "# Usages\n", - "Collector can be used both for training (data collecting) and evaluation in Tianshou." - ], - "metadata": { - "id": "OX5cayLv4Ziu" - } + "id": "UEcs8P8P0RLt", + "outputId": "85f02f9d-b79b-48b2-99c6-36a1602f0884" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Seems that an initialised policy performs even worse than a random policy without any training." + ], + "metadata": { + "id": "sKQRTiG10ljU" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Data Collecting\n", + "Data collecting is mostly used during training, when we need to store the collected data in a ReplayBuffer." + ], + "metadata": { + "id": "8RKmHIoG1A1k" + } + }, + { + "cell_type": "code", + "source": [ + "from tianshou.data import VectorReplayBuffer\n", + "train_env_num = 4\n", + "buffer_size = 100\n", + "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v0\") for _ in range(train_env_num)])\n", + "replaybuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", + "\n", + "train_collector = Collector(policy, train_envs, replaybuffer)" + ], + "metadata": { + "id": "CB9XB9bF1YPC" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now we can collect 50 steps of data, which will be automatically saved in the replay buffer. You can still choose to collect a certain number of episodes rather than steps. Try it yourself." + ], + "metadata": { + "id": "rWKDazA42IUQ" + } + }, + { + "cell_type": "code", + "source": [ + "print(len(replaybuffer))\n", + "collect_result = train_collector.collect(n_step=50)\n", + "print(len(replaybuffer))\n", + "print(collect_result)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "## Policy evaluation\n", - "We need to evaluate our trained policy from time to time in DRL experiments. Collector can help us with this.\n", - "\n", - "First we have to initialise a Collector with an (vectorized) environment and a given policy (agent)." - ], - "metadata": { - "id": "Z6XKbj28u8Ze" - } + "id": "-fUtQOnM2Yi1", + "outputId": "dceee987-433e-4b75-ed9e-823c20a9e1c2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "for i in range(13):\n", + " print(i, replaybuffer.next(i))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "source": [ - "import gym\n", - "import numpy as np\n", - "import torch\n", - "\n", - "from tianshou.data import Collector\n", - "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import PGPolicy\n", - "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor\n", - "\n", - "import warnings\n", - "warnings.filterwarnings('ignore')\n", - "\n", - "env = gym.make(\"CartPole-v0\")\n", - "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v0\") for _ in range(2)])\n", - "\n", - "# model\n", - "net = Net(env.observation_space.shape, hidden_sizes=[16,])\n", - "actor = Actor(net, env.action_space.shape)\n", - "optim = torch.optim.Adam(actor.parameters(), lr=0.0003)\n", - "\n", - "policy = PGPolicy(actor, optim, dist_fn=torch.distributions.Categorical)\n", - "test_collector = Collector(policy, test_envs)" - ], - "metadata": { - "id": "w8t9ubO7u69J" - }, - "execution_count": null, - "outputs": [] + "id": "EWO4A7plefwM", + "outputId": "9a6f36d1-2b84-49b0-a03d-a8ebe8acadbf" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "replaybuffer.sample(10)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "Now we would like to collect 9 episodes of data to test how our initialised Policy performs." - ], - "metadata": { - "id": "wmt8vuwpzQdR" - } - }, - { - "cell_type": "code", - "source": [ - "collect_result = test_collector.collect(n_episode=9)\n", - "print(collect_result)\n", - "print(\"Rewards of 9 episodes are {}\".format(collect_result[\"rews\"]))\n", - "print(\"Average episode reward is {}.\".format(collect_result[\"rew\"]))\n", - "print(\"Average episode length is {}.\".format(collect_result[\"len\"]))" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "9SuT6MClyjyH", - "outputId": "1e48f13b-c1fe-4fc2-ca1b-669485efdcae" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "{'n/ep': 9, 'n/st': 85, 'rews': array([ 9., 9., 10., 10., 9., 9., 10., 10., 9.]), 'lens': array([ 9, 9, 10, 10, 9, 9, 10, 10, 9]), 'idxs': array([0, 1, 0, 1, 0, 1, 0, 1, 1]), 'rew': 9.444444444444445, 'len': 9.444444444444445, 'rew_std': 0.49690399499995325, 'len_std': 0.49690399499995325}\n", - "Rewards of 9 episodes are [ 9. 9. 10. 10. 9. 9. 10. 10. 9.]\n", - "Average episode reward is 9.444444444444445.\n", - "Average episode length is 9.444444444444445.\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "Now we wonder what is the performance of a random policy." - ], - "metadata": { - "id": "zX9AQY0M0R3C" - } - }, - { - "cell_type": "code", - "source": [ - "# Reset the collector\n", - "test_collector.reset()\n", - "collect_result = test_collector.collect(n_episode=9, random=True)\n", - "print(collect_result)\n", - "print(\"Rewards of 9 episodes are {}\".format(collect_result[\"rews\"]))\n", - "print(\"Average episode reward is {}.\".format(collect_result[\"rew\"]))\n", - "print(\"Average episode length is {}.\".format(collect_result[\"len\"]))" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "UEcs8P8P0RLt", - "outputId": "85f02f9d-b79b-48b2-99c6-36a1602f0884" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "{'n/ep': 9, 'n/st': 187, 'rews': array([13., 14., 11., 44., 35., 14., 19., 17., 20.]), 'lens': array([13, 14, 11, 44, 35, 14, 19, 17, 20]), 'idxs': array([1, 0, 0, 1, 0, 0, 1, 0, 1]), 'rew': 20.77777777777778, 'len': 20.77777777777778, 'rew_std': 10.580671872993257, 'len_std': 10.580671872993257}\n", - "Rewards of 9 episodes are [13. 14. 11. 44. 35. 14. 19. 17. 20.]\n", - "Average episode reward is 20.77777777777778.\n", - "Average episode length is 20.77777777777778.\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "Seems that an initialised policy performs even worse than a random policy without any training." - ], - "metadata": { - "id": "sKQRTiG10ljU" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Data Collecting\n", - "Data collecting is mostly used during training, when we need to store the collected data in a ReplayBuffer." - ], - "metadata": { - "id": "8RKmHIoG1A1k" - } - }, - { - "cell_type": "code", - "source": [ - "from tianshou.data import VectorReplayBuffer\n", - "train_env_num = 4\n", - "buffer_size = 100\n", - "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v0\") for _ in range(train_env_num)])\n", - "replaybuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", - "\n", - "train_collector = Collector(policy, train_envs, replaybuffer)" - ], - "metadata": { - "id": "CB9XB9bF1YPC" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Now we can collect 50 steps of data, which will be automatically saved in the replay buffer. You can still choose to collect a certain number of episodes rather than steps. Try it yourself." - ], - "metadata": { - "id": "rWKDazA42IUQ" - } - }, - { - "cell_type": "code", - "source": [ - "print(len(replaybuffer))\n", - "collect_result = train_collector.collect(n_step=50)\n", - "print(len(replaybuffer))\n", - "print(collect_result)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "-fUtQOnM2Yi1", - "outputId": "dceee987-433e-4b75-ed9e-823c20a9e1c2" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "0\n", - "52\n", - "{'n/ep': 4, 'n/st': 52, 'rews': array([ 8., 10., 10., 10.]), 'lens': array([ 8, 10, 10, 10]), 'idxs': array([25, 0, 50, 75]), 'rew': 9.5, 'len': 9.5, 'rew_std': 0.8660254037844386, 'len_std': 0.8660254037844386}\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "for i in range(13):\n", - " print(i, replaybuffer.next(i))" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "EWO4A7plefwM", - "outputId": "9a6f36d1-2b84-49b0-a03d-a8ebe8acadbf" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "0 1\n", - "1 2\n", - "2 3\n", - "3 4\n", - "4 5\n", - "5 6\n", - "6 7\n", - "7 8\n", - "8 9\n", - "9 9\n", - "10 11\n", - "11 12\n", - "12 12\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "replaybuffer.sample(10)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "HW8PpWH9fLCo", - "outputId": "7ca70c50-23b9-4405-9e42-2e5771cd9c78" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(Batch(\n", - " obs: array([[-6.36541036e-02, -1.15297838e+00, 7.85138179e-02,\n", - " 1.77225357e+00],\n", - " [-1.05090645e-02, -3.71521519e-01, -2.96323181e-03,\n", - " 5.76793524e-01],\n", - " [-2.45282997e-02, 4.77850180e-02, 2.21142716e-02,\n", - " 4.96743371e-02],\n", - " [ 1.68433453e-03, 4.47272356e-02, -1.72360346e-02,\n", - " -7.74977680e-03],\n", - " [ 4.29854159e-02, -3.95380051e-01, 3.91005958e-02,\n", - " 6.41183774e-01],\n", - " [ 3.50778149e-02, -5.91024637e-01, 5.19242712e-02,\n", - " 9.45918993e-01],\n", - " [-1.19358173e-01, -1.38179912e+00, 1.04318690e-01,\n", - " 2.10185768e+00],\n", - " [-3.48454722e-02, 1.61145106e-03, -2.49944951e-02,\n", - " -3.82213155e-02],\n", - " [-3.48454722e-02, 1.61145106e-03, -2.49944951e-02,\n", - " -3.82213155e-02],\n", - " [-4.64338716e-02, -5.82710950e-01, -1.02110827e-02,\n", - " 8.16967413e-01]]),\n", - " act: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n", - " rew: array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),\n", - " done: array([False, False, False, False, False, False, False, False, False,\n", - " False]),\n", - " obs_next: array([[-0.08671367, -1.34889318, 0.11395889, 2.08827982],\n", - " [-0.01793949, -0.56660181, 0.00857264, 0.86854149],\n", - " [-0.0235726 , -0.14764694, 0.02310776, 0.34925166],\n", - " [ 0.00257888, -0.15014334, -0.01739103, 0.27944553],\n", - " [ 0.03507781, -0.59102464, 0.05192427, 0.94591899],\n", - " [ 0.02325732, -0.78680603, 0.07084265, 1.25445415],\n", - " [-0.14699416, -1.57780202, 0.14635584, 2.42487783],\n", - " [-0.03481324, -0.19314333, -0.02575892, 0.24647199],\n", - " [-0.03481324, -0.19314333, -0.02575892, 0.24647199],\n", - " [-0.05808809, -0.77769163, 0.00612827, 1.10642118]]),\n", - " info: Batch(\n", - " env_id: array([0, 0, 0, 1, 1, 1, 2, 3, 3, 3]),\n", - " ),\n", - " policy: Batch(),\n", - " ), array([ 6, 2, 10, 33, 27, 28, 57, 75, 75, 78]))" - ] - }, - "metadata": {}, - "execution_count": 10 - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "# Further Reading\n", - "The above collector actually collects 52 data at a time because 52 % 4 = 0. There is one asynchronous collector which allows you collect exactly 50 steps. Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.data.html#asynccollector) for details." - ], - "metadata": { - "id": "8NP7lOBU3-VS" - } - } - ] -} \ No newline at end of file + "id": "HW8PpWH9fLCo", + "outputId": "7ca70c50-23b9-4405-9e42-2e5771cd9c78" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Further Reading\n", + "The above collector actually collects 52 data at a time because 52 % 4 = 0. There is one asynchronous collector which allows you collect exactly 50 steps. Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.data.html#asynccollector) for details." + ], + "metadata": { + "id": "8NP7lOBU3-VS" + } + } + ] +} diff --git a/notebooks/L6_Trainer.ipynb b/notebooks/L6_Trainer.ipynb index c75ad85..c803460 100644 --- a/notebooks/L6_Trainer.ipynb +++ b/notebooks/L6_Trainer.ipynb @@ -1,379 +1,227 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "collapsed_sections": [ - "S3-tJZy35Ck_", - "XfsuU2AAE52C", - "p-7U_cwgF5Ej", - "_j3aUJZQ7nml" - ] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "collapsed_sections": [ + "S3-tJZy35Ck_", + "XfsuU2AAE52C", + "p-7U_cwgF5Ej", + "_j3aUJZQ7nml" + ] }, - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wDZlC0v348Ym" - }, - "outputs": [], - "source": [ - "# Remember to install tianshou first\n", - "!pip install tianshou==0.4.8\n", - "!pip install gym" - ] + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wDZlC0v348Ym" + }, + "outputs": [], + "source": [ + "# Remember to install tianshou first\n", + "!pip install tianshou==0.4.8\n", + "!pip install gym" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Overview\n", + "Trainer is the highest-level encapsulation in Tianshou. It controls the training loop and the evaluation method. It also controls the interaction between the Collector and the Policy, with the ReplayBuffer serving as the media.\n", + "\n", + "![framework.svg]()\n", + "\n" + ], + "metadata": { + "id": "S3-tJZy35Ck_" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Usages\n", + "In Tianshou v0.4.7, there are three types of Trainer. They are designed to be used in on-policy training, off-policy training and offline training respectively. We will use on-policy trainer as an example and leave the other two for further reading." + ], + "metadata": { + "id": "ifsEQMzZ6mmz" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Pseudocode\n", + "![1.PNG]()\n", + "\n", + "For the on-policy trainer, the main difference is that we clear the buffer after Line 10." + ], + "metadata": { + "id": "XfsuU2AAE52C" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Training without trainer\n", + "As we have learned the usages of the Collector and the Policy, it's possible that we write our own training logic.\n", + "\n", + "First, let us create the instances of Environment, ReplayBuffer, Policy and Collector." + ], + "metadata": { + "id": "Hcp_o0CCFz12" + } + }, + { + "cell_type": "code", + "source": [ + "import gym\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from tianshou.data import Collector, VectorReplayBuffer\n", + "from tianshou.env import DummyVectorEnv\n", + "from tianshou.policy import PGPolicy\n", + "from tianshou.utils.net.common import Net\n", + "from tianshou.utils.net.discrete import Actor\n", + "\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "train_env_num = 4\n", + "buffer_size = 2000 # Since REINFORCE is an on-policy algorithm, we don't need a very large buffer size\n", + "\n", + "# Create the environments, used for training and evaluation\n", + "env = gym.make(\"CartPole-v0\")\n", + "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v0\") for _ in range(2)])\n", + "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v0\") for _ in range(train_env_num)])\n", + "\n", + "# Create the Policy instance\n", + "net = Net(env.observation_space.shape, hidden_sizes=[16,])\n", + "actor = Actor(net, env.action_space.shape)\n", + "optim = torch.optim.Adam(actor.parameters(), lr=0.001)\n", + "policy = PGPolicy(actor, optim, dist_fn=torch.distributions.Categorical)\n", + "\n", + "# Create the replay buffer and the collector\n", + "replaybuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", + "test_collector = Collector(policy, test_envs)\n", + "train_collector = Collector(policy, train_envs, replaybuffer)" + ], + "metadata": { + "id": "do-xZ-8B7nVH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now, we can try training our policy network. The logic is simple. We collect some data into the buffer and then we use the data to train our policy." + ], + "metadata": { + "id": "wiEGiBgQIiFM" + } + }, + { + "cell_type": "code", + "source": [ + "train_collector.reset()\n", + "train_envs.reset()\n", + "test_collector.reset()\n", + "test_envs.reset()\n", + "replaybuffer.reset()\n", + "for i in range(10):\n", + " evaluation_result = test_collector.collect(n_episode=10)\n", + " print(\"Evaluation reward is {}\".format(evaluation_result[\"rew\"]))\n", + " train_collector.collect(n_step=2000)\n", + " # 0 means taking all data stored in train_collector.buffer\n", + " policy.update(0, train_collector.buffer, batch_size=512, repeat=1)\n", + " train_collector.reset_buffer(keep_statistics=True)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "# Overview\n", - "Trainer is the highest-level encapsulation in Tianshou. It controls the training loop and the evaluation method. It also controls the interaction between the Collector and the Policy, with the ReplayBuffer serving as the media.\n", - "\n", - "![framework.svg]()\n", - "\n" - ], - "metadata": { - "id": "S3-tJZy35Ck_" - } + "id": "JMUNPN5SI_kd", + "outputId": "7d68323c-0322-4b82-dafb-7c7f63e7a26d" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "The evaluation reward doesn't seem to improve. That is simply because we haven't trained it for enough time. Plus, the network size is too small and REINFORCE algorithm is actually not very stable. Don't worry, we will solve this problem in the end. Still we get some idea on how to start a training loop." + ], + "metadata": { + "id": "QXBHIBckMs_2" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Training with trainer\n", + "The trainer does almost the same thing. The only difference is that it has considered many details and is more modular." + ], + "metadata": { + "id": "p-7U_cwgF5Ej" + } + }, + { + "cell_type": "code", + "source": [ + "from tianshou.trainer import onpolicy_trainer\n", + "\n", + "train_collector.reset()\n", + "train_envs.reset()\n", + "test_collector.reset()\n", + "test_envs.reset()\n", + "replaybuffer.reset()\n", + "\n", + "result = onpolicy_trainer(\n", + " policy,\n", + " train_collector,\n", + " test_collector,\n", + " max_epoch=10,\n", + " step_per_epoch=1,\n", + " repeat_per_collect=1,\n", + " episode_per_test=10,\n", + " step_per_collect=2000,\n", + " batch_size=512,\n", + ")\n", + "print(result)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "# Usages\n", - "In Tianshou v0.4.7, there are three types of Trainer. They are designed to be used in on-policy training, off-policy training and offline training respectively. We will use on-policy trainer as an example and leave the other two for further reading." - ], - "metadata": { - "id": "ifsEQMzZ6mmz" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Pseudocode\n", - "![1.PNG]()\n", - "\n", - "For the on-policy trainer, the main difference is that we clear the buffer after Line 10." - ], - "metadata": { - "id": "XfsuU2AAE52C" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Training without trainer\n", - "As we have learned the usages of the Collector and the Policy, it's possible that we write our own training logic.\n", - "\n", - "First, let us create the instances of Environment, ReplayBuffer, Policy and Collector." - ], - "metadata": { - "id": "Hcp_o0CCFz12" - } - }, - { - "cell_type": "code", - "source": [ - "import gym\n", - "import numpy as np\n", - "import torch\n", - "\n", - "from tianshou.data import Collector, VectorReplayBuffer\n", - "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import PGPolicy\n", - "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor\n", - "\n", - "import warnings\n", - "warnings.filterwarnings('ignore')\n", - "\n", - "train_env_num = 4\n", - "buffer_size = 2000 # Since REINFORCE is an on-policy algorithm, we don't need a very large buffer size\n", - "\n", - "# Create the environments, used for training and evaluation\n", - "env = gym.make(\"CartPole-v0\")\n", - "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v0\") for _ in range(2)])\n", - "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v0\") for _ in range(train_env_num)])\n", - "\n", - "# Create the Policy instance\n", - "net = Net(env.observation_space.shape, hidden_sizes=[16,])\n", - "actor = Actor(net, env.action_space.shape)\n", - "optim = torch.optim.Adam(actor.parameters(), lr=0.001)\n", - "policy = PGPolicy(actor, optim, dist_fn=torch.distributions.Categorical)\n", - "\n", - "# Create the replay buffer and the collector\n", - "replaybuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", - "test_collector = Collector(policy, test_envs)\n", - "train_collector = Collector(policy, train_envs, replaybuffer)" - ], - "metadata": { - "id": "do-xZ-8B7nVH" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Now, we can try training our policy network. The logic is simple. We collect some data into the buffer and then we use the data to train our policy." - ], - "metadata": { - "id": "wiEGiBgQIiFM" - } - }, - { - "cell_type": "code", - "source": [ - "train_collector.reset()\n", - "train_envs.reset()\n", - "test_collector.reset()\n", - "test_envs.reset()\n", - "replaybuffer.reset()\n", - "for i in range(10):\n", - " evaluation_result = test_collector.collect(n_episode=10)\n", - " print(\"Evaluation reward is {}\".format(evaluation_result[\"rew\"]))\n", - " train_collector.collect(n_step=2000)\n", - " # 0 means taking all data stored in train_collector.buffer\n", - " policy.update(0, train_collector.buffer, batch_size=512, repeat=1)\n", - " train_collector.reset_buffer(keep_statistics=True)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JMUNPN5SI_kd", - "outputId": "7d68323c-0322-4b82-dafb-7c7f63e7a26d" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Evaluation reward is 9.6\n", - "Evaluation reward is 9.6\n", - "Evaluation reward is 9.2\n", - "Evaluation reward is 9.1\n", - "Evaluation reward is 9.5\n", - "Evaluation reward is 9.7\n", - "Evaluation reward is 9.6\n", - "Evaluation reward is 9.4\n", - "Evaluation reward is 9.3\n", - "Evaluation reward is 9.1\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "The evaluation reward doesn't seem to improve. That is simply because we haven't trained it for enough time. Plus, the network size is too small and REINFORCE algorithm is actually not very stable. Don't worry, we will solve this problem in the end. Still we get some idea on how to start a training loop." - ], - "metadata": { - "id": "QXBHIBckMs_2" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Training with trainer\n", - "The trainer does almost the same thing. The only difference is that it has considered many details and is more modular." - ], - "metadata": { - "id": "p-7U_cwgF5Ej" - } - }, - { - "cell_type": "code", - "source": [ - "from tianshou.trainer import onpolicy_trainer\n", - "\n", - "train_collector.reset()\n", - "train_envs.reset()\n", - "test_collector.reset()\n", - "test_envs.reset()\n", - "replaybuffer.reset()\n", - "\n", - "result = onpolicy_trainer(\n", - " policy,\n", - " train_collector,\n", - " test_collector,\n", - " max_epoch=10,\n", - " step_per_epoch=1,\n", - " repeat_per_collect=1,\n", - " episode_per_test=10,\n", - " step_per_collect=2000,\n", - " batch_size=512,\n", - ")\n", - "print(result)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vcvw9J8RNtFE", - "outputId": "b483fa8b-2a57-4051-a3d0-6d8162d948c5" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Epoch #1: 2000it [00:00, 4144.84it/s, env_step=2000, len=9, loss=0.000, n/ep=213, n/st=2000, rew=9.34]\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epoch #1: test_reward: 9.500000 ± 0.500000, best_reward: 9.900000 ± 0.700000 in #0\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Epoch #2: 2000it [00:00, 4208.58it/s, env_step=4000, len=9, loss=0.000, n/ep=213, n/st=2000, rew=9.41]\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epoch #2: test_reward: 9.400000 ± 0.489898, best_reward: 9.900000 ± 0.700000 in #0\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Epoch #3: 2000it [00:00, 4472.80it/s, env_step=6000, len=9, loss=0.000, n/ep=212, n/st=2000, rew=9.39]\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epoch #3: test_reward: 9.100000 ± 0.700000, best_reward: 9.900000 ± 0.700000 in #0\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Epoch #4: 2000it [00:00, 4340.62it/s, env_step=8000, len=9, loss=0.000, n/ep=213, n/st=2000, rew=9.38]\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epoch #4: test_reward: 9.400000 ± 0.800000, best_reward: 9.900000 ± 0.700000 in #0\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Epoch #5: 2000it [00:00, 4483.35it/s, env_step=10000, len=9, loss=0.000, n/ep=213, n/st=2000, rew=9.42]\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epoch #5: test_reward: 9.400000 ± 1.019804, best_reward: 9.900000 ± 0.700000 in #0\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Epoch #6: 2000it [00:00, 4068.51it/s, env_step=12000, len=9, loss=0.000, n/ep=212, n/st=2000, rew=9.42]\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epoch #6: test_reward: 9.400000 ± 0.663325, best_reward: 9.900000 ± 0.700000 in #0\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Epoch #7: 2000it [00:00, 4091.46it/s, env_step=14000, len=9, loss=0.000, n/ep=214, n/st=2000, rew=9.32]\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epoch #7: test_reward: 9.300000 ± 0.640312, best_reward: 9.900000 ± 0.700000 in #0\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Epoch #8: 2000it [00:00, 4042.49it/s, env_step=16000, len=9, loss=0.000, n/ep=215, n/st=2000, rew=9.34]\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epoch #8: test_reward: 9.600000 ± 0.800000, best_reward: 9.900000 ± 0.700000 in #0\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Epoch #9: 2000it [00:00, 4400.16it/s, env_step=18000, len=9, loss=0.000, n/ep=213, n/st=2000, rew=9.38]" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epoch #9: test_reward: 9.000000 ± 0.632456, best_reward: 9.900000 ± 0.700000 in #0\n", - "{'duration': '4.79s', 'train_time/model': '0.22s', 'test_step': 940, 'test_episode': 100, 'test_time': '0.46s', 'test_speed': '2026.40 step/s', 'best_reward': 9.9, 'best_result': '9.90 ± 0.70', 'train_step': 18000, 'train_episode': 1918, 'train_time/collector': '4.11s', 'train_speed': '4156.80 step/s'}\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "# Further Reading\n", - "## Logger usages\n", - "Tianshou provides experiment loggers that are both tensorboard- and wandb-compatible. It also has a BaseLogger Class which allows you to self-define your own logger. Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.utils.html#tianshou.utils.BaseLogger) for details.\n", - "\n", - "## Learn more about the APIs of Trainers\n", - "[documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.trainer.html)" - ], - "metadata": { - "id": "_j3aUJZQ7nml" - } - } - ] -} \ No newline at end of file + "id": "vcvw9J8RNtFE", + "outputId": "b483fa8b-2a57-4051-a3d0-6d8162d948c5" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Further Reading\n", + "## Logger usages\n", + "Tianshou provides experiment loggers that are both tensorboard- and wandb-compatible. It also has a BaseLogger Class which allows you to self-define your own logger. Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.utils.html#tianshou.utils.BaseLogger) for details.\n", + "\n", + "## Learn more about the APIs of Trainers\n", + "[documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.trainer.html)" + ], + "metadata": { + "id": "_j3aUJZQ7nml" + } + } + ] +} diff --git a/notebooks/L7_Experiment.ipynb b/notebooks/L7_Experiment.ipynb index 0242ef0..c23b4a1 100644 --- a/notebooks/L7_Experiment.ipynb +++ b/notebooks/L7_Experiment.ipynb @@ -1,352 +1,321 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "toc_visible": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "toc_visible": true }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "# Overview\n", - "Finally, we can assemble building blocks that we have came across in previous tutorials to conduct our first DRL experiment. In this experiment, we will use [PPO](https://arxiv.org/abs/1707.06347) algorithm to solve the classic CartPole task in Gym." - ], - "metadata": { - "id": "_UaXOSRjDUF9" - } + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Overview\n", + "Finally, we can assemble building blocks that we have came across in previous tutorials to conduct our first DRL experiment. In this experiment, we will use [PPO](https://arxiv.org/abs/1707.06347) algorithm to solve the classic CartPole task in Gym." + ], + "metadata": { + "id": "_UaXOSRjDUF9" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Experiment\n", + "To conduct this experiment, we need the following building blocks.\n", + "\n", + "\n", + "* Two vectorized environments, one for training and one for evaluation\n", + "* A PPO agent\n", + "* A replay buffer to store transition data\n", + "* Two collectors to manage the data collecting process, one for training and one for evaluation\n", + "* A trainer to manage the training loop\n", + "\n", + "
\n", + "\n", + "\n", + "
\n", + "\n", + "Let us do this step by step." + ], + "metadata": { + "id": "2QRbCJvDHNAd" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Preparation\n", + "Firstly, install Tianshou if you haven't installed it before." + ], + "metadata": { + "id": "-Hh4E6i0Hj0I" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install tianshou==0.4.8\n", + "!pip install gym" + ], + "metadata": { + "id": "w50BVwaRHg3N" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Import libraries we might need later." + ], + "metadata": { + "id": "7E4EhiBeHxD5" + } + }, + { + "cell_type": "code", + "source": [ + "import gym\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from tianshou.data import Collector, VectorReplayBuffer\n", + "from tianshou.env import DummyVectorEnv\n", + "from tianshou.policy import PPOPolicy\n", + "from tianshou.trainer import onpolicy_trainer\n", + "from tianshou.utils.net.common import ActorCritic, Net\n", + "from tianshou.utils.net.discrete import Actor, Critic\n", + "\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'" + ], + "metadata": { + "id": "ao9gWJDiHgG-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Environment" + ], + "metadata": { + "id": "QnRg5y7THRYw" + } + }, + { + "cell_type": "markdown", + "source": [ + "We create two vectorized environments both for training and testing. Since the execution time of CartPole is extremely short, there is no need to use multi-process wrappers and we simply use DummyVectorEnv." + ], + "metadata": { + "id": "YZERKCGtH8W1" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Mpuj5PFnDKVS" + }, + "outputs": [], + "source": [ + "env = gym.make('CartPole-v0')\n", + "train_envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(20)])\n", + "test_envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)])" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Policy\n", + "Next we need to initialise our PPO policy. PPO is an actor-critic-style on-policy algorithm, so we have to define the actor and the critic in PPO first.\n", + "\n", + "The actor is a neural network that shares the same network head with the critic. Both networks' input is the environment observation. The output of the actor is the action and the output of the critic is a single value, representing the value of the current policy.\n", + "\n", + "Luckily, Tianshou already provides basic network modules that we can use in this experiment." + ], + "metadata": { + "id": "BJtt_Ya8DTAh" + } + }, + { + "cell_type": "code", + "source": [ + "# net is the shared head of the actor and the critic\n", + "net = Net(env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n", + "actor = Actor(net, env.action_space.n, device=device).to(device)\n", + "critic = Critic(net, device=device).to(device)\n", + "actor_critic = ActorCritic(actor, critic)\n", + "\n", + "# optimizer of the actor and the critic\n", + "optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)" + ], + "metadata": { + "id": "_Vy8uPWXP4m_" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Once we have defined the actor, the critic and the optimizer. We can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution." + ], + "metadata": { + "id": "Lh2-hwE5Dn9I" + } + }, + { + "cell_type": "code", + "source": [ + "dist = torch.distributions.Categorical\n", + "policy = PPOPolicy(actor, critic, optim, dist, action_space=env.action_space, deterministic_eval=True)" + ], + "metadata": { + "id": "OiJ2GkT0Qnbr" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "`deterministic_eval=True` means that we want to sample actions during training but we would like to always use the best action in evaluation. No randomness included." + ], + "metadata": { + "id": "okxfj6IEQ-r8" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Collector\n", + "We can set up the collectors now. Train collector is used to collect and store training data, so an additional replay buffer has to be passed in." + ], + "metadata": { + "id": "n5XAAbuBZarO" + } + }, + { + "cell_type": "code", + "source": [ + "train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)))\n", + "test_collector = Collector(policy, test_envs)" + ], + "metadata": { + "id": "ezwz0qerZhQM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "We use `VectorReplayBuffer` here because it's more efficient to collaborate with vectorized environments, you can simply consider `VectorReplayBuffer` as a a list of ordinary replay buffers." + ], + "metadata": { + "id": "ZaoPxOd2hm0b" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Trainer\n", + "Finally, we can use the trainer to help us set up the training loop." + ], + "metadata": { + "id": "qBoE9pLUiC-8" + } + }, + { + "cell_type": "code", + "source": [ + "result = onpolicy_trainer(\n", + " policy,\n", + " train_collector,\n", + " test_collector,\n", + " max_epoch=10,\n", + " step_per_epoch=50000,\n", + " repeat_per_collect=10,\n", + " episode_per_test=10,\n", + " batch_size=256,\n", + " step_per_collect=2000,\n", + " stop_fn=lambda mean_reward: mean_reward >= 195,\n", + ")" + ], + "metadata": { + "id": "i45EDnpxQ8gu", + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "# Experiment\n", - "To conduct this experiment, we need the following building blocks.\n", - "\n", - "\n", - "* Two vectorized environments, one for training and one for evaluation\n", - "* A PPO agent\n", - "* A replay buffer to store transition data\n", - "* Two collectors to manage the data collecting process, one for training and one for evaluation\n", - "* A trainer to manage the training loop\n", - "\n", - "
\n", - "\n", - "\n", - "
\n", - "\n", - "Let us do this step by step." - ], - "metadata": { - "id": "2QRbCJvDHNAd" - } + "outputId": "b1666b88-0bfa-4340-868e-58611872d988" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Results\n", + "Print the training result." + ], + "metadata": { + "id": "ckgINHE2iTFR" + } + }, + { + "cell_type": "code", + "source": [ + "print(result)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "source": [ - "## Preparation\n", - "Firstly, install Tianshou if you haven't installed it before." - ], - "metadata": { - "id": "-Hh4E6i0Hj0I" - } + "id": "tJCPgmiyiaaX", + "outputId": "40123ae3-3365-4782-9563-46c43812f10f" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "We can also test our trained agent." + ], + "metadata": { + "id": "A-MJ9avMibxN" + } + }, + { + "cell_type": "code", + "source": [ + "# Let's watch its performance!\n", + "policy.eval()\n", + "result = test_collector.collect(n_episode=1, render=False)\n", + "print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "source": [ - "!pip install tianshou==0.4.8\n", - "!pip install gym" - ], - "metadata": { - "id": "w50BVwaRHg3N" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Import libraries we might need later." - ], - "metadata": { - "id": "7E4EhiBeHxD5" - } - }, - { - "cell_type": "code", - "source": [ - "import gym\n", - "import numpy as np\n", - "import torch\n", - "\n", - "from tianshou.data import Collector, VectorReplayBuffer\n", - "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import PPOPolicy\n", - "from tianshou.trainer import onpolicy_trainer\n", - "from tianshou.utils.net.common import ActorCritic, Net\n", - "from tianshou.utils.net.discrete import Actor, Critic\n", - "\n", - "import warnings\n", - "warnings.filterwarnings('ignore')\n", - "\n", - "device = 'cuda' if torch.cuda.is_available() else 'cpu'" - ], - "metadata": { - "id": "ao9gWJDiHgG-" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Environment" - ], - "metadata": { - "id": "QnRg5y7THRYw" - } - }, - { - "cell_type": "markdown", - "source": [ - "We create two vectorized environments both for training and testing. Since the execution time of CartPole is extremely short, there is no need to use multi-process wrappers and we simply use DummyVectorEnv." - ], - "metadata": { - "id": "YZERKCGtH8W1" - } - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Mpuj5PFnDKVS" - }, - "outputs": [], - "source": [ - "env = gym.make('CartPole-v0')\n", - "train_envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(20)])\n", - "test_envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)])" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Policy\n", - "Next we need to initialise our PPO policy. PPO is an actor-critic-style on-policy algorithm, so we have to define the actor and the critic in PPO first.\n", - "\n", - "The actor is a neural network that shares the same network head with the critic. Both networks' input is the environment observation. The output of the actor is the action and the output of the critic is a single value, representing the value of the current policy.\n", - "\n", - "Luckily, Tianshou already provides basic network modules that we can use in this experiment." - ], - "metadata": { - "id": "BJtt_Ya8DTAh" - } - }, - { - "cell_type": "code", - "source": [ - "# net is the shared head of the actor and the critic\n", - "net = Net(env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n", - "actor = Actor(net, env.action_space.n, device=device).to(device)\n", - "critic = Critic(net, device=device).to(device)\n", - "actor_critic = ActorCritic(actor, critic)\n", - "\n", - "# optimizer of the actor and the critic\n", - "optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)" - ], - "metadata": { - "id": "_Vy8uPWXP4m_" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Once we have defined the actor, the critic and the optimizer. We can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution." - ], - "metadata": { - "id": "Lh2-hwE5Dn9I" - } - }, - { - "cell_type": "code", - "source": [ - "dist = torch.distributions.Categorical\n", - "policy = PPOPolicy(actor, critic, optim, dist, action_space=env.action_space, deterministic_eval=True)" - ], - "metadata": { - "id": "OiJ2GkT0Qnbr" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "`deterministic_eval=True` means that we want to sample actions during training but we would like to always use the best action in evaluation. No randomness included." - ], - "metadata": { - "id": "okxfj6IEQ-r8" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Collector\n", - "We can set up the collectors now. Train collector is used to collect and store training data, so an additional replay buffer has to be passed in." - ], - "metadata": { - "id": "n5XAAbuBZarO" - } - }, - { - "cell_type": "code", - "source": [ - "train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)))\n", - "test_collector = Collector(policy, test_envs)" - ], - "metadata": { - "id": "ezwz0qerZhQM" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "We use `VectorReplayBuffer` here because it's more efficient to collaborate with vectorized environments, you can simply consider `VectorReplayBuffer` as a a list of ordinary replay buffers." - ], - "metadata": { - "id": "ZaoPxOd2hm0b" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Trainer\n", - "Finally, we can use the trainer to help us set up the training loop." - ], - "metadata": { - "id": "qBoE9pLUiC-8" - } - }, - { - "cell_type": "code", - "source": [ - "result = onpolicy_trainer(\n", - " policy,\n", - " train_collector,\n", - " test_collector,\n", - " max_epoch=10,\n", - " step_per_epoch=50000,\n", - " repeat_per_collect=10,\n", - " episode_per_test=10,\n", - " batch_size=256,\n", - " step_per_collect=2000,\n", - " stop_fn=lambda mean_reward: mean_reward >= 195,\n", - ")" - ], - "metadata": { - "id": "i45EDnpxQ8gu", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "b1666b88-0bfa-4340-868e-58611872d988" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Epoch #1: 50001it [00:13, 3601.81it/s, env_step=50000, len=143, loss=41.162, loss/clip=0.001, loss/ent=0.583, loss/vf=82.332, n/ep=12, n/st=2000, rew=143.08] \n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epoch #1: test_reward: 200.000000 ± 0.000000, best_reward: 200.000000 ± 0.000000 in #1\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Results\n", - "Print the training result." - ], - "metadata": { - "id": "ckgINHE2iTFR" - } - }, - { - "cell_type": "code", - "source": [ - "print(result)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "tJCPgmiyiaaX", - "outputId": "40123ae3-3365-4782-9563-46c43812f10f" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "{'duration': '14.17s', 'train_time/model': '8.80s', 'test_step': 2094, 'test_episode': 20, 'test_time': '0.27s', 'test_speed': '7770.16 step/s', 'best_reward': 200.0, 'best_result': '200.00 ± 0.00', 'train_step': 50000, 'train_episode': 942, 'train_time/collector': '5.10s', 'train_speed': '3597.32 step/s'}\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "We can also test our trained agent." - ], - "metadata": { - "id": "A-MJ9avMibxN" - } - }, - { - "cell_type": "code", - "source": [ - "# Let's watch its performance!\n", - "policy.eval()\n", - "result = test_collector.collect(n_episode=1, render=False)\n", - "print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "mnMANFcciiAQ", - "outputId": "6febcc1e-7265-4a75-c9dd-34e29a3e5d21" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Final reward: 200.0, length: 200.0\n" - ] - } - ] - } - ] -} \ No newline at end of file + "id": "mnMANFcciiAQ", + "outputId": "6febcc1e-7265-4a75-c9dd-34e29a3e5d21" + }, + "execution_count": null, + "outputs": [] + } + ] +}