2023-10-17 10:28:24 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								{
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								 "cells": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "M98bqxdMsTXK"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Collector\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-11-09 13:36:23 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "From its literal meaning, we can easily know that the Collector in Tianshou is used to collect training data. More specifically, the Collector controls the interaction between Policy (agent) and the environment. It also helps save the interaction data into the ReplayBuffer and returns episode statistics.\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<center>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</center>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "OX5cayLv4Ziu"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-11-09 13:36:23 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## Usages\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "Collector can be used both for training (data collecting) and evaluation in Tianshou."
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "Z6XKbj28u8Ze"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-11-09 13:36:23 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "### Policy evaluation\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "We need to evaluate our trained policy from time to time in DRL experiments. Collector can help us with this.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-11-09 13:36:23 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "First we have to initialize a Collector with an (vectorized) environment and a given policy (agent)."
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "editable": true,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "w8t9ubO7u69J",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "slideshow": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "slide_type": ""
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tags": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "hide-cell",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "remove-output"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "import gymnasium as gym\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from tianshou.data import Collector\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from tianshou.env import DummyVectorEnv\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from tianshou.policy import PGPolicy\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from tianshou.utils.net.common import Net\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from tianshou.utils.net.discrete import Actor\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "from tianshou.data import VectorReplayBuffer"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "env = gym.make(\"CartPole-v1\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# model\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "net = Net(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    env.observation_space.shape,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    hidden_sizes=[\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        16,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    ],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "actor = Actor(net, env.action_space.shape)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "optim = torch.optim.Adam(actor.parameters(), lr=0.0003)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "policy = PGPolicy(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    actor=actor,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    optim=optim,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    dist_fn=torch.distributions.Categorical,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    action_space=env.action_space,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    action_scaling=False,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "test_collector = Collector(policy, test_envs)"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "wmt8vuwpzQdR"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-11-09 13:36:23 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "Now we would like to collect 9 episodes of data to test how our initialized Policy performs."
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 10:28:24 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 10:28:24 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "9SuT6MClyjyH",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "1e48f13b-c1fe-4fc2-ca1b-669485efdcae"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "collect_result = test_collector.collect(n_episode=9)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(collect_result)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Rewards of 9 episodes are {}\".format(collect_result[\"rews\"]))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Average episode reward is {}.\".format(collect_result[\"rew\"]))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Average episode length is {}.\".format(collect_result[\"len\"]))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 10:28:24 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "zX9AQY0M0R3C"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Now we wonder what is the performance of a random policy."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "UEcs8P8P0RLt",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "85f02f9d-b79b-48b2-99c6-36a1602f0884"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Reset the collector\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "test_collector.reset()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "collect_result = test_collector.collect(n_episode=9, random=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(collect_result)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Rewards of 9 episodes are {}\".format(collect_result[\"rews\"]))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Average episode reward is {}.\".format(collect_result[\"rew\"]))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Average episode length is {}.\".format(collect_result[\"len\"]))"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "sKQRTiG10ljU"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-11-09 13:36:23 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "Seems that an initialized policy performs even worse than a random policy without any training."
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "8RKmHIoG1A1k"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-11-09 13:36:23 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "### Data Collecting\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "Data collecting is mostly used during training, when we need to store the collected data in a ReplayBuffer."
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "editable": true,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "CB9XB9bF1YPC",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "slideshow": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "slide_type": ""
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tags": []
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_env_num = 4\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "buffer_size = 100\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "replaybuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_collector = Collector(policy, train_envs, replaybuffer)"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "rWKDazA42IUQ"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Now we can collect 50 steps of data, which will be automatically saved in the replay buffer. You can still choose to collect a certain number of episodes rather than steps. Try it yourself."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 10:28:24 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "-fUtQOnM2Yi1",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "dceee987-433e-4b75-ed9e-823c20a9e1c2"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(len(replaybuffer))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "collect_result = train_collector.collect(n_step=50)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(len(replaybuffer))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(collect_result)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 10:28:24 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "EWO4A7plefwM",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "9a6f36d1-2b84-49b0-a03d-a8ebe8acadbf"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for i in range(13):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print(i, replaybuffer.next(i))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 10:28:24 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "HW8PpWH9fLCo",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "7ca70c50-23b9-4405-9e42-2e5771cd9c78"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "replaybuffer.sample(10)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "8NP7lOBU3-VS"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-11-09 13:36:23 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## Further Reading\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "The above collector actually collects 52 data at a time because 52 % 4 = 0. There is one asynchronous collector which allows you collect exactly 50 steps. Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.data.html#asynccollector) for details."
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
									
										
										
										
											2023-10-26 16:27:59 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								 ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "provenance": []
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "kernelspec": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "display_name": "Python 3 (ipykernel)",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "language": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "language_info": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "codemirror_mode": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "name": "ipython",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "version": 3
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "file_extension": ".py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "mimetype": "text/x-python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "nbconvert_exporter": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "pygments_lexer": "ipython3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "version": "3.11.5"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat_minor": 4
							 
						 
					
						
							
								
									
										
										
										
											2023-10-17 13:59:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								}