Closes: https://github.com/aai-institute/tianshou/issues/1116 ### API Extensions - Batch received new method: `to_torch_`. #1117 ### Breaking Changes - The method `to_torch` in `data.utils.batch.Batch` is not in-place anymore. Instead, a new method `to_torch_` does the conversion in-place. #1117
410 lines
11 KiB
Plaintext
410 lines
11 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "69y6AHvq1S3f"
|
|
},
|
|
"source": [
|
|
"# Batch\n",
|
|
"In this tutorial, we will introduce the **Batch** to you, which is the most basic data structure in Tianshou. You can consider Batch as a numpy version of python dictionary. It is also similar to pytorch's tensordict,\n",
|
|
"although with a somehow different type structure."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"editable": true,
|
|
"id": "NkfiIe_y2FI-",
|
|
"outputId": "5008275f-8f77-489a-af64-b35af4448589",
|
|
"slideshow": {
|
|
"slide_type": ""
|
|
},
|
|
"tags": [
|
|
"remove-output",
|
|
"hide-cell"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%%capture\n",
|
|
"\n",
|
|
"import pickle\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"from tianshou.data import Batch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data = Batch(a=4, b=[5, 5], c=\"2312312\", d=(\"a\", -2, -3))\n",
|
|
"print(data)\n",
|
|
"print(data.b)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "S6e6OuXe3UT-"
|
|
},
|
|
"source": [
|
|
"A batch stores all passed in data as key-value pairs, and automatically turns the value into a numpy array if possible.\n",
|
|
"\n",
|
|
"## Why do we need Batch in Tianshou?\n",
|
|
"The motivation behind the implementation of Batch module is simple. In DRL, you need to handle a lot of dictionary-format data. For instance, most algorithms would require you to store state, action, and reward data for every step when interacting with the environment. All of them can be organized as a dictionary, and the\n",
|
|
" Batch class helps Tianshou in unifying the interfaces of a diverse set of algorithms. In addition, Batch supports advanced indexing, concatenation and splitting, formatting print just like any other numpy array, which proved to be helpful for developers.\n",
|
|
"<div align=center>\n",
|
|
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/concepts_arch.png\", title=\"The data flow is converted into a Batch in Tianshou\">\n",
|
|
"\n",
|
|
"<a> Data flow is converted into a Batch in Tianshou </a>\n",
|
|
"</div>\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "_Xenx64M9HhV"
|
|
},
|
|
"source": [
|
|
"## Basic Usages"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "4YGX_f1Z9Uil"
|
|
},
|
|
"source": [
|
|
"### Initialization\n",
|
|
"Batch can be constructed directly from a python dictionary, and all data structures\n",
|
|
" will be converted to numpy arrays if possible."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "Jl3-4BRbp3MM",
|
|
"outputId": "a8b225f6-2893-4716-c694-3c2ff558b7f0"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# converted from a python library\n",
|
|
"print(\"========================================\")\n",
|
|
"batch1 = Batch({\"a\": [4, 4], \"b\": (5, 5)})\n",
|
|
"print(batch1)\n",
|
|
"\n",
|
|
"# initialization of batch2 is equivalent to batch1\n",
|
|
"print(\"========================================\")\n",
|
|
"batch2 = Batch(a=[4, 4], b=(5, 5))\n",
|
|
"print(batch2)\n",
|
|
"\n",
|
|
"# the dictionary can be nested, and it will be turned into a nested Batch\n",
|
|
"print(\"========================================\")\n",
|
|
"data = {\n",
|
|
" \"action\": np.array([1.0, 2.0, 3.0]),\n",
|
|
" \"reward\": 3.66,\n",
|
|
" \"obs\": {\n",
|
|
" \"rgb_obs\": np.zeros((3, 3)),\n",
|
|
" \"flatten_obs\": np.ones(5),\n",
|
|
" },\n",
|
|
"}\n",
|
|
"\n",
|
|
"batch3 = Batch(data, extra=\"extra_string\")\n",
|
|
"print(batch3)\n",
|
|
"# batch3.obs is also a Batch\n",
|
|
"print(type(batch3.obs))\n",
|
|
"print(batch3.obs.rgb_obs)\n",
|
|
"\n",
|
|
"# a list of dictionary/Batch will automatically be concatenated/stacked, providing convenience if you\n",
|
|
"# want to use parallelized environments to collect data.\n",
|
|
"print(\"========================================\")\n",
|
|
"batch4 = Batch([data] * 3)\n",
|
|
"print(batch4)\n",
|
|
"print(batch4.obs.rgb_obs.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "JCf6bqY3uf5L"
|
|
},
|
|
"source": [
|
|
"### Getting access to data\n",
|
|
"You can conveniently search or change the key-value pair in a Batch just as if it were a python dictionary."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "2TNIY90-vU9b",
|
|
"outputId": "de52ffe9-03c2-45f2-d95a-4071132daa4a"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch1 = Batch({\"a\": [4, 4], \"b\": (5, 5)})\n",
|
|
"print(batch1)\n",
|
|
"\n",
|
|
"# add or delete key-value pair in batch1\n",
|
|
"print(\"========================================\")\n",
|
|
"batch1.c = Batch(c1=np.arange(3), c2=False)\n",
|
|
"del batch1.a\n",
|
|
"print(batch1)\n",
|
|
"\n",
|
|
"# access value by key\n",
|
|
"print(\"========================================\")\n",
|
|
"assert batch1[\"c\"] is batch1.c\n",
|
|
"print(\"c\" in batch1)\n",
|
|
"\n",
|
|
"# traverse the Batch\n",
|
|
"print(\"========================================\")\n",
|
|
"for key, value in batch1.items():\n",
|
|
" print(str(key) + \": \" + str(value))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "bVywStbV9jD2"
|
|
},
|
|
"source": [
|
|
"### Indexing and Slicing\n",
|
|
"If all values in Batch share the same shape in certain dimensions, Batch can support array-like indexing and slicing."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "gKza3OJnzc_D",
|
|
"outputId": "4f240bfe-4a69-4c1b-b40e-983c5c4d0cbc"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Let us suppose we have collected the data from stepping from 4 environments\n",
|
|
"step_outputs = [\n",
|
|
" {\n",
|
|
" \"act\": np.random.randint(10),\n",
|
|
" \"rew\": 0.0,\n",
|
|
" \"obs\": np.ones((3, 3)),\n",
|
|
" \"info\": {\"done\": np.random.choice(2), \"failed\": False},\n",
|
|
" \"terminated\": False,\n",
|
|
" \"truncated\": False,\n",
|
|
" }\n",
|
|
" for _ in range(4)\n",
|
|
"]\n",
|
|
"batch = Batch(step_outputs)\n",
|
|
"print(batch)\n",
|
|
"print(batch.shape)\n",
|
|
"\n",
|
|
"# advanced indexing is supported, if we only want to select data in a given set of environments\n",
|
|
"print(\"========================================\")\n",
|
|
"print(batch[0])\n",
|
|
"print(batch[[0, 3]])\n",
|
|
"\n",
|
|
"# slicing is also supported\n",
|
|
"print(\"========================================\")\n",
|
|
"print(batch[-2:])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Aggregation and Splitting\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "1vUwQ-Hw9jtu"
|
|
},
|
|
"source": [
|
|
"Again, just like a numpy array. Play the example code below."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "f5UkReyn3_kb",
|
|
"outputId": "e7bb3324-7f20-4810-a328-479117efca55"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# concat batches with compatible keys\n",
|
|
"# try incompatible keys yourself if you feel curious\n",
|
|
"print(\"========================================\")\n",
|
|
"b1 = Batch(a=[{\"b\": np.float64(1.0), \"d\": Batch(e=np.array(3.0))}])\n",
|
|
"b2 = Batch(a=[{\"b\": np.float64(4.0), \"d\": {\"e\": np.array(6.0)}}])\n",
|
|
"b12_cat_out = Batch.cat([b1, b2])\n",
|
|
"print(b1)\n",
|
|
"print(b2)\n",
|
|
"print(b12_cat_out)\n",
|
|
"\n",
|
|
"# stack batches with compatible keys\n",
|
|
"# try incompatible keys yourself if you feel curious\n",
|
|
"print(\"========================================\")\n",
|
|
"b3 = Batch(a=np.zeros((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[1], [2]]))\n",
|
|
"b4 = Batch(a=np.ones((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[0], [3]]))\n",
|
|
"b34_stack = Batch.stack((b3, b4), axis=1)\n",
|
|
"print(b3)\n",
|
|
"print(b4)\n",
|
|
"print(b34_stack)\n",
|
|
"\n",
|
|
"# split the batch into small batches of size 1, breaking the order of the data\n",
|
|
"print(\"========================================\")\n",
|
|
"print(type(b34_stack.split(1)))\n",
|
|
"print(list(b34_stack.split(1, shuffle=True)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Smc_W1Cx6zRS"
|
|
},
|
|
"source": [
|
|
"### Data type converting\n",
|
|
"Besides numpy array, Batch actually also supports Torch Tensor. The usages are exactly the same. Cool, isn't it?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "Y6im_Mtb7Ody",
|
|
"outputId": "898e82c4-b940-4c35-a0f9-dedc4a9bc500"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch1 = Batch(a=np.arange(2), b=torch.zeros((2, 2)))\n",
|
|
"batch2 = Batch(a=np.arange(2), b=torch.ones((2, 2)))\n",
|
|
"batch_cat = Batch.cat([batch1, batch2, batch1])\n",
|
|
"print(batch_cat)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "1wfTUVKb6xki"
|
|
},
|
|
"source": [
|
|
"You can convert the data type easily, if you no longer want to use hybrid data type anymore."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "F7WknVs98DHD",
|
|
"outputId": "cfd0712a-1df3-4208-e6cc-9149840bdc40"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch_cat.to_numpy_()\n",
|
|
"print(batch_cat)\n",
|
|
"batch_cat.to_torch_()\n",
|
|
"print(batch_cat)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "NTFVle1-9Biz"
|
|
},
|
|
"source": [
|
|
"Batch is even serializable, just in case you may need to save it to disk or restore it."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "Lnf17OXv9YRb",
|
|
"outputId": "753753f2-3f66-4d4b-b4ff-d57f9c40d1da"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))\n",
|
|
"batch_pk = pickle.loads(pickle.dumps(batch))\n",
|
|
"print(batch_pk)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "-vPMiPZ-9kJN"
|
|
},
|
|
"source": [
|
|
"## Further Reading"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "8Oc1p8ud9kcu"
|
|
},
|
|
"source": [
|
|
"Would you like to learn more advanced usages of Batch? Feel curious about how data is organized inside the Batch? Check the [documentation](https://tianshou.readthedocs.io/en/master/03_api/data/batch.html) and other [tutorials](https://tianshou.readthedocs.io/en/master/01_tutorials/03_batch.html#) for more details."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|