Removed notebook outputs
This commit is contained in:
parent
6b6ce0fdf1
commit
efaadec6a1
@ -38,134 +38,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "qvplhjduVDs6",
|
||||
"ExecuteTime": {
|
||||
"end_time": "2023-10-12T15:51:01.680688825Z",
|
||||
"start_time": "2023-10-12T15:48:15.090023052Z"
|
||||
}
|
||||
"id": "qvplhjduVDs6"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Collecting tianshou==0.4.8\r\n",
|
||||
" Downloading tianshou-0.4.8-py3-none-any.whl (150 kB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m150.4/150.4 kB\u001B[0m \u001B[31m3.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0ma \u001B[36m0:00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting gym>=0.15.4 (from tianshou==0.4.8)\r\n",
|
||||
" Downloading gym-0.26.2.tar.gz (721 kB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m721.7/721.7 kB\u001B[0m \u001B[31m11.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0ma \u001B[36m0:00:01\u001B[0m\r\n",
|
||||
"\u001B[?25h Installing build dependencies ... \u001B[?25ldone\r\n",
|
||||
"\u001B[?25h Getting requirements to build wheel ... \u001B[?25ldone\r\n",
|
||||
"\u001B[?25h Preparing metadata (pyproject.toml) ... \u001B[?25ldone\r\n",
|
||||
"\u001B[?25hRequirement already satisfied: tqdm in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (4.66.1)\r\n",
|
||||
"Requirement already satisfied: numpy>1.16.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (1.24.4)\r\n",
|
||||
"Requirement already satisfied: tensorboard>=2.5.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (2.14.1)\r\n",
|
||||
"Requirement already satisfied: torch>=1.4.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (2.1.0)\r\n",
|
||||
"Requirement already satisfied: numba>=0.51.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (0.57.1)\r\n",
|
||||
"Requirement already satisfied: h5py>=2.10.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (3.10.0)\r\n",
|
||||
"Requirement already satisfied: cloudpickle>=1.2.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym>=0.15.4->tianshou==0.4.8) (2.2.1)\r\n",
|
||||
"Collecting gym-notices>=0.0.4 (from gym>=0.15.4->tianshou==0.4.8)\r\n",
|
||||
" Downloading gym_notices-0.0.8-py3-none-any.whl (3.0 kB)\r\n",
|
||||
"Requirement already satisfied: llvmlite<0.41,>=0.40.0dev0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from numba>=0.51.0->tianshou==0.4.8) (0.40.1)\r\n",
|
||||
"Requirement already satisfied: absl-py>=0.4 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (2.0.0)\r\n",
|
||||
"Requirement already satisfied: grpcio>=1.48.2 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (1.59.0)\r\n",
|
||||
"Requirement already satisfied: google-auth<3,>=1.6.3 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (2.23.3)\r\n",
|
||||
"Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (1.0.0)\r\n",
|
||||
"Requirement already satisfied: markdown>=2.6.8 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (3.5)\r\n",
|
||||
"Requirement already satisfied: protobuf>=3.19.6 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (3.20.3)\r\n",
|
||||
"Requirement already satisfied: requests<3,>=2.21.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (2.31.0)\r\n",
|
||||
"Requirement already satisfied: setuptools>=41.0.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (68.2.2)\r\n",
|
||||
"Requirement already satisfied: six>1.9 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (1.16.0)\r\n",
|
||||
"Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (0.7.1)\r\n",
|
||||
"Requirement already satisfied: werkzeug>=1.0.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (3.0.0)\r\n",
|
||||
"Requirement already satisfied: filelock in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (3.12.4)\r\n",
|
||||
"Requirement already satisfied: typing-extensions in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (4.8.0)\r\n",
|
||||
"Requirement already satisfied: sympy in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (1.12)\r\n",
|
||||
"Requirement already satisfied: networkx in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (3.1)\r\n",
|
||||
"Requirement already satisfied: jinja2 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (3.1.2)\r\n",
|
||||
"Requirement already satisfied: fsspec in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (2023.9.2)\r\n",
|
||||
"Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m23.7/23.7 MB\u001B[0m \u001B[31m17.1 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m823.6/823.6 kB\u001B[0m \u001B[31m20.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m14.1/14.1 MB\u001B[0m \u001B[31m14.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Obtaining dependency information for nvidia-cudnn-cu12==8.9.2.26 from https://files.pythonhosted.org/packages/ff/74/a2e2be7fb83aaedec84f391f082cf765dfb635e7caa9b49065f73e4835d8/nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata\r\n",
|
||||
" Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\r\n",
|
||||
"Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m410.6/410.6 MB\u001B[0m \u001B[31m5.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m121.6/121.6 MB\u001B[0m \u001B[31m9.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m56.5/56.5 MB\u001B[0m \u001B[31m13.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m124.2/124.2 MB\u001B[0m \u001B[31m10.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m196.0/196.0 MB\u001B[0m \u001B[31m8.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-nccl-cu12==2.18.1 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl (209.8 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m209.8/209.8 MB\u001B[0m \u001B[31m8.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting nvidia-nvtx-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m99.1/99.1 kB\u001B[0m \u001B[31m28.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\r\n",
|
||||
"\u001B[?25hCollecting triton==2.1.0 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Obtaining dependency information for triton==2.1.0 from https://files.pythonhosted.org/packages/5c/c1/54fffb2eb13d293d9a429fead3646752ea190de0229bcf3d591ba2481263/triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata\r\n",
|
||||
" Downloading triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.3 kB)\r\n",
|
||||
"Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.4.0->tianshou==0.4.8)\r\n",
|
||||
" Obtaining dependency information for nvidia-nvjitlink-cu12 from https://files.pythonhosted.org/packages/0a/f8/5193b57555cbeecfdb6ade643df0d4218cc6385485492b6e2f64ceae53bb/nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl.metadata\r\n",
|
||||
" Downloading nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\r\n",
|
||||
"Requirement already satisfied: cachetools<6.0,>=2.0.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (5.3.1)\r\n",
|
||||
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (0.3.0)\r\n",
|
||||
"Requirement already satisfied: rsa<5,>=3.1.4 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (4.9)\r\n",
|
||||
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard>=2.5.0->tianshou==0.4.8) (1.3.1)\r\n",
|
||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (3.3.0)\r\n",
|
||||
"Requirement already satisfied: idna<4,>=2.5 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (3.4)\r\n",
|
||||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (2.0.6)\r\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (2023.7.22)\r\n",
|
||||
"Requirement already satisfied: MarkupSafe>=2.1.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from werkzeug>=1.0.1->tensorboard>=2.5.0->tianshou==0.4.8) (2.1.3)\r\n",
|
||||
"Requirement already satisfied: mpmath>=0.19 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from sympy->torch>=1.4.0->tianshou==0.4.8) (1.3.0)\r\n",
|
||||
"Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (0.5.0)\r\n",
|
||||
"Requirement already satisfied: oauthlib>=3.0.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard>=2.5.0->tianshou==0.4.8) (3.2.2)\r\n",
|
||||
"Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m731.7/731.7 MB\u001B[0m \u001B[31m3.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m:00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hDownloading triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (89.2 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m89.2/89.2 MB\u001B[0m \u001B[31m13.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hDownloading nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl (20.2 MB)\r\n",
|
||||
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m20.2/20.2 MB\u001B[0m \u001B[31m15.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
||||
"\u001B[?25hBuilding wheels for collected packages: gym\r\n",
|
||||
" Building wheel for gym (pyproject.toml) ... \u001B[?25ldone\r\n",
|
||||
"\u001B[?25h Created wheel for gym: filename=gym-0.26.2-py3-none-any.whl size=827621 sha256=612698033ee83c54db52d872001a111f5f0adf14dd996065edff561305ac2266\r\n",
|
||||
" Stored in directory: /home/ccagnetta/.cache/pip/wheels/1c/77/9e/9af5470201a0b0543937933ee99ba884cd237d2faefe8f4d37\r\n",
|
||||
"Successfully built gym\r\n",
|
||||
"Installing collected packages: gym-notices, triton, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, gym, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, tianshou\r\n",
|
||||
" Attempting uninstall: triton\r\n",
|
||||
" Found existing installation: triton 2.0.0\r\n",
|
||||
" Uninstalling triton-2.0.0:\r\n",
|
||||
" Successfully uninstalled triton-2.0.0\r\n",
|
||||
" Attempting uninstall: tianshou\r\n",
|
||||
" Found existing installation: tianshou 0.5.1\r\n",
|
||||
" Uninstalling tianshou-0.5.1:\r\n",
|
||||
" Successfully uninstalled tianshou-0.5.1\r\n",
|
||||
"Successfully installed gym-0.26.2 gym-notices-0.0.8 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.18.1 nvidia-nvjitlink-cu12-12.2.140 nvidia-nvtx-cu12-12.1.105 tianshou-0.4.8 triton-2.1.0\r\n",
|
||||
"Requirement already satisfied: gym in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (0.26.2)\r\n",
|
||||
"Requirement already satisfied: numpy>=1.18.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym) (1.24.4)\r\n",
|
||||
"Requirement already satisfied: cloudpickle>=1.2.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym) (2.2.1)\r\n",
|
||||
"Requirement already satisfied: gym-notices>=0.0.4 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym) (0.0.8)\r\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install tianshou==0.4.8\n",
|
||||
"!pip install gym"
|
||||
@ -248,23 +125,7 @@
|
||||
"outputId": "b792fc24-f42c-426a-9d83-fe1a4f3f91f1"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"Epoch #1: 50001it [00:19, 2529.50it/s, env_step=50000, len=87, loss=80.895, loss/clip=-0.009, loss/ent=0.566, loss/vf=161.818, n/ep=15, n/st=2000, rew=87.27] \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Epoch #1: test_reward: 200.000000 ± 0.000000, best_reward: 200.000000 ± 0.000000 in #1\n",
|
||||
"{'duration': '20.26s', 'train_time/model': '13.75s', 'test_step': 2159, 'test_episode': 20, 'test_time': '0.48s', 'test_speed': '4496.33 step/s', 'best_reward': 200.0, 'best_result': '200.00 ± 0.00', 'train_step': 50000, 'train_episode': 944, 'train_time/collector': '6.03s', 'train_speed': '2527.97 step/s'}\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@ -282,15 +143,7 @@
|
||||
"outputId": "2a9b5b22-be50-4bb7-ae93-af7e65e7442a"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Final reward: 200.0, length: 200.0\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,575 +1,394 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "4TCEkXj7LFe2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Remember to install tianshou first\n",
|
||||
"!pip install tianshou==0.4.8\n",
|
||||
"!pip install gym"
|
||||
]
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "4TCEkXj7LFe2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Remember to install tianshou first\n",
|
||||
"!pip install tianshou==0.4.8\n",
|
||||
"!pip install gym"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Overview\n",
|
||||
"Replay Buffer is a very common module in DRL implementations. In Tianshou, you can consider Buffer module as as a specialized form of Batch, which helps you track all data trajectories and provide utilities such as sampling method besides the basic storage.\n",
|
||||
"\n",
|
||||
"There are many kinds of Buffer modules in Tianshou, two most basic ones are ReplayBuffer and VectorReplayBuffer. The later one is specially designed for parallelized environments (will introduce in tutorial L3)."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "xoPiGVD8LNma"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Usages"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "OdesCAxANehZ"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Basic usages as a batch\n",
|
||||
"Usually a buffer stores all the data in a batch with circular-queue style."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "fUbLl9T_SrTR"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from tianshou.data import Batch, ReplayBuffer\n",
|
||||
"# a buffer is initialised with its maxsize set to 10 (older data will be discarded if more data flow in).\n",
|
||||
"print(\"========================================\")\n",
|
||||
"buf = ReplayBuffer(size=10)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"\n",
|
||||
"# add 3 steps of data into ReplayBuffer sequentially\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(3):\n",
|
||||
" buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"\n",
|
||||
"# add another 10 steps of data into ReplayBuffer sequentially\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(3, 13):\n",
|
||||
" buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Overview\n",
|
||||
"Replay Buffer is a very common module in DRL implementations. In Tianshou, you can consider Buffer module as as a specialized form of Batch, which helps you track all data trajectories and provide utilities such as sampling method besides the basic storage.\n",
|
||||
"\n",
|
||||
"There are many kinds of Buffer modules in Tianshou, two most basic ones are ReplayBuffer and VectorReplayBuffer. The later one is specially designed for parallelized environments (will introduce in tutorial L3)."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "xoPiGVD8LNma"
|
||||
}
|
||||
"id": "mocZ6IqZTH62",
|
||||
"outputId": "66cc4181-c51b-4a47-aacf-666b92b7fc52"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Just like Batch, ReplayBuffer supports concatenation, splitting, advanced slicing and indexing, etc."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "H8B85Y5yUfTy"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(buf[-1])\n",
|
||||
"print(buf[-3:])\n",
|
||||
"# Try more methods you find useful in Batch yourself."
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Usages"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "OdesCAxANehZ"
|
||||
}
|
||||
"id": "cOX-ADOPNeEK",
|
||||
"outputId": "f1a8ec01-b878-419b-f180-bdce3dee73e6"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"ReplayBuffer can also be saved into local disk, still keeping track of the trajectories. This is extremely helpful in offline DRL settings."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "vqldap-2WQBh"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"_buf = pickle.loads(pickle.dumps(buf))"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Ppx0L3niNT5K"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Understanding reserved keys for buffer\n",
|
||||
"As I have explained, ReplayBuffer is specially designed to utilize the implementations of DRL algorithms. So, for convenience, we reserve certain seven reserved keys in Batch.\n",
|
||||
"\n",
|
||||
"* `obs`\n",
|
||||
"* `act`\n",
|
||||
"* `rew`\n",
|
||||
"* `done`\n",
|
||||
"* `obs_next`\n",
|
||||
"* `info`\n",
|
||||
"* `policy`\n",
|
||||
"\n",
|
||||
"The meaning of these seven reserved keys are consistent with the meaning in [OPENAI Gym](https://gym.openai.com/). We would recommend you simply use these seven keys when adding batched data into ReplayBuffer, because\n",
|
||||
"some of them are tracked in ReplayBuffer (e.g. \"done\" value is tracked to help us determine a trajectory's start index and end index, together with its total reward and episode length.)\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"buf.add(Batch(......, extro_info=0)) # This is okay but not recommended.\n",
|
||||
"buf.add(Batch(......, info={\"extro_info\":0})) # Recommended.\n",
|
||||
"```\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Eqezp0OyXn6J"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Data sampling\n",
|
||||
"We keep a replay buffer in DRL for one purpose:\"sample data from it for training\". `ReplayBuffer.sample()` and `ReplayBuffer.split(..., shuffle=True)` can both fullfill this need."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ueAbTspsc6jo"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"buf.sample(5)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Basic usages as a batch\n",
|
||||
"Usually a buffer stores all the data in a batch with circular-queue style."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "fUbLl9T_SrTR"
|
||||
}
|
||||
"id": "P5xnYOhrchDl",
|
||||
"outputId": "bcd2c970-efa6-43bb-8709-720d38f77bbd"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Trajectory tracking\n",
|
||||
"Compared to Batch, a unique feature of ReplayBuffer is that it can help you track the environment trajectories.\n",
|
||||
"\n",
|
||||
"First, let us simulate a situation, where we add three trajectories into the buffer. The last trajectory is still not finished yet."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "IWyaOSKOcgK4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from numpy import False_\n",
|
||||
"buf = ReplayBuffer(size=10)\n",
|
||||
"# Add the first trajectory (length is 3) into ReplayBuffer\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(3):\n",
|
||||
" result = buf.add(Batch(obs=i, act=i, rew=i, done=True if i==2 else False, obs_next=i + 1, info={}))\n",
|
||||
" print(result)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"# Add the second trajectory (length is 5) into ReplayBuffer\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(3, 8):\n",
|
||||
" result = buf.add(Batch(obs=i, act=i, rew=i, done=True if i==7 else False, obs_next=i + 1, info={}))\n",
|
||||
" print(result)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"# Add the third trajectory (length is 5, still not finished) into ReplayBuffer\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(8, 13):\n",
|
||||
" result = buf.add(Batch(obs=i, act=i, rew=i, done=False, obs_next=i + 1, info={}))\n",
|
||||
" print(result)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from tianshou.data import Batch, ReplayBuffer\n",
|
||||
"# a buffer is initialised with its maxsize set to 10 (older data will be discarded if more data flow in).\n",
|
||||
"print(\"========================================\")\n",
|
||||
"buf = ReplayBuffer(size=10)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"\n",
|
||||
"# add 3 steps of data into ReplayBuffer sequentially\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(3):\n",
|
||||
" buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"\n",
|
||||
"# add another 10 steps of data into ReplayBuffer sequentially\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(3, 13):\n",
|
||||
" buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "mocZ6IqZTH62",
|
||||
"outputId": "66cc4181-c51b-4a47-aacf-666b92b7fc52"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"========================================\n",
|
||||
"ReplayBuffer()\n",
|
||||
"maxsize: 10, data length: 0\n",
|
||||
"========================================\n",
|
||||
"ReplayBuffer(\n",
|
||||
" info: Batch(),\n",
|
||||
" act: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),\n",
|
||||
" obs: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),\n",
|
||||
" done: array([False, False, False, False, False, False, False, False, False,\n",
|
||||
" False]),\n",
|
||||
" rew: array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0.]),\n",
|
||||
" obs_next: array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0]),\n",
|
||||
")\n",
|
||||
"maxsize: 10, data length: 3\n",
|
||||
"========================================\n",
|
||||
"ReplayBuffer(\n",
|
||||
" info: Batch(),\n",
|
||||
" act: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n",
|
||||
" obs: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n",
|
||||
" done: array([False, False, False, False, False, False, False, False, False,\n",
|
||||
" False]),\n",
|
||||
" rew: array([10., 11., 12., 3., 4., 5., 6., 7., 8., 9.]),\n",
|
||||
" obs_next: array([11, 12, 13, 4, 5, 6, 7, 8, 9, 10]),\n",
|
||||
")\n",
|
||||
"maxsize: 10, data length: 10\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
"id": "H0qRb6HLfhLB",
|
||||
"outputId": "9bdb7d4e-b6ec-489f-a221-0bddf706d85b"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### episode length and rewards tracking\n",
|
||||
"Notice that `ReplayBuffer.add()` returns a tuple of 4 numbers every time it returns, meaning `(current_index, episode_reward, episode_length, episode_start_index)`. `episode_reward` and `episode_length` are valid only when a trajectory is finished. This might save developers some trouble.\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "dO7PWdb_hkXA"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Episode index management\n",
|
||||
"In the ReplayBuffer above, we can get access to any data step by indexing.\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "xbVc90z8itH0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(buf)\n",
|
||||
"data = buf[6]\n",
|
||||
"print(data)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Just like Batch, ReplayBuffer supports concatenation, splitting, advanced slicing and indexing, etc."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "H8B85Y5yUfTy"
|
||||
}
|
||||
"id": "4mKwo54MjupY",
|
||||
"outputId": "9ae14a7e-908b-44eb-afec-89b45bac5961"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Now we know that step \"6\" is not the start of an episode (it should be step 4, 4-7 is the second trajectory we add into the ReplayBuffer), we wonder what is the earliest index of the that episode.\n",
|
||||
"\n",
|
||||
"This may seem easy but actually it is not. We cannot simply look at the \"done\" flag now, because we can see that since the third-added trajectory is not finished yet, step \"4\" is surrounded by flag \"False\". There are many things to consider. Things could get more nasty if you are using more advanced ReplayBuffer like VectorReplayBuffer, because now the data is not stored in a simple circular-queue.\n",
|
||||
"\n",
|
||||
"Luckily, all ReplayBuffer instances help you identify step indexes through a unified API."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "p5Co_Fmzj8Sw"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Search for the previous index of index \"6\"\n",
|
||||
"now_index = 6\n",
|
||||
"while True:\n",
|
||||
" prev_index = buf.prev(now_index)\n",
|
||||
" print(prev_index)\n",
|
||||
" if prev_index == now_index:\n",
|
||||
" break\n",
|
||||
" else:\n",
|
||||
" now_index = prev_index"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(buf[-1])\n",
|
||||
"print(buf[-3:])\n",
|
||||
"# Try more methods you find useful in Batch yourself."
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "cOX-ADOPNeEK",
|
||||
"outputId": "f1a8ec01-b878-419b-f180-bdce3dee73e6"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Batch(\n",
|
||||
" obs: array(9),\n",
|
||||
" act: array(9),\n",
|
||||
" rew: array(9.),\n",
|
||||
" done: array(False),\n",
|
||||
" obs_next: array(10),\n",
|
||||
" info: Batch(),\n",
|
||||
" policy: Batch(),\n",
|
||||
")\n",
|
||||
"Batch(\n",
|
||||
" obs: array([7, 8, 9]),\n",
|
||||
" act: array([7, 8, 9]),\n",
|
||||
" rew: array([7., 8., 9.]),\n",
|
||||
" done: array([False, False, False]),\n",
|
||||
" obs_next: array([ 8, 9, 10]),\n",
|
||||
" info: Batch(),\n",
|
||||
" policy: Batch(),\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
"id": "DcJ0LEX6mxHg",
|
||||
"outputId": "7830f5fb-96d9-4298-d09b-24e64b2f633c"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Using `ReplayBuffer.prev()`, we know that the earliest step of that episode is step \"3\". Similarly, `ReplayBuffer.next()` helps us indentify the last index of an episode regardless of which kind of ReplayBuffer we are using."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "4Wlb57V4lQyQ"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# next step of indexes [4,5,6,7,8,9] are:\n",
|
||||
"print(buf.next([4,5,6,7,8,9]))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"ReplayBuffer can also be saved into local disk, still keeping track of the trajectories. This is extremely helpful in offline DRL settings."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "vqldap-2WQBh"
|
||||
}
|
||||
"id": "zl5TRMo7oOy5",
|
||||
"outputId": "4a11612c-3ee0-4e74-b028-c8759e71fbdb"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We can also search for the indexes which are labeled \"done: False\", but are the last step in a trajectory."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "YJ9CcWZXoOXw"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(buf.unfinished_index())"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"_buf = pickle.loads(pickle.dumps(buf))"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Ppx0L3niNT5K"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Understanding reserved keys for buffer\n",
|
||||
"As I have explained, ReplayBuffer is specially designed to utilize the implementations of DRL algorithms. So, for convenience, we reserve certain seven reserved keys in Batch.\n",
|
||||
"\n",
|
||||
"* `obs`\n",
|
||||
"* `act`\n",
|
||||
"* `rew`\n",
|
||||
"* `done`\n",
|
||||
"* `obs_next`\n",
|
||||
"* `info`\n",
|
||||
"* `policy`\n",
|
||||
"\n",
|
||||
"The meaning of these seven reserved keys are consistent with the meaning in [OPENAI Gym](https://gym.openai.com/). We would recommend you simply use these seven keys when adding batched data into ReplayBuffer, because\n",
|
||||
"some of them are tracked in ReplayBuffer (e.g. \"done\" value is tracked to help us determine a trajectory's start index and end index, together with its total reward and episode length.)\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"buf.add(Batch(......, extro_info=0)) # This is okay but not recommended.\n",
|
||||
"buf.add(Batch(......, info={\"extro_info\":0})) # Recommended.\n",
|
||||
"```\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Eqezp0OyXn6J"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Data sampling\n",
|
||||
"We keep a replay buffer in DRL for one purpose:\"sample data from it for training\". `ReplayBuffer.sample()` and `ReplayBuffer.split(..., shuffle=True)` can both fullfill this need."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ueAbTspsc6jo"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"buf.sample(5)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "P5xnYOhrchDl",
|
||||
"outputId": "bcd2c970-efa6-43bb-8709-720d38f77bbd"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(Batch(\n",
|
||||
" obs: array([10, 11, 4, 3, 8]),\n",
|
||||
" act: array([10, 11, 4, 3, 8]),\n",
|
||||
" rew: array([10., 11., 4., 3., 8.]),\n",
|
||||
" done: array([False, False, False, False, False]),\n",
|
||||
" obs_next: array([11, 12, 5, 4, 9]),\n",
|
||||
" info: Batch(),\n",
|
||||
" policy: Batch(),\n",
|
||||
" ), array([0, 1, 4, 3, 8]))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"execution_count": 5
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Trajectory tracking\n",
|
||||
"Compared to Batch, a unique feature of ReplayBuffer is that it can help you track the environment trajectories.\n",
|
||||
"\n",
|
||||
"First, let us simulate a situation, where we add three trajectories into the buffer. The last trajectory is still not finished yet."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "IWyaOSKOcgK4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from numpy import False_\n",
|
||||
"buf = ReplayBuffer(size=10)\n",
|
||||
"# Add the first trajectory (length is 3) into ReplayBuffer\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(3):\n",
|
||||
" result = buf.add(Batch(obs=i, act=i, rew=i, done=True if i==2 else False, obs_next=i + 1, info={}))\n",
|
||||
" print(result)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"# Add the second trajectory (length is 5) into ReplayBuffer\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(3, 8):\n",
|
||||
" result = buf.add(Batch(obs=i, act=i, rew=i, done=True if i==7 else False, obs_next=i + 1, info={}))\n",
|
||||
" print(result)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
||||
"# Add the third trajectory (length is 5, still not finished) into ReplayBuffer\n",
|
||||
"print(\"========================================\")\n",
|
||||
"for i in range(8, 13):\n",
|
||||
" result = buf.add(Batch(obs=i, act=i, rew=i, done=False, obs_next=i + 1, info={}))\n",
|
||||
" print(result)\n",
|
||||
"print(buf)\n",
|
||||
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "H0qRb6HLfhLB",
|
||||
"outputId": "9bdb7d4e-b6ec-489f-a221-0bddf706d85b"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"========================================\n",
|
||||
"(array([0]), array([0.]), array([0]), array([0]))\n",
|
||||
"(array([1]), array([0.]), array([0]), array([0]))\n",
|
||||
"(array([2]), array([3.]), array([3]), array([0]))\n",
|
||||
"ReplayBuffer(\n",
|
||||
" info: Batch(),\n",
|
||||
" act: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),\n",
|
||||
" obs: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),\n",
|
||||
" done: array([False, False, True, False, False, False, False, False, False,\n",
|
||||
" False]),\n",
|
||||
" rew: array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0.]),\n",
|
||||
" obs_next: array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0]),\n",
|
||||
")\n",
|
||||
"maxsize: 10, data length: 3\n",
|
||||
"========================================\n",
|
||||
"(array([3]), array([0.]), array([0]), array([3]))\n",
|
||||
"(array([4]), array([0.]), array([0]), array([3]))\n",
|
||||
"(array([5]), array([0.]), array([0]), array([3]))\n",
|
||||
"(array([6]), array([0.]), array([0]), array([3]))\n",
|
||||
"(array([7]), array([25.]), array([5]), array([3]))\n",
|
||||
"ReplayBuffer(\n",
|
||||
" info: Batch(),\n",
|
||||
" act: array([0, 1, 2, 3, 4, 5, 6, 7, 0, 0]),\n",
|
||||
" obs: array([0, 1, 2, 3, 4, 5, 6, 7, 0, 0]),\n",
|
||||
" done: array([False, False, True, False, False, False, False, True, False,\n",
|
||||
" False]),\n",
|
||||
" rew: array([0., 1., 2., 3., 4., 5., 6., 7., 0., 0.]),\n",
|
||||
" obs_next: array([1, 2, 3, 4, 5, 6, 7, 8, 0, 0]),\n",
|
||||
")\n",
|
||||
"maxsize: 10, data length: 8\n",
|
||||
"========================================\n",
|
||||
"(array([8]), array([0.]), array([0]), array([8]))\n",
|
||||
"(array([9]), array([0.]), array([0]), array([8]))\n",
|
||||
"(array([0]), array([0.]), array([0]), array([8]))\n",
|
||||
"(array([1]), array([0.]), array([0]), array([8]))\n",
|
||||
"(array([2]), array([0.]), array([0]), array([8]))\n",
|
||||
"ReplayBuffer(\n",
|
||||
" info: Batch(),\n",
|
||||
" act: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n",
|
||||
" obs: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n",
|
||||
" done: array([False, False, False, False, False, False, False, True, False,\n",
|
||||
" False]),\n",
|
||||
" rew: array([10., 11., 12., 3., 4., 5., 6., 7., 8., 9.]),\n",
|
||||
" obs_next: array([11, 12, 13, 4, 5, 6, 7, 8, 9, 10]),\n",
|
||||
")\n",
|
||||
"maxsize: 10, data length: 10\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### episode length and rewards tracking\n",
|
||||
"Notice that `ReplayBuffer.add()` returns a tuple of 4 numbers every time it returns, meaning `(current_index, episode_reward, episode_length, episode_start_index)`. `episode_reward` and `episode_length` are valid only when a trajectory is finished. This might save developers some trouble.\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "dO7PWdb_hkXA"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Episode index management\n",
|
||||
"In the ReplayBuffer above, we can get access to any data step by indexing.\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "xbVc90z8itH0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(buf)\n",
|
||||
"data = buf[6]\n",
|
||||
"print(data)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "4mKwo54MjupY",
|
||||
"outputId": "9ae14a7e-908b-44eb-afec-89b45bac5961"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"ReplayBuffer(\n",
|
||||
" info: Batch(),\n",
|
||||
" act: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n",
|
||||
" obs: array([10, 11, 12, 3, 4, 5, 6, 7, 8, 9]),\n",
|
||||
" done: array([False, False, False, False, False, False, False, True, False,\n",
|
||||
" False]),\n",
|
||||
" rew: array([10., 11., 12., 3., 4., 5., 6., 7., 8., 9.]),\n",
|
||||
" obs_next: array([11, 12, 13, 4, 5, 6, 7, 8, 9, 10]),\n",
|
||||
")\n",
|
||||
"Batch(\n",
|
||||
" obs: array(6),\n",
|
||||
" act: array(6),\n",
|
||||
" rew: array(6.),\n",
|
||||
" done: array(False),\n",
|
||||
" obs_next: array(7),\n",
|
||||
" info: Batch(),\n",
|
||||
" policy: Batch(),\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Now we know that step \"6\" is not the start of an episode (it should be step 4, 4-7 is the second trajectory we add into the ReplayBuffer), we wonder what is the earliest index of the that episode.\n",
|
||||
"\n",
|
||||
"This may seem easy but actually it is not. We cannot simply look at the \"done\" flag now, because we can see that since the third-added trajectory is not finished yet, step \"4\" is surrounded by flag \"False\". There are many things to consider. Things could get more nasty if you are using more advanced ReplayBuffer like VectorReplayBuffer, because now the data is not stored in a simple circular-queue.\n",
|
||||
"\n",
|
||||
"Luckily, all ReplayBuffer instances help you identify step indexes through a unified API."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "p5Co_Fmzj8Sw"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Search for the previous index of index \"6\"\n",
|
||||
"now_index = 6\n",
|
||||
"while True:\n",
|
||||
" prev_index = buf.prev(now_index)\n",
|
||||
" print(prev_index)\n",
|
||||
" if prev_index == now_index:\n",
|
||||
" break\n",
|
||||
" else:\n",
|
||||
" now_index = prev_index"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "DcJ0LEX6mxHg",
|
||||
"outputId": "7830f5fb-96d9-4298-d09b-24e64b2f633c"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"5\n",
|
||||
"4\n",
|
||||
"3\n",
|
||||
"3\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Using `ReplayBuffer.prev()`, we know that the earliest step of that episode is step \"3\". Similarly, `ReplayBuffer.next()` helps us indentify the last index of an episode regardless of which kind of ReplayBuffer we are using."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "4Wlb57V4lQyQ"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# next step of indexes [4,5,6,7,8,9] are:\n",
|
||||
"print(buf.next([4,5,6,7,8,9]))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "zl5TRMo7oOy5",
|
||||
"outputId": "4a11612c-3ee0-4e74-b028-c8759e71fbdb"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"[5 6 7 7 9 0]\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We can also search for the indexes which are labeled \"done: False\", but are the last step in a trajectory."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "YJ9CcWZXoOXw"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(buf.unfinished_index())"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Xkawk97NpItg",
|
||||
"outputId": "df10b359-c2c7-42ca-e50d-9caee6bccadd"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"[2]\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Aforementioned APIs will be helpful when we calculate quantities like GAE and n-step-returns in DRL algorithms ([Example usage in Tianshou](https://github.com/thu-ml/tianshou/blob/6fc68578127387522424460790cbcb32a2bd43c4/tianshou/policy/base.py#L384)). The unified APIs ensure a modular design and a flexible interface."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "8_lMr0j3pOmn"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Further Reading\n",
|
||||
"## Other Buffer Module\n",
|
||||
"\n",
|
||||
"* PrioritizedReplayBuffer, which helps you implement [prioritized experience replay](https://arxiv.org/abs/1511.05952)\n",
|
||||
"* CachedReplayBuffer, one main buffer with several cached buffers (higher sample efficiency in some scenarios)\n",
|
||||
"* ReplayBufferManager, A base class that can be inherited (may help you manage multiple buffers).\n",
|
||||
"\n",
|
||||
"Check the documentation and the source code for more details.\n",
|
||||
"\n",
|
||||
"## Support for steps stacking to use RNN in DRL.\n",
|
||||
"There is an option called `stack_num` (default to 1) when initialising the ReplayBuffer, which may help you use RNN in your algorithm. Check the documentation for details."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "FEyE0c7tNfwa"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"id": "Xkawk97NpItg",
|
||||
"outputId": "df10b359-c2c7-42ca-e50d-9caee6bccadd"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Aforementioned APIs will be helpful when we calculate quantities like GAE and n-step-returns in DRL algorithms ([Example usage in Tianshou](https://github.com/thu-ml/tianshou/blob/6fc68578127387522424460790cbcb32a2bd43c4/tianshou/policy/base.py#L384)). The unified APIs ensure a modular design and a flexible interface."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "8_lMr0j3pOmn"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Further Reading\n",
|
||||
"## Other Buffer Module\n",
|
||||
"\n",
|
||||
"* PrioritizedReplayBuffer, which helps you implement [prioritized experience replay](https://arxiv.org/abs/1511.05952)\n",
|
||||
"* CachedReplayBuffer, one main buffer with several cached buffers (higher sample efficiency in some scenarios)\n",
|
||||
"* ReplayBufferManager, A base class that can be inherited (may help you manage multiple buffers).\n",
|
||||
"\n",
|
||||
"Check the documentation and the source code for more details.\n",
|
||||
"\n",
|
||||
"## Support for steps stacking to use RNN in DRL.\n",
|
||||
"There is an option called `stack_num` (default to 1) when initialising the ReplayBuffer, which may help you use RNN in your algorithm. Check the documentation for details."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "FEyE0c7tNfwa"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@ -1,225 +1,215 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "0T7FYEnlBT6F"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Remember to install tianshou first\n",
|
||||
"!pip install tianshou==0.4.8\n",
|
||||
"!pip install gym"
|
||||
]
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "0T7FYEnlBT6F"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Remember to install tianshou first\n",
|
||||
"!pip install tianshou==0.4.8\n",
|
||||
"!pip install gym"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Overview\n",
|
||||
"In reinforcement learning, the agent interacts with environments to improve itself. In this tutorial we will concentrate on the environment part. Although there are many kinds of environments or their libraries in DRL research, Tianshou chooses to keep a consistent API with [OPENAI Gym](https://gym.openai.com/).\n",
|
||||
"\n",
|
||||
"<div align=center>\n",
|
||||
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/rl-loop.jpg\", title=\"The agents interacting with the environment\">\n",
|
||||
"\n",
|
||||
"<a> The agents interacting with the environment </a>\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
"In Gym, an environment receives an action and returns next observation and reward. This process is slow and sometimes can be the throughput bottleneck in a DRL experiment.\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "W5V7z3fVX7_b"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Tianshou provides vectorized environment wrapper for a Gym environment. This wrapper allows you to make use of multiple cpu cores in your server to accelerate the data sampling."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "A0NGWZ8adBwt"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from tianshou.env import SubprocVectorEnv\n",
|
||||
"import numpy as np\n",
|
||||
"import gym\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"num_cpus = [1,2,5]\n",
|
||||
"for num_cpu in num_cpus:\n",
|
||||
" env = SubprocVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(num_cpu)])\n",
|
||||
" env.reset()\n",
|
||||
" sampled_steps = 0\n",
|
||||
" time_start = time.time()\n",
|
||||
" while sampled_steps < 1000:\n",
|
||||
" act = np.random.choice(2, size=num_cpu)\n",
|
||||
" obs, rew, done, info = env.step(act)\n",
|
||||
" if np.sum(done):\n",
|
||||
" env.reset(np.where(done)[0])\n",
|
||||
" sampled_steps += num_cpu\n",
|
||||
" time_used = time.time() - time_start\n",
|
||||
" print(\"{}s used to sample 1000 steps if using {} cpus.\".format(time_used, num_cpu))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Overview\n",
|
||||
"In reinforcement learning, the agent interacts with environments to improve itself. In this tutorial we will concentrate on the environment part. Although there are many kinds of environments or their libraries in DRL research, Tianshou chooses to keep a consistent API with [OPENAI Gym](https://gym.openai.com/).\n",
|
||||
"\n",
|
||||
"<div align=center>\n",
|
||||
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/rl-loop.jpg\", title=\"The agents interacting with the environment\">\n",
|
||||
"\n",
|
||||
"<a> The agents interacting with the environment </a>\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
"In Gym, an environment receives an action and returns next observation and reward. This process is slow and sometimes can be the throughput bottleneck in a DRL experiment.\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "W5V7z3fVX7_b"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Tianshou provides vectorized environment wrapper for a Gym environment. This wrapper allows you to make use of multiple cpu cores in your server to accelerate the data sampling."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "A0NGWZ8adBwt"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from tianshou.env import SubprocVectorEnv\n",
|
||||
"import numpy as np\n",
|
||||
"import gym\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"num_cpus = [1,2,5]\n",
|
||||
"for num_cpu in num_cpus:\n",
|
||||
" env = SubprocVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(num_cpu)])\n",
|
||||
" env.reset()\n",
|
||||
" sampled_steps = 0\n",
|
||||
" time_start = time.time()\n",
|
||||
" while sampled_steps < 1000:\n",
|
||||
" act = np.random.choice(2, size=num_cpu)\n",
|
||||
" obs, rew, done, info = env.step(act)\n",
|
||||
" if np.sum(done):\n",
|
||||
" env.reset(np.where(done)[0])\n",
|
||||
" sampled_steps += num_cpu\n",
|
||||
" time_used = time.time() - time_start\n",
|
||||
" print(\"{}s used to sample 1000 steps if using {} cpus.\".format(time_used, num_cpu))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "67wKtkiNi3lb",
|
||||
"outputId": "1e04353b-7a91-4c32-e2ae-f3889d58aa5e"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"0.30551695823669434s used to sample 1000 steps if using 1 cpus.\n",
|
||||
"0.2602052688598633s used to sample 1000 steps if using 2 cpus.\n",
|
||||
"0.15763545036315918s used to sample 1000 steps if using 5 cpus.\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"You may notice that the speed doesn't increase linearly when we add subprocess numbers. There are multiple reasons behind this. One reason is that synchronize exection causes straggler effect. One way to solve this would be to use asynchronous mode. We leave this for further reading if you feel interested.\n",
|
||||
"\n",
|
||||
"Note that SubprocVectorEnv should only be used when the environment exection is slow. In practice, DummyVectorEnv (or raw Gym environment) is actually more efficient for a simple environment like CartPole because now you avoid both straggler effect and the overhead of communication between subprocesses."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "S1b6vxp9nEUS"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Usages\n",
|
||||
"## Initialisation\n",
|
||||
"Just pass in a list of functions which return the initialised environment upon called."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Z6yPxdqFp18j"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from tianshou.env import DummyVectorEnv\n",
|
||||
"# In Gym\n",
|
||||
"env = gym.make(\"CartPole-v0\")\n",
|
||||
"\n",
|
||||
"# In Tianshou\n",
|
||||
"def helper_function():\n",
|
||||
" env = gym.make(\"CartPole-v0\")\n",
|
||||
" # other operations such as env.seed(np.random.choice(10))\n",
|
||||
" return env\n",
|
||||
"\n",
|
||||
"envs = DummyVectorEnv([helper_function for _ in range(5)])\n",
|
||||
"print(envs)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ssLcrL_pq24-"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## EnvPool supporting\n",
|
||||
"Besides integrated environment wrappers, Tianshou also fully supports [EnvPool](https://github.com/sail-sg/envpool/). Explore its Github page yourself."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "X7p8csjdrwIN"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Environment exection and resetting\n",
|
||||
"The only difference between Vectorized environments and standard Gym environments is that passed in actions and returned rewards/observations are also vectorized."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "kvIfqh0vqAR5"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# In Gym, env.reset() returns a single observation.\n",
|
||||
"print(\"In Gym, env.reset() returns a single observation.\")\n",
|
||||
"print(env.reset())\n",
|
||||
"\n",
|
||||
"# In Tianshou, envs.reset() returns stacked observations.\n",
|
||||
"print(\"========================================\")\n",
|
||||
"print(\"In Tianshou, envs.reset() returns stacked observations.\")\n",
|
||||
"print(envs.reset())\n",
|
||||
"\n",
|
||||
"obs, rew, done, info = envs.step(np.random.choice(2, size=num_cpu))\n",
|
||||
"print(info)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "BH1ZnPG6tkdD"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"If we only want to execute several environments. The `id` argument can be used."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "qXroB7KluvP9"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(envs.step(np.random.choice(2, size=3), id=[0,3,1]))"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ufvFViKTu8d_"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Further Reading\n",
|
||||
"## Other environment wrappers in Tianshou\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"* ShmemVectorEnv: use share memory instead of pipe based on SubprocVectorEnv;\n",
|
||||
"* RayVectorEnv: use Ray for concurrent activities and is currently the only choice for parallel simulation in a cluster with multiple machines.\n",
|
||||
"\n",
|
||||
"Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.env.html) for details.\n",
|
||||
"\n",
|
||||
"## Difference between synchronous and asynchronous mode (How to choose?)\n",
|
||||
"Explanation can be found at the [Parallel Sampling](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling) tutorial."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "fekHR1a6X_HB"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"id": "67wKtkiNi3lb",
|
||||
"outputId": "1e04353b-7a91-4c32-e2ae-f3889d58aa5e"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"You may notice that the speed doesn't increase linearly when we add subprocess numbers. There are multiple reasons behind this. One reason is that synchronize exection causes straggler effect. One way to solve this would be to use asynchronous mode. We leave this for further reading if you feel interested.\n",
|
||||
"\n",
|
||||
"Note that SubprocVectorEnv should only be used when the environment exection is slow. In practice, DummyVectorEnv (or raw Gym environment) is actually more efficient for a simple environment like CartPole because now you avoid both straggler effect and the overhead of communication between subprocesses."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "S1b6vxp9nEUS"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Usages\n",
|
||||
"## Initialisation\n",
|
||||
"Just pass in a list of functions which return the initialised environment upon called."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Z6yPxdqFp18j"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from tianshou.env import DummyVectorEnv\n",
|
||||
"# In Gym\n",
|
||||
"env = gym.make(\"CartPole-v0\")\n",
|
||||
"\n",
|
||||
"# In Tianshou\n",
|
||||
"def helper_function():\n",
|
||||
" env = gym.make(\"CartPole-v0\")\n",
|
||||
" # other operations such as env.seed(np.random.choice(10))\n",
|
||||
" return env\n",
|
||||
"\n",
|
||||
"envs = DummyVectorEnv([helper_function for _ in range(5)])\n",
|
||||
"print(envs)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ssLcrL_pq24-"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## EnvPool supporting\n",
|
||||
"Besides integrated environment wrappers, Tianshou also fully supports [EnvPool](https://github.com/sail-sg/envpool/). Explore its Github page yourself."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "X7p8csjdrwIN"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Environment exection and resetting\n",
|
||||
"The only difference between Vectorized environments and standard Gym environments is that passed in actions and returned rewards/observations are also vectorized."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "kvIfqh0vqAR5"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# In Gym, env.reset() returns a single observation.\n",
|
||||
"print(\"In Gym, env.reset() returns a single observation.\")\n",
|
||||
"print(env.reset())\n",
|
||||
"\n",
|
||||
"# In Tianshou, envs.reset() returns stacked observations.\n",
|
||||
"print(\"========================================\")\n",
|
||||
"print(\"In Tianshou, envs.reset() returns stacked observations.\")\n",
|
||||
"print(envs.reset())\n",
|
||||
"\n",
|
||||
"obs, rew, done, info = envs.step(np.random.choice(2, size=num_cpu))\n",
|
||||
"print(info)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "BH1ZnPG6tkdD"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"If we only want to execute several environments. The `id` argument can be used."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "qXroB7KluvP9"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(envs.step(np.random.choice(2, size=3), id=[0,3,1]))"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ufvFViKTu8d_"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Further Reading\n",
|
||||
"## Other environment wrappers in Tianshou\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"* ShmemVectorEnv: use share memory instead of pipe based on SubprocVectorEnv;\n",
|
||||
"* RayVectorEnv: use Ray for concurrent activities and is currently the only choice for parallel simulation in a cluster with multiple machines.\n",
|
||||
"\n",
|
||||
"Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.env.html) for details.\n",
|
||||
"\n",
|
||||
"## Difference between synchronous and asynchronous mode (How to choose?)\n",
|
||||
"Explanation can be found at the [Parallel Sampling](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling) tutorial."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "fekHR1a6X_HB"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1,352 +1,321 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": [],
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": [],
|
||||
"toc_visible": true
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Overview\n",
|
||||
"Finally, we can assemble building blocks that we have came across in previous tutorials to conduct our first DRL experiment. In this experiment, we will use [PPO](https://arxiv.org/abs/1707.06347) algorithm to solve the classic CartPole task in Gym."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "_UaXOSRjDUF9"
|
||||
}
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Overview\n",
|
||||
"Finally, we can assemble building blocks that we have came across in previous tutorials to conduct our first DRL experiment. In this experiment, we will use [PPO](https://arxiv.org/abs/1707.06347) algorithm to solve the classic CartPole task in Gym."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "_UaXOSRjDUF9"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Experiment\n",
|
||||
"To conduct this experiment, we need the following building blocks.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"* Two vectorized environments, one for training and one for evaluation\n",
|
||||
"* A PPO agent\n",
|
||||
"* A replay buffer to store transition data\n",
|
||||
"* Two collectors to manage the data collecting process, one for training and one for evaluation\n",
|
||||
"* A trainer to manage the training loop\n",
|
||||
"\n",
|
||||
"<div align=center>\n",
|
||||
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/pipeline.png\", width=500>\n",
|
||||
"\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
"Let us do this step by step."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "2QRbCJvDHNAd"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Preparation\n",
|
||||
"Firstly, install Tianshou if you haven't installed it before."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "-Hh4E6i0Hj0I"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!pip install tianshou==0.4.8\n",
|
||||
"!pip install gym"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "w50BVwaRHg3N"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Import libraries we might need later."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "7E4EhiBeHxD5"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import gym\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from tianshou.data import Collector, VectorReplayBuffer\n",
|
||||
"from tianshou.env import DummyVectorEnv\n",
|
||||
"from tianshou.policy import PPOPolicy\n",
|
||||
"from tianshou.trainer import onpolicy_trainer\n",
|
||||
"from tianshou.utils.net.common import ActorCritic, Net\n",
|
||||
"from tianshou.utils.net.discrete import Actor, Critic\n",
|
||||
"\n",
|
||||
"import warnings\n",
|
||||
"warnings.filterwarnings('ignore')\n",
|
||||
"\n",
|
||||
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ao9gWJDiHgG-"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Environment"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "QnRg5y7THRYw"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We create two vectorized environments both for training and testing. Since the execution time of CartPole is extremely short, there is no need to use multi-process wrappers and we simply use DummyVectorEnv."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "YZERKCGtH8W1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Mpuj5PFnDKVS"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = gym.make('CartPole-v0')\n",
|
||||
"train_envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(20)])\n",
|
||||
"test_envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Policy\n",
|
||||
"Next we need to initialise our PPO policy. PPO is an actor-critic-style on-policy algorithm, so we have to define the actor and the critic in PPO first.\n",
|
||||
"\n",
|
||||
"The actor is a neural network that shares the same network head with the critic. Both networks' input is the environment observation. The output of the actor is the action and the output of the critic is a single value, representing the value of the current policy.\n",
|
||||
"\n",
|
||||
"Luckily, Tianshou already provides basic network modules that we can use in this experiment."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "BJtt_Ya8DTAh"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# net is the shared head of the actor and the critic\n",
|
||||
"net = Net(env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n",
|
||||
"actor = Actor(net, env.action_space.n, device=device).to(device)\n",
|
||||
"critic = Critic(net, device=device).to(device)\n",
|
||||
"actor_critic = ActorCritic(actor, critic)\n",
|
||||
"\n",
|
||||
"# optimizer of the actor and the critic\n",
|
||||
"optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "_Vy8uPWXP4m_"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Once we have defined the actor, the critic and the optimizer. We can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Lh2-hwE5Dn9I"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"dist = torch.distributions.Categorical\n",
|
||||
"policy = PPOPolicy(actor, critic, optim, dist, action_space=env.action_space, deterministic_eval=True)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "OiJ2GkT0Qnbr"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"`deterministic_eval=True` means that we want to sample actions during training but we would like to always use the best action in evaluation. No randomness included."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "okxfj6IEQ-r8"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Collector\n",
|
||||
"We can set up the collectors now. Train collector is used to collect and store training data, so an additional replay buffer has to be passed in."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "n5XAAbuBZarO"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)))\n",
|
||||
"test_collector = Collector(policy, test_envs)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ezwz0qerZhQM"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We use `VectorReplayBuffer` here because it's more efficient to collaborate with vectorized environments, you can simply consider `VectorReplayBuffer` as a a list of ordinary replay buffers."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ZaoPxOd2hm0b"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Trainer\n",
|
||||
"Finally, we can use the trainer to help us set up the training loop."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "qBoE9pLUiC-8"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"result = onpolicy_trainer(\n",
|
||||
" policy,\n",
|
||||
" train_collector,\n",
|
||||
" test_collector,\n",
|
||||
" max_epoch=10,\n",
|
||||
" step_per_epoch=50000,\n",
|
||||
" repeat_per_collect=10,\n",
|
||||
" episode_per_test=10,\n",
|
||||
" batch_size=256,\n",
|
||||
" step_per_collect=2000,\n",
|
||||
" stop_fn=lambda mean_reward: mean_reward >= 195,\n",
|
||||
")"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "i45EDnpxQ8gu",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Experiment\n",
|
||||
"To conduct this experiment, we need the following building blocks.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"* Two vectorized environments, one for training and one for evaluation\n",
|
||||
"* A PPO agent\n",
|
||||
"* A replay buffer to store transition data\n",
|
||||
"* Two collectors to manage the data collecting process, one for training and one for evaluation\n",
|
||||
"* A trainer to manage the training loop\n",
|
||||
"\n",
|
||||
"<div align=center>\n",
|
||||
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/pipeline.png\", width=500>\n",
|
||||
"\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
"Let us do this step by step."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "2QRbCJvDHNAd"
|
||||
}
|
||||
"outputId": "b1666b88-0bfa-4340-868e-58611872d988"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Results\n",
|
||||
"Print the training result."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ckgINHE2iTFR"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(result)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Preparation\n",
|
||||
"Firstly, install Tianshou if you haven't installed it before."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "-Hh4E6i0Hj0I"
|
||||
}
|
||||
"id": "tJCPgmiyiaaX",
|
||||
"outputId": "40123ae3-3365-4782-9563-46c43812f10f"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We can also test our trained agent."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "A-MJ9avMibxN"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Let's watch its performance!\n",
|
||||
"policy.eval()\n",
|
||||
"result = test_collector.collect(n_episode=1, render=False)\n",
|
||||
"print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!pip install tianshou==0.4.8\n",
|
||||
"!pip install gym"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "w50BVwaRHg3N"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Import libraries we might need later."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "7E4EhiBeHxD5"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import gym\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from tianshou.data import Collector, VectorReplayBuffer\n",
|
||||
"from tianshou.env import DummyVectorEnv\n",
|
||||
"from tianshou.policy import PPOPolicy\n",
|
||||
"from tianshou.trainer import onpolicy_trainer\n",
|
||||
"from tianshou.utils.net.common import ActorCritic, Net\n",
|
||||
"from tianshou.utils.net.discrete import Actor, Critic\n",
|
||||
"\n",
|
||||
"import warnings\n",
|
||||
"warnings.filterwarnings('ignore')\n",
|
||||
"\n",
|
||||
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ao9gWJDiHgG-"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Environment"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "QnRg5y7THRYw"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We create two vectorized environments both for training and testing. Since the execution time of CartPole is extremely short, there is no need to use multi-process wrappers and we simply use DummyVectorEnv."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "YZERKCGtH8W1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Mpuj5PFnDKVS"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = gym.make('CartPole-v0')\n",
|
||||
"train_envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(20)])\n",
|
||||
"test_envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Policy\n",
|
||||
"Next we need to initialise our PPO policy. PPO is an actor-critic-style on-policy algorithm, so we have to define the actor and the critic in PPO first.\n",
|
||||
"\n",
|
||||
"The actor is a neural network that shares the same network head with the critic. Both networks' input is the environment observation. The output of the actor is the action and the output of the critic is a single value, representing the value of the current policy.\n",
|
||||
"\n",
|
||||
"Luckily, Tianshou already provides basic network modules that we can use in this experiment."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "BJtt_Ya8DTAh"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# net is the shared head of the actor and the critic\n",
|
||||
"net = Net(env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n",
|
||||
"actor = Actor(net, env.action_space.n, device=device).to(device)\n",
|
||||
"critic = Critic(net, device=device).to(device)\n",
|
||||
"actor_critic = ActorCritic(actor, critic)\n",
|
||||
"\n",
|
||||
"# optimizer of the actor and the critic\n",
|
||||
"optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "_Vy8uPWXP4m_"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Once we have defined the actor, the critic and the optimizer. We can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Lh2-hwE5Dn9I"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"dist = torch.distributions.Categorical\n",
|
||||
"policy = PPOPolicy(actor, critic, optim, dist, action_space=env.action_space, deterministic_eval=True)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "OiJ2GkT0Qnbr"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"`deterministic_eval=True` means that we want to sample actions during training but we would like to always use the best action in evaluation. No randomness included."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "okxfj6IEQ-r8"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Collector\n",
|
||||
"We can set up the collectors now. Train collector is used to collect and store training data, so an additional replay buffer has to be passed in."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "n5XAAbuBZarO"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)))\n",
|
||||
"test_collector = Collector(policy, test_envs)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ezwz0qerZhQM"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We use `VectorReplayBuffer` here because it's more efficient to collaborate with vectorized environments, you can simply consider `VectorReplayBuffer` as a a list of ordinary replay buffers."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ZaoPxOd2hm0b"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Trainer\n",
|
||||
"Finally, we can use the trainer to help us set up the training loop."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "qBoE9pLUiC-8"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"result = onpolicy_trainer(\n",
|
||||
" policy,\n",
|
||||
" train_collector,\n",
|
||||
" test_collector,\n",
|
||||
" max_epoch=10,\n",
|
||||
" step_per_epoch=50000,\n",
|
||||
" repeat_per_collect=10,\n",
|
||||
" episode_per_test=10,\n",
|
||||
" batch_size=256,\n",
|
||||
" step_per_collect=2000,\n",
|
||||
" stop_fn=lambda mean_reward: mean_reward >= 195,\n",
|
||||
")"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "i45EDnpxQ8gu",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"outputId": "b1666b88-0bfa-4340-868e-58611872d988"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"Epoch #1: 50001it [00:13, 3601.81it/s, env_step=50000, len=143, loss=41.162, loss/clip=0.001, loss/ent=0.583, loss/vf=82.332, n/ep=12, n/st=2000, rew=143.08] \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Epoch #1: test_reward: 200.000000 ± 0.000000, best_reward: 200.000000 ± 0.000000 in #1\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Results\n",
|
||||
"Print the training result."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ckgINHE2iTFR"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(result)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "tJCPgmiyiaaX",
|
||||
"outputId": "40123ae3-3365-4782-9563-46c43812f10f"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"{'duration': '14.17s', 'train_time/model': '8.80s', 'test_step': 2094, 'test_episode': 20, 'test_time': '0.27s', 'test_speed': '7770.16 step/s', 'best_reward': 200.0, 'best_result': '200.00 ± 0.00', 'train_step': 50000, 'train_episode': 942, 'train_time/collector': '5.10s', 'train_speed': '3597.32 step/s'}\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We can also test our trained agent."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "A-MJ9avMibxN"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Let's watch its performance!\n",
|
||||
"policy.eval()\n",
|
||||
"result = test_collector.collect(n_episode=1, render=False)\n",
|
||||
"print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "mnMANFcciiAQ",
|
||||
"outputId": "6febcc1e-7265-4a75-c9dd-34e29a3e5d21"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Final reward: 200.0, length: 200.0\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
"id": "mnMANFcciiAQ",
|
||||
"outputId": "6febcc1e-7265-4a75-c9dd-34e29a3e5d21"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user