2023-10-17 10:28:24 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								{
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								 "cells": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "_UaXOSRjDUF9"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Experiment\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Finally, we can assemble building blocks that we have came across in previous tutorials to conduct our first DRL experiment. In this experiment, we will use [PPO](https://arxiv.org/abs/1707.06347) algorithm to solve the classic CartPole task in Gym."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "2QRbCJvDHNAd"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-11-09 13:36:23 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## Experiment\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "To conduct this experiment, we need the following building blocks.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "*   Two vectorized environments, one for training and one for evaluation\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "*   A PPO agent\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "*   A replay buffer to store transition data\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "*   Two collectors to manage the data collecting process, one for training and one for evaluation\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "*   A trainer to manage the training loop\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<div align=center>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://tianshou.readthedocs.io/en/master/_images/pipeline.png\", width=500>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</div>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Let us do this step by step."
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "-Hh4E6i0Hj0I"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Preparation\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Firstly, install Tianshou if you haven't installed it before."
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "7E4EhiBeHxD5"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Import libraries we might need later."
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "editable": true,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "ao9gWJDiHgG-",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "slideshow": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "slide_type": ""
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tags": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "hide-cell",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "remove-output"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "import gymnasium as gym\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from tianshou.data import Collector, VectorReplayBuffer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from tianshou.env import DummyVectorEnv\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from tianshou.policy import PPOPolicy\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "from tianshou.trainer import OnpolicyTrainer\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "from tianshou.utils.net.common import ActorCritic, Net\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from tianshou.utils.net.discrete import Actor, Critic\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "QnRg5y7THRYw"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Environment"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "YZERKCGtH8W1"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "We create two vectorized environments both for training and testing. Since the execution time of CartPole is extremely short, there is no need to use multi-process wrappers and we simply use DummyVectorEnv."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "Mpuj5PFnDKVS"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "env = gym.make(\"CartPole-v1\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(20)])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(10)])"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "BJtt_Ya8DTAh"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Policy\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-11-09 13:36:23 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "Next we need to initialize our PPO policy. PPO is an actor-critic-style on-policy algorithm, so we have to define the actor and the critic in PPO first.\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "The actor is a neural network that shares the same network head with the critic. Both networks' input is the environment observation. The output of the actor is the action and the output of the critic is a single value, representing the value of the current policy.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Luckily, Tianshou already provides basic network modules that we can use in this experiment."
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "_Vy8uPWXP4m_"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# net is the shared head of the actor and the critic\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "net = Net(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "actor = Actor(preprocess_net=net, action_shape=env.action_space.n, device=device).to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "critic = Critic(preprocess_net=net, device=device).to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "actor_critic = ActorCritic(actor=actor, critic=critic)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# optimizer of the actor and the critic\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "Lh2-hwE5Dn9I"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Once we have defined the actor, the critic and the optimizer. We can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "OiJ2GkT0Qnbr"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "dist = torch.distributions.Categorical\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "policy = PPOPolicy(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    actor=actor,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    critic=critic,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    optim=optim,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    dist_fn=dist,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    action_space=env.action_space,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    deterministic_eval=True,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    action_scaling=False,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "okxfj6IEQ-r8"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "`deterministic_eval=True` means that we want to sample actions during training but we would like to always use the best action in evaluation. No randomness included."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "n5XAAbuBZarO"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Collector\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "We can set up the collectors now. Train collector is used to collect and store training data, so an additional replay buffer has to be passed in."
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "ezwz0qerZhQM"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_collector = Collector(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    policy=policy, env=train_envs, buffer=VectorReplayBuffer(20000, len(train_envs))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "test_collector = Collector(policy=policy, env=test_envs)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "ZaoPxOd2hm0b"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "We use `VectorReplayBuffer` here because it's more efficient to collaborate with vectorized environments, you can simply consider `VectorReplayBuffer` as a a list of ordinary replay buffers."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "qBoE9pLUiC-8"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Trainer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Finally, we can use the trainer to help us set up the training loop."
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "i45EDnpxQ8gu",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "b1666b88-0bfa-4340-868e-58611872d988"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "result = OnpolicyTrainer(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    policy=policy,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    train_collector=train_collector,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    test_collector=test_collector,\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    max_epoch=10,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    step_per_epoch=50000,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    repeat_per_collect=10,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    episode_per_test=10,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    batch_size=256,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    step_per_collect=2000,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    stop_fn=lambda mean_reward: mean_reward >= 195,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 10:28:24 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "ckgINHE2iTFR"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Results\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Print the training result."
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 10:28:24 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "tJCPgmiyiaaX",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "40123ae3-3365-4782-9563-46c43812f10f"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(result)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "A-MJ9avMibxN"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "We can also test our trained agent."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 10:28:24 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "mnMANFcciiAQ",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "6febcc1e-7265-4a75-c9dd-34e29a3e5d21"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Let's watch its performance!\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "policy.eval()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "result = test_collector.collect(n_episode=1, render=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								 ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "provenance": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "toc_visible": true
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "kernelspec": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "display_name": "Python 3 (ipykernel)",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "language": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "language_info": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "codemirror_mode": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "name": "ipython",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "version": 3
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "file_extension": ".py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "mimetype": "text/x-python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "nbconvert_exporter": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "pygments_lexer": "ipython3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "version": "3.11.5"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat_minor": 4
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								}