Tianshou/notebooks/L0_overview.ipynb

360 lines
26 KiB
Plaintext

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"language": "python",
"display_name": "Python 3 (ipykernel)"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# Overview\n",
"In this toturial, we use guide you step by step to show you how the most basic modules in Tianshou work and how they collaborate with each other to conduct a classic DRL experiment."
],
"metadata": {
"id": "r7aE6Rq3cAEE"
}
},
{
"cell_type": "markdown",
"source": [
"## Run the code\n",
"Before we get started, we must first install Tianshou's library and Gym environment by running the commands below. Here I choose a specific version of Tianshou(0.4.8) which is the latest as of the time writing this toturial. APIs in differet versions may vary a little bit but most are the same. Feel free to use other versions in your own project."
],
"metadata": {
"id": "1_mLTSEIcY2c"
}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "qvplhjduVDs6",
"ExecuteTime": {
"end_time": "2023-10-12T15:51:01.680688825Z",
"start_time": "2023-10-12T15:48:15.090023052Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting tianshou==0.4.8\r\n",
" Downloading tianshou-0.4.8-py3-none-any.whl (150 kB)\r\n",
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m150.4/150.4 kB\u001B[0m \u001B[31m3.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0ma \u001B[36m0:00:01\u001B[0m\r\n",
"\u001B[?25hCollecting gym>=0.15.4 (from tianshou==0.4.8)\r\n",
" Downloading gym-0.26.2.tar.gz (721 kB)\r\n",
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m721.7/721.7 kB\u001B[0m \u001B[31m11.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0ma \u001B[36m0:00:01\u001B[0m\r\n",
"\u001B[?25h Installing build dependencies ... \u001B[?25ldone\r\n",
"\u001B[?25h Getting requirements to build wheel ... \u001B[?25ldone\r\n",
"\u001B[?25h Preparing metadata (pyproject.toml) ... \u001B[?25ldone\r\n",
"\u001B[?25hRequirement already satisfied: tqdm in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (4.66.1)\r\n",
"Requirement already satisfied: numpy>1.16.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (1.24.4)\r\n",
"Requirement already satisfied: tensorboard>=2.5.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (2.14.1)\r\n",
"Requirement already satisfied: torch>=1.4.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (2.1.0)\r\n",
"Requirement already satisfied: numba>=0.51.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (0.57.1)\r\n",
"Requirement already satisfied: h5py>=2.10.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (3.10.0)\r\n",
"Requirement already satisfied: cloudpickle>=1.2.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym>=0.15.4->tianshou==0.4.8) (2.2.1)\r\n",
"Collecting gym-notices>=0.0.4 (from gym>=0.15.4->tianshou==0.4.8)\r\n",
" Downloading gym_notices-0.0.8-py3-none-any.whl (3.0 kB)\r\n",
"Requirement already satisfied: llvmlite<0.41,>=0.40.0dev0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from numba>=0.51.0->tianshou==0.4.8) (0.40.1)\r\n",
"Requirement already satisfied: absl-py>=0.4 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (2.0.0)\r\n",
"Requirement already satisfied: grpcio>=1.48.2 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (1.59.0)\r\n",
"Requirement already satisfied: google-auth<3,>=1.6.3 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (2.23.3)\r\n",
"Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (1.0.0)\r\n",
"Requirement already satisfied: markdown>=2.6.8 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (3.5)\r\n",
"Requirement already satisfied: protobuf>=3.19.6 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (3.20.3)\r\n",
"Requirement already satisfied: requests<3,>=2.21.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (2.31.0)\r\n",
"Requirement already satisfied: setuptools>=41.0.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (68.2.2)\r\n",
"Requirement already satisfied: six>1.9 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (1.16.0)\r\n",
"Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (0.7.1)\r\n",
"Requirement already satisfied: werkzeug>=1.0.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (3.0.0)\r\n",
"Requirement already satisfied: filelock in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (3.12.4)\r\n",
"Requirement already satisfied: typing-extensions in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (4.8.0)\r\n",
"Requirement already satisfied: sympy in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (1.12)\r\n",
"Requirement already satisfied: networkx in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (3.1)\r\n",
"Requirement already satisfied: jinja2 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (3.1.2)\r\n",
"Requirement already satisfied: fsspec in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (2023.9.2)\r\n",
"Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
" Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\r\n",
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m23.7/23.7 MB\u001B[0m \u001B[31m17.1 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
"\u001B[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
" Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\r\n",
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m823.6/823.6 kB\u001B[0m \u001B[31m20.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m\r\n",
"\u001B[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
" Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\r\n",
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m14.1/14.1 MB\u001B[0m \u001B[31m14.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
"\u001B[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
" Obtaining dependency information for nvidia-cudnn-cu12==8.9.2.26 from https://files.pythonhosted.org/packages/ff/74/a2e2be7fb83aaedec84f391f082cf765dfb635e7caa9b49065f73e4835d8/nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata\r\n",
" Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\r\n",
"Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
" Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\r\n",
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m410.6/410.6 MB\u001B[0m \u001B[31m5.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
"\u001B[?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
" Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\r\n",
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m121.6/121.6 MB\u001B[0m \u001B[31m9.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
"\u001B[?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
" Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\r\n",
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m56.5/56.5 MB\u001B[0m \u001B[31m13.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
"\u001B[?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
" Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\r\n",
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m124.2/124.2 MB\u001B[0m \u001B[31m10.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
"\u001B[?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
" Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\r\n",
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m196.0/196.0 MB\u001B[0m \u001B[31m8.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
"\u001B[?25hCollecting nvidia-nccl-cu12==2.18.1 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
" Downloading nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl (209.8 MB)\r\n",
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m209.8/209.8 MB\u001B[0m \u001B[31m8.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
"\u001B[?25hCollecting nvidia-nvtx-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
" Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\r\n",
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m99.1/99.1 kB\u001B[0m \u001B[31m28.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\r\n",
"\u001B[?25hCollecting triton==2.1.0 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
" Obtaining dependency information for triton==2.1.0 from https://files.pythonhosted.org/packages/5c/c1/54fffb2eb13d293d9a429fead3646752ea190de0229bcf3d591ba2481263/triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata\r\n",
" Downloading triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.3 kB)\r\n",
"Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.4.0->tianshou==0.4.8)\r\n",
" Obtaining dependency information for nvidia-nvjitlink-cu12 from https://files.pythonhosted.org/packages/0a/f8/5193b57555cbeecfdb6ade643df0d4218cc6385485492b6e2f64ceae53bb/nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl.metadata\r\n",
" Downloading nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\r\n",
"Requirement already satisfied: cachetools<6.0,>=2.0.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (5.3.1)\r\n",
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (0.3.0)\r\n",
"Requirement already satisfied: rsa<5,>=3.1.4 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (4.9)\r\n",
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard>=2.5.0->tianshou==0.4.8) (1.3.1)\r\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (3.3.0)\r\n",
"Requirement already satisfied: idna<4,>=2.5 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (3.4)\r\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (2.0.6)\r\n",
"Requirement already satisfied: certifi>=2017.4.17 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (2023.7.22)\r\n",
"Requirement already satisfied: MarkupSafe>=2.1.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from werkzeug>=1.0.1->tensorboard>=2.5.0->tianshou==0.4.8) (2.1.3)\r\n",
"Requirement already satisfied: mpmath>=0.19 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from sympy->torch>=1.4.0->tianshou==0.4.8) (1.3.0)\r\n",
"Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (0.5.0)\r\n",
"Requirement already satisfied: oauthlib>=3.0.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard>=2.5.0->tianshou==0.4.8) (3.2.2)\r\n",
"Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\r\n",
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m731.7/731.7 MB\u001B[0m \u001B[31m3.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m:00:01\u001B[0m00:01\u001B[0m\r\n",
"\u001B[?25hDownloading triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (89.2 MB)\r\n",
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m89.2/89.2 MB\u001B[0m \u001B[31m13.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
"\u001B[?25hDownloading nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl (20.2 MB)\r\n",
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m20.2/20.2 MB\u001B[0m \u001B[31m15.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
"\u001B[?25hBuilding wheels for collected packages: gym\r\n",
" Building wheel for gym (pyproject.toml) ... \u001B[?25ldone\r\n",
"\u001B[?25h Created wheel for gym: filename=gym-0.26.2-py3-none-any.whl size=827621 sha256=612698033ee83c54db52d872001a111f5f0adf14dd996065edff561305ac2266\r\n",
" Stored in directory: /home/ccagnetta/.cache/pip/wheels/1c/77/9e/9af5470201a0b0543937933ee99ba884cd237d2faefe8f4d37\r\n",
"Successfully built gym\r\n",
"Installing collected packages: gym-notices, triton, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, gym, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, tianshou\r\n",
" Attempting uninstall: triton\r\n",
" Found existing installation: triton 2.0.0\r\n",
" Uninstalling triton-2.0.0:\r\n",
" Successfully uninstalled triton-2.0.0\r\n",
" Attempting uninstall: tianshou\r\n",
" Found existing installation: tianshou 0.5.1\r\n",
" Uninstalling tianshou-0.5.1:\r\n",
" Successfully uninstalled tianshou-0.5.1\r\n",
"Successfully installed gym-0.26.2 gym-notices-0.0.8 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.18.1 nvidia-nvjitlink-cu12-12.2.140 nvidia-nvtx-cu12-12.1.105 tianshou-0.4.8 triton-2.1.0\r\n",
"Requirement already satisfied: gym in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (0.26.2)\r\n",
"Requirement already satisfied: numpy>=1.18.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym) (1.24.4)\r\n",
"Requirement already satisfied: cloudpickle>=1.2.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym) (2.2.1)\r\n",
"Requirement already satisfied: gym-notices>=0.0.4 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym) (0.0.8)\r\n"
]
}
],
"source": [
"!pip install tianshou==0.4.8\n",
"!pip install gym"
]
},
{
"cell_type": "markdown",
"source": [
"Below is a short script that use a certain DRL algorithm (PPO) to solve the classic CartPole-v0\n",
"problem in Gym. Simply run it and **don't worry** if you can't understand the code very well. That is\n",
"exactly what this tutorial is for.\n",
"\n",
"If the script ends normally, you will see the evaluation result printed out before the first\n",
"epoch is done."
],
"metadata": {
"id": "IcFNmCjYeIIU"
}
},
{
"cell_type": "code",
"source": [
"import gym\n",
"import numpy as np\n",
"import torch\n",
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import PPOPolicy\n",
"from tianshou.trainer import onpolicy_trainer\n",
"from tianshou.utils.net.common import ActorCritic, Net\n",
"from tianshou.utils.net.discrete import Actor, Critic\n",
"\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"# environments\n",
"env = gym.make('CartPole-v0')\n",
"train_envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(20)])\n",
"test_envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)])\n",
"\n",
"# model & optimizer\n",
"net = Net(env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n",
"actor = Actor(net, env.action_space.n, device=device).to(device)\n",
"critic = Critic(net, device=device).to(device)\n",
"actor_critic = ActorCritic(actor, critic)\n",
"optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)\n",
"\n",
"# PPO policy\n",
"dist = torch.distributions.Categorical\n",
"policy = PPOPolicy(actor, critic, optim, dist, action_space=env.action_space, deterministic_eval=True)\n",
"\n",
"\n",
"# collector\n",
"train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)))\n",
"test_collector = Collector(policy, test_envs)\n",
"\n",
"# trainer\n",
"result = onpolicy_trainer(\n",
" policy,\n",
" train_collector,\n",
" test_collector,\n",
" max_epoch=10,\n",
" step_per_epoch=50000,\n",
" repeat_per_collect=10,\n",
" episode_per_test=10,\n",
" batch_size=256,\n",
" step_per_collect=2000,\n",
" stop_fn=lambda mean_reward: mean_reward >= 195,\n",
")\n",
"print(result)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pxY_ZbGmkr6_",
"outputId": "b792fc24-f42c-426a-9d83-fe1a4f3f91f1"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Epoch #1: 50001it [00:19, 2529.50it/s, env_step=50000, len=87, loss=80.895, loss/clip=-0.009, loss/ent=0.566, loss/vf=161.818, n/ep=15, n/st=2000, rew=87.27] \n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch #1: test_reward: 200.000000 ± 0.000000, best_reward: 200.000000 ± 0.000000 in #1\n",
"{'duration': '20.26s', 'train_time/model': '13.75s', 'test_step': 2159, 'test_episode': 20, 'test_time': '0.48s', 'test_speed': '4496.33 step/s', 'best_reward': 200.0, 'best_result': '200.00 ± 0.00', 'train_step': 50000, 'train_episode': 944, 'train_time/collector': '6.03s', 'train_speed': '2527.97 step/s'}\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Let's watch its performance!\n",
"policy.eval()\n",
"result = test_collector.collect(n_episode=1, render=False)\n",
"print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "G9YEQptYvCgx",
"outputId": "2a9b5b22-be50-4bb7-ae93-af7e65e7442a"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Final reward: 200.0, length: 200.0\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Tutorial Introduction\n",
"\n",
"A common DRL experiment as is shown above may require many components to work together. The agent, the\n",
"environment (possibly parallelized ones), the replay buffer and the trainer all work together to complete a\n",
"training task.\n",
"\n",
"<div align=center>\n",
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/pipeline.png\", width=500>\n",
"\n",
"</div>\n"
],
"metadata": {
"id": "xFYlcPo8fpPU"
}
},
{
"cell_type": "markdown",
"source": [
"In Tianshou, all of these main components are factored out as different building blocks, which you\n",
"can use to create your own algorithm and finish your own experiment.\n",
"\n",
"Buiding blocks may include:\n",
"- Batch\n",
"- Replay Buffer\n",
"- Vectorized Environment Wrapper\n",
"- Policy (the agent and the training algorithm)\n",
"- Data Collector\n",
"- Trainer\n",
"- Logger\n",
"\n",
"\n",
"Check this [webpage](https://tianshou.readthedocs.io/en/master/tutorials/dqn.html) to find jupter-notebook-style tutorials that will guide you through all these\n",
"modules one by one. You can also read the [documentation](https://tianshou.readthedocs.io/en/master/) of Tianshou for more detailed explanation and\n",
"advanced usages."
],
"metadata": {
"id": "kV_uOyimj-bk"
}
},
{
"cell_type": "markdown",
"source": [
"# Further reading"
],
"metadata": {
"id": "S0mNKwH9i6Ek"
}
},
{
"cell_type": "markdown",
"source": [
"## What if I am not familar with the PPO algorithm itself?\n",
"As for the DRL algorithms themselves, we will refer you to the [Spinning up documentation](https://spinningup.openai.com/en/latest/algorithms/ppo.html), where they provide\n",
"plenty of resources and guides if you want to study the DRL algorithms. In Tianshou's toturials, we will\n",
"focus on the usages of different modules, but not the algorithms themselves."
],
"metadata": {
"id": "M3NPSUnAov4L"
}
}
]
}