2023-10-17 10:28:24 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								{
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								 "cells": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "S3-tJZy35Ck_"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Trainer\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "Trainer is the highest-level encapsulation in Tianshou. It controls the training loop and the evaluation method. It also controls the interaction between the Collector and the Policy, with the ReplayBuffer serving as the media.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-11-15 15:50:06 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<center>\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-11-17 11:33:44 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=../_static/images/structure.svg></img>\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-11-15 15:50:06 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "</center>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "ifsEQMzZ6mmz"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-11-09 13:36:23 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## Usages\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "In Tianshou v0.5.1, there are three types of Trainer. They are designed to be used in  on-policy training, off-policy training and offline training respectively. We will use on-policy trainer as an example and leave the other two for further reading."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "XfsuU2AAE52C"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-11-09 13:36:23 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "### Pseudocode\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-11-15 15:50:06 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<center>\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-11-17 11:33:44 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=../_static/images/pseudocode_off_policy.svg></img>\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-11-15 15:50:06 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "</center>\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "For the on-policy trainer, the main difference is that we clear the buffer after Line 10."
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "Hcp_o0CCFz12"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-11-09 13:36:23 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "### Training without trainer\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "As we have learned the usages of the Collector and the Policy, it's possible that we write our own training logic.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "First, let us create the instances of Environment, ReplayBuffer, Policy and Collector."
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "editable": true,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "do-xZ-8B7nVH",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "slideshow": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "slide_type": ""
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tags": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "hide-cell",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "remove-output"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "import gymnasium as gym\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from tianshou.data import Collector, VectorReplayBuffer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from tianshou.env import DummyVectorEnv\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from tianshou.policy import PGPolicy\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from tianshou.utils.net.common import Net\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from tianshou.utils.net.discrete import Actor\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "from tianshou.trainer import OnpolicyTrainer"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "train_env_num = 4\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "buffer_size = (\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    2000  # Since REINFORCE is an on-policy algorithm, we don't need a very large buffer size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Create the environments, used for training and evaluation\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "env = gym.make(\"CartPole-v1\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Create the Policy instance\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "net = Net(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    env.observation_space.shape,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    hidden_sizes=[\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        16,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    ],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "actor = Actor(net, env.action_space.shape)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "optim = torch.optim.Adam(actor.parameters(), lr=0.001)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "policy = PGPolicy(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    actor=actor,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    optim=optim,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    dist_fn=torch.distributions.Categorical,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    action_space=env.action_space,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    action_scaling=False,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Create the replay buffer and the collector\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "replaybuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "test_collector = Collector(policy, test_envs)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_collector = Collector(policy, train_envs, replaybuffer)"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "wiEGiBgQIiFM"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Now, we can try training our policy network. The logic is simple. We collect some data into the buffer and then we use the data to train our policy."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 10:28:24 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 10:28:24 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "JMUNPN5SI_kd",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "7d68323c-0322-4b82-dafb-7c7f63e7a26d"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_collector.reset()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_envs.reset()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "test_collector.reset()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "test_envs.reset()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "replaybuffer.reset()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for i in range(10):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    evaluation_result = test_collector.collect(n_episode=10)\n",
							 
						 
					
						
							
								
									
										
											 
										
											
												Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
  return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
											 
										 
										
											2023-12-30 11:09:03 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    train_collector.collect(n_step=2000)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # 0 means taking all data stored in train_collector.buffer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    policy.update(0, train_collector.buffer, batch_size=512, repeat=1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    train_collector.reset_buffer(keep_statistics=True)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 10:28:24 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "QXBHIBckMs_2"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "The evaluation reward doesn't seem to improve. That is simply because we haven't trained it for enough time. Plus, the network size is too small and REINFORCE algorithm is actually not very stable. Don't worry, we will solve this problem in the end. Still we get some idea on how to start a training loop."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "p-7U_cwgF5Ej"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-11-09 13:36:23 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "### Training with trainer\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "The trainer does almost the same thing. The only difference is that it has considered many details and is more modular."
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "vcvw9J8RNtFE",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-14 19:31:53 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "b483fa8b-2a57-4051-a3d0-6d8162d948c5",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tags": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "remove-output"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_collector.reset()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_envs.reset()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "test_collector.reset()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "test_envs.reset()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "replaybuffer.reset()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "result = OnpolicyTrainer(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    policy=policy,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    train_collector=train_collector,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    test_collector=test_collector,\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    max_epoch=10,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    step_per_epoch=1,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    repeat_per_collect=1,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    episode_per_test=10,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    step_per_collect=2000,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    batch_size=512,\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-14 19:31:53 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    ").run()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tags": []
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(result)"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "_j3aUJZQ7nml"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-11-09 13:36:23 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## Further Reading\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### Logger usages\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "Tianshou provides experiment loggers that are both tensorboard- and wandb-compatible. It also has a BaseLogger Class which allows you to self-define your own logger. Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.utils.html#tianshou.utils.BaseLogger) for details.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-11-09 13:36:23 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "### Learn more about the APIs of Trainers\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "[documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.trainer.html)"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "collapsed_sections": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "S3-tJZy35Ck_",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "XfsuU2AAE52C",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "p-7U_cwgF5Ej",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "_j3aUJZQ7nml"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "provenance": []
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "kernelspec": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "display_name": "Python 3 (ipykernel)",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "language": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "language_info": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "codemirror_mode": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "name": "ipython",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "version": 3
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "file_extension": ".py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "mimetype": "text/x-python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "nbconvert_exporter": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "pygments_lexer": "ipython3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "version": "3.11.5"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								 },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat_minor": 4
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								}