360 lines
26 KiB
Plaintext
360 lines
26 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"language": "python",
|
|
"display_name": "Python 3 (ipykernel)"
|
|
},
|
|
"language_info": {
|
|
"name": "python"
|
|
},
|
|
"accelerator": "GPU"
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# Overview\n",
|
|
"In this toturial, we use guide you step by step to show you how the most basic modules in Tianshou work and how they collaborate with each other to conduct a classic DRL experiment."
|
|
],
|
|
"metadata": {
|
|
"id": "r7aE6Rq3cAEE"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## Run the code\n",
|
|
"Before we get started, we must first install Tianshou's library and Gym environment by running the commands below. Here I choose a specific version of Tianshou(0.4.8) which is the latest as of the time writing this toturial. APIs in differet versions may vary a little bit but most are the same. Feel free to use other versions in your own project."
|
|
],
|
|
"metadata": {
|
|
"id": "1_mLTSEIcY2c"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"id": "qvplhjduVDs6",
|
|
"ExecuteTime": {
|
|
"end_time": "2023-10-12T15:51:01.680688825Z",
|
|
"start_time": "2023-10-12T15:48:15.090023052Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Collecting tianshou==0.4.8\r\n",
|
|
" Downloading tianshou-0.4.8-py3-none-any.whl (150 kB)\r\n",
|
|
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m150.4/150.4 kB\u001B[0m \u001B[31m3.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0ma \u001B[36m0:00:01\u001B[0m\r\n",
|
|
"\u001B[?25hCollecting gym>=0.15.4 (from tianshou==0.4.8)\r\n",
|
|
" Downloading gym-0.26.2.tar.gz (721 kB)\r\n",
|
|
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m721.7/721.7 kB\u001B[0m \u001B[31m11.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0ma \u001B[36m0:00:01\u001B[0m\r\n",
|
|
"\u001B[?25h Installing build dependencies ... \u001B[?25ldone\r\n",
|
|
"\u001B[?25h Getting requirements to build wheel ... \u001B[?25ldone\r\n",
|
|
"\u001B[?25h Preparing metadata (pyproject.toml) ... \u001B[?25ldone\r\n",
|
|
"\u001B[?25hRequirement already satisfied: tqdm in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (4.66.1)\r\n",
|
|
"Requirement already satisfied: numpy>1.16.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (1.24.4)\r\n",
|
|
"Requirement already satisfied: tensorboard>=2.5.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (2.14.1)\r\n",
|
|
"Requirement already satisfied: torch>=1.4.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (2.1.0)\r\n",
|
|
"Requirement already satisfied: numba>=0.51.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (0.57.1)\r\n",
|
|
"Requirement already satisfied: h5py>=2.10.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tianshou==0.4.8) (3.10.0)\r\n",
|
|
"Requirement already satisfied: cloudpickle>=1.2.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym>=0.15.4->tianshou==0.4.8) (2.2.1)\r\n",
|
|
"Collecting gym-notices>=0.0.4 (from gym>=0.15.4->tianshou==0.4.8)\r\n",
|
|
" Downloading gym_notices-0.0.8-py3-none-any.whl (3.0 kB)\r\n",
|
|
"Requirement already satisfied: llvmlite<0.41,>=0.40.0dev0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from numba>=0.51.0->tianshou==0.4.8) (0.40.1)\r\n",
|
|
"Requirement already satisfied: absl-py>=0.4 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (2.0.0)\r\n",
|
|
"Requirement already satisfied: grpcio>=1.48.2 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (1.59.0)\r\n",
|
|
"Requirement already satisfied: google-auth<3,>=1.6.3 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (2.23.3)\r\n",
|
|
"Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (1.0.0)\r\n",
|
|
"Requirement already satisfied: markdown>=2.6.8 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (3.5)\r\n",
|
|
"Requirement already satisfied: protobuf>=3.19.6 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (3.20.3)\r\n",
|
|
"Requirement already satisfied: requests<3,>=2.21.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (2.31.0)\r\n",
|
|
"Requirement already satisfied: setuptools>=41.0.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (68.2.2)\r\n",
|
|
"Requirement already satisfied: six>1.9 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (1.16.0)\r\n",
|
|
"Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (0.7.1)\r\n",
|
|
"Requirement already satisfied: werkzeug>=1.0.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from tensorboard>=2.5.0->tianshou==0.4.8) (3.0.0)\r\n",
|
|
"Requirement already satisfied: filelock in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (3.12.4)\r\n",
|
|
"Requirement already satisfied: typing-extensions in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (4.8.0)\r\n",
|
|
"Requirement already satisfied: sympy in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (1.12)\r\n",
|
|
"Requirement already satisfied: networkx in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (3.1)\r\n",
|
|
"Requirement already satisfied: jinja2 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (3.1.2)\r\n",
|
|
"Requirement already satisfied: fsspec in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from torch>=1.4.0->tianshou==0.4.8) (2023.9.2)\r\n",
|
|
"Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
|
" Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\r\n",
|
|
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m23.7/23.7 MB\u001B[0m \u001B[31m17.1 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
|
"\u001B[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
|
" Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\r\n",
|
|
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m823.6/823.6 kB\u001B[0m \u001B[31m20.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m\r\n",
|
|
"\u001B[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
|
" Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\r\n",
|
|
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m14.1/14.1 MB\u001B[0m \u001B[31m14.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
|
"\u001B[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
|
" Obtaining dependency information for nvidia-cudnn-cu12==8.9.2.26 from https://files.pythonhosted.org/packages/ff/74/a2e2be7fb83aaedec84f391f082cf765dfb635e7caa9b49065f73e4835d8/nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata\r\n",
|
|
" Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\r\n",
|
|
"Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
|
" Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\r\n",
|
|
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m410.6/410.6 MB\u001B[0m \u001B[31m5.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
|
"\u001B[?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
|
" Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\r\n",
|
|
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m121.6/121.6 MB\u001B[0m \u001B[31m9.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
|
"\u001B[?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
|
" Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\r\n",
|
|
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m56.5/56.5 MB\u001B[0m \u001B[31m13.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
|
"\u001B[?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
|
" Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\r\n",
|
|
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m124.2/124.2 MB\u001B[0m \u001B[31m10.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
|
"\u001B[?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
|
" Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\r\n",
|
|
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m196.0/196.0 MB\u001B[0m \u001B[31m8.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
|
"\u001B[?25hCollecting nvidia-nccl-cu12==2.18.1 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
|
" Downloading nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl (209.8 MB)\r\n",
|
|
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m209.8/209.8 MB\u001B[0m \u001B[31m8.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
|
"\u001B[?25hCollecting nvidia-nvtx-cu12==12.1.105 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
|
" Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\r\n",
|
|
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m99.1/99.1 kB\u001B[0m \u001B[31m28.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\r\n",
|
|
"\u001B[?25hCollecting triton==2.1.0 (from torch>=1.4.0->tianshou==0.4.8)\r\n",
|
|
" Obtaining dependency information for triton==2.1.0 from https://files.pythonhosted.org/packages/5c/c1/54fffb2eb13d293d9a429fead3646752ea190de0229bcf3d591ba2481263/triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata\r\n",
|
|
" Downloading triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.3 kB)\r\n",
|
|
"Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.4.0->tianshou==0.4.8)\r\n",
|
|
" Obtaining dependency information for nvidia-nvjitlink-cu12 from https://files.pythonhosted.org/packages/0a/f8/5193b57555cbeecfdb6ade643df0d4218cc6385485492b6e2f64ceae53bb/nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl.metadata\r\n",
|
|
" Downloading nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\r\n",
|
|
"Requirement already satisfied: cachetools<6.0,>=2.0.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (5.3.1)\r\n",
|
|
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (0.3.0)\r\n",
|
|
"Requirement already satisfied: rsa<5,>=3.1.4 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (4.9)\r\n",
|
|
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard>=2.5.0->tianshou==0.4.8) (1.3.1)\r\n",
|
|
"Requirement already satisfied: charset-normalizer<4,>=2 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (3.3.0)\r\n",
|
|
"Requirement already satisfied: idna<4,>=2.5 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (3.4)\r\n",
|
|
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (2.0.6)\r\n",
|
|
"Requirement already satisfied: certifi>=2017.4.17 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->tianshou==0.4.8) (2023.7.22)\r\n",
|
|
"Requirement already satisfied: MarkupSafe>=2.1.1 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from werkzeug>=1.0.1->tensorboard>=2.5.0->tianshou==0.4.8) (2.1.3)\r\n",
|
|
"Requirement already satisfied: mpmath>=0.19 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from sympy->torch>=1.4.0->tianshou==0.4.8) (1.3.0)\r\n",
|
|
"Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.5.0->tianshou==0.4.8) (0.5.0)\r\n",
|
|
"Requirement already satisfied: oauthlib>=3.0.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard>=2.5.0->tianshou==0.4.8) (3.2.2)\r\n",
|
|
"Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\r\n",
|
|
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m731.7/731.7 MB\u001B[0m \u001B[31m3.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m:00:01\u001B[0m00:01\u001B[0m\r\n",
|
|
"\u001B[?25hDownloading triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (89.2 MB)\r\n",
|
|
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m89.2/89.2 MB\u001B[0m \u001B[31m13.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
|
"\u001B[?25hDownloading nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl (20.2 MB)\r\n",
|
|
"\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m20.2/20.2 MB\u001B[0m \u001B[31m15.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\r\n",
|
|
"\u001B[?25hBuilding wheels for collected packages: gym\r\n",
|
|
" Building wheel for gym (pyproject.toml) ... \u001B[?25ldone\r\n",
|
|
"\u001B[?25h Created wheel for gym: filename=gym-0.26.2-py3-none-any.whl size=827621 sha256=612698033ee83c54db52d872001a111f5f0adf14dd996065edff561305ac2266\r\n",
|
|
" Stored in directory: /home/ccagnetta/.cache/pip/wheels/1c/77/9e/9af5470201a0b0543937933ee99ba884cd237d2faefe8f4d37\r\n",
|
|
"Successfully built gym\r\n",
|
|
"Installing collected packages: gym-notices, triton, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, gym, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, tianshou\r\n",
|
|
" Attempting uninstall: triton\r\n",
|
|
" Found existing installation: triton 2.0.0\r\n",
|
|
" Uninstalling triton-2.0.0:\r\n",
|
|
" Successfully uninstalled triton-2.0.0\r\n",
|
|
" Attempting uninstall: tianshou\r\n",
|
|
" Found existing installation: tianshou 0.5.1\r\n",
|
|
" Uninstalling tianshou-0.5.1:\r\n",
|
|
" Successfully uninstalled tianshou-0.5.1\r\n",
|
|
"Successfully installed gym-0.26.2 gym-notices-0.0.8 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.18.1 nvidia-nvjitlink-cu12-12.2.140 nvidia-nvtx-cu12-12.1.105 tianshou-0.4.8 triton-2.1.0\r\n",
|
|
"Requirement already satisfied: gym in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (0.26.2)\r\n",
|
|
"Requirement already satisfied: numpy>=1.18.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym) (1.24.4)\r\n",
|
|
"Requirement already satisfied: cloudpickle>=1.2.0 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym) (2.2.1)\r\n",
|
|
"Requirement already satisfied: gym-notices>=0.0.4 in /home/ccagnetta/.cache/pypoetry/virtualenvs/tianshou-spwmGTuX-py3.11/lib/python3.11/site-packages (from gym) (0.0.8)\r\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!pip install tianshou==0.4.8\n",
|
|
"!pip install gym"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"Below is a short script that use a certain DRL algorithm (PPO) to solve the classic CartPole-v0\n",
|
|
"problem in Gym. Simply run it and **don't worry** if you can't understand the code very well. That is\n",
|
|
"exactly what this tutorial is for.\n",
|
|
"\n",
|
|
"If the script ends normally, you will see the evaluation result printed out before the first\n",
|
|
"epoch is done."
|
|
],
|
|
"metadata": {
|
|
"id": "IcFNmCjYeIIU"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"import gym\n",
|
|
"import numpy as np\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"from tianshou.data import Collector, VectorReplayBuffer\n",
|
|
"from tianshou.env import DummyVectorEnv\n",
|
|
"from tianshou.policy import PPOPolicy\n",
|
|
"from tianshou.trainer import onpolicy_trainer\n",
|
|
"from tianshou.utils.net.common import ActorCritic, Net\n",
|
|
"from tianshou.utils.net.discrete import Actor, Critic\n",
|
|
"\n",
|
|
"import warnings\n",
|
|
"warnings.filterwarnings('ignore')\n",
|
|
"\n",
|
|
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
|
"\n",
|
|
"# environments\n",
|
|
"env = gym.make('CartPole-v0')\n",
|
|
"train_envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(20)])\n",
|
|
"test_envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)])\n",
|
|
"\n",
|
|
"# model & optimizer\n",
|
|
"net = Net(env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n",
|
|
"actor = Actor(net, env.action_space.n, device=device).to(device)\n",
|
|
"critic = Critic(net, device=device).to(device)\n",
|
|
"actor_critic = ActorCritic(actor, critic)\n",
|
|
"optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)\n",
|
|
"\n",
|
|
"# PPO policy\n",
|
|
"dist = torch.distributions.Categorical\n",
|
|
"policy = PPOPolicy(actor, critic, optim, dist, action_space=env.action_space, deterministic_eval=True)\n",
|
|
"\n",
|
|
"\n",
|
|
"# collector\n",
|
|
"train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)))\n",
|
|
"test_collector = Collector(policy, test_envs)\n",
|
|
"\n",
|
|
"# trainer\n",
|
|
"result = onpolicy_trainer(\n",
|
|
" policy,\n",
|
|
" train_collector,\n",
|
|
" test_collector,\n",
|
|
" max_epoch=10,\n",
|
|
" step_per_epoch=50000,\n",
|
|
" repeat_per_collect=10,\n",
|
|
" episode_per_test=10,\n",
|
|
" batch_size=256,\n",
|
|
" step_per_collect=2000,\n",
|
|
" stop_fn=lambda mean_reward: mean_reward >= 195,\n",
|
|
")\n",
|
|
"print(result)"
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "pxY_ZbGmkr6_",
|
|
"outputId": "b792fc24-f42c-426a-9d83-fe1a4f3f91f1"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stderr",
|
|
"text": [
|
|
"Epoch #1: 50001it [00:19, 2529.50it/s, env_step=50000, len=87, loss=80.895, loss/clip=-0.009, loss/ent=0.566, loss/vf=161.818, n/ep=15, n/st=2000, rew=87.27] \n"
|
|
]
|
|
},
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"Epoch #1: test_reward: 200.000000 ± 0.000000, best_reward: 200.000000 ± 0.000000 in #1\n",
|
|
"{'duration': '20.26s', 'train_time/model': '13.75s', 'test_step': 2159, 'test_episode': 20, 'test_time': '0.48s', 'test_speed': '4496.33 step/s', 'best_reward': 200.0, 'best_result': '200.00 ± 0.00', 'train_step': 50000, 'train_episode': 944, 'train_time/collector': '6.03s', 'train_speed': '2527.97 step/s'}\n"
|
|
]
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Let's watch its performance!\n",
|
|
"policy.eval()\n",
|
|
"result = test_collector.collect(n_episode=1, render=False)\n",
|
|
"print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))"
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "G9YEQptYvCgx",
|
|
"outputId": "2a9b5b22-be50-4bb7-ae93-af7e65e7442a"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"Final reward: 200.0, length: 200.0\n"
|
|
]
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## Tutorial Introduction\n",
|
|
"\n",
|
|
"A common DRL experiment as is shown above may require many components to work together. The agent, the\n",
|
|
"environment (possibly parallelized ones), the replay buffer and the trainer all work together to complete a\n",
|
|
"training task.\n",
|
|
"\n",
|
|
"<div align=center>\n",
|
|
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/pipeline.png\", width=500>\n",
|
|
"\n",
|
|
"</div>\n"
|
|
],
|
|
"metadata": {
|
|
"id": "xFYlcPo8fpPU"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"In Tianshou, all of these main components are factored out as different building blocks, which you\n",
|
|
"can use to create your own algorithm and finish your own experiment.\n",
|
|
"\n",
|
|
"Buiding blocks may include:\n",
|
|
"- Batch\n",
|
|
"- Replay Buffer\n",
|
|
"- Vectorized Environment Wrapper\n",
|
|
"- Policy (the agent and the training algorithm)\n",
|
|
"- Data Collector\n",
|
|
"- Trainer\n",
|
|
"- Logger\n",
|
|
"\n",
|
|
"\n",
|
|
"Check this [webpage](https://tianshou.readthedocs.io/en/master/tutorials/dqn.html) to find jupter-notebook-style tutorials that will guide you through all these\n",
|
|
"modules one by one. You can also read the [documentation](https://tianshou.readthedocs.io/en/master/) of Tianshou for more detailed explanation and\n",
|
|
"advanced usages."
|
|
],
|
|
"metadata": {
|
|
"id": "kV_uOyimj-bk"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# Further reading"
|
|
],
|
|
"metadata": {
|
|
"id": "S0mNKwH9i6Ek"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## What if I am not familar with the PPO algorithm itself?\n",
|
|
"As for the DRL algorithms themselves, we will refer you to the [Spinning up documentation](https://spinningup.openai.com/en/latest/algorithms/ppo.html), where they provide\n",
|
|
"plenty of resources and guides if you want to study the DRL algorithms. In Tianshou's toturials, we will\n",
|
|
"focus on the usages of different modules, but not the algorithms themselves."
|
|
],
|
|
"metadata": {
|
|
"id": "M3NPSUnAov4L"
|
|
}
|
|
}
|
|
]
|
|
}
|