write tutorials to specify the standard of Batch (#142)

* add doc for len exceptions

* doc move; unify is_scalar_value function

* remove some issubclass check

* bugfix for shape of Batch(a=1)

* keep moving doc

* keep writing batch tutorial

* draft version of Batch tutorial done

* improving doc

* keep improving doc

* batch tutorial done

* rename _is_number

* rename _is_scalar

* shape property do not raise exception

* restore some doc string

* grammarly [ci skip]

* grammarly + fix warning of building docs

* polish docs

* trim and re-arrange batch tutorial

* go straight to the point

* minor fix for batch doc

* add shape / len in basic usage

* keep improving tutorial

* unify _to_array_with_correct_type to remove duplicate code

* delegate type convertion to Batch.__init__

* further delegate type convertion to Batch.__init__

* bugfix for setattr

* add a _parse_value function

* remove dummy function call

* polish docs

Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
youkaichao 2020-07-19 15:20:35 +08:00 committed by Trinkle23897
parent 3a08e27ed4
commit fe5555d2a1
9 changed files with 655 additions and 300 deletions

BIN
docs/_static/images/aggregation.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 208 KiB

BIN
docs/_static/images/batch_reserve.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 74 KiB

BIN
docs/_static/images/batch_tree.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 77 KiB

View File

@ -70,6 +70,7 @@ Tianshou is still under development, you can also check out the documents in sta
tutorials/dqn
tutorials/concepts
tutorials/batch
tutorials/trick
tutorials/cheatsheet

491
docs/tutorials/batch.rst Normal file
View File

@ -0,0 +1,491 @@
.. _batch_concept:
Understand Batch
================
:class:`~tianshou.data.Batch` is the internal data structure extensively used in Tianshou. It is designed to store and manipulate hierarchical named tensors. This tutorial aims to help users correctly understand the concept and the behavior of ``Batch`` so that users can make the best of Tianshou.
The tutorial has three parts. We first explain the concept of hierarchical named tensors, and introduce basic usage of ``Batch``, followed by advanced topics of ``Batch``.
Hierarchical Named Tensors
---------------------------
.. sidebar:: The structure of a Batch shown by a tree
.. Figure:: ../_static/images/batch_tree.png
"Hierarchical named tensors" refers to a set of tensors where their names form a hierarchy. Suppose there are four tensors ``[t1, t2, t3, t4]`` with names ``[name1, name2, name3, name4]``, where ``name1`` and ``name2`` belong to the same namespace ``name0``, then the full name of tensor ``t1`` is ``name0.name1``. That is, the hierarchy lies in the names of tensors.
We can describe the structure of hierarchical named tensors using a tree in the right. There is always a "virtual root" node to represent the whole object; internal nodes are keys (names), and leaf nodes are values (scalars or tensors).
Hierarchical named tensors are needed because we have to deal with the heterogeneity of reinforcement learning problems. The abstraction of RL is very simple, just::
state, reward, done = env.step(action)
``reward`` and ``done`` are simple, they are mostly scalar values. However, the ``state`` and ``action`` vary with environments. For example, ``state`` can be simply a vector, a tensor, or a camera input combined with sensory input. In the last case, it is natural to store them as hierarchical named tensors. This hierarchy can go beyond ``state`` and ``action``: we can store ``state``, ``action``, ``reward``, and ``done`` together as hierarchical named tensors.
Note that, storing hierarchical named tensors is as easy as creating nested dictionary objects:
::
{
'done': done,
'reward': reward,
'state': {
'camera': camera,
'sensory': sensory
}
'action': {
'direct': direct,
'point_3d': point_3d,
'force': force,
}
}
The real problem is how to **manipulate them**, such as adding new transition tuples into replay buffer and dealing with their heterogeneity. ``Batch`` is designed to easily create, store, and manipulate these hierarchical named tensors.
Basic Usages
------------
Here we cover some basic usages of ``Batch``, describing what ``Batch`` contains, how to construct ``Batch`` objects and how to manipulate them.
What Does Batch Contain
^^^^^^^^^^^^^^^^^^^^^^^
The content of ``Batch`` objects can be defined by the following rules.
1. A ``Batch`` object can be an empty ``Batch()``, or have at least one key-value pairs. ``Batch()`` can be used to reserve keys, too. See :ref:`key_reservations` for this advanced usage.
2. The keys are always strings (they are names of corresponding values).
3. The values can be scalars, tensors, or Batch objects. The recurse definition makes it possible to form a hierarchy of batches.
4. Tensors are the most important values. In short, tensors are n-dimensional arrays of the same data type. We support two types of tensors: `PyTorch <https://pytorch.org/>`_ tensor type ``torch.Tensor`` and `NumPy <https://numpy.org/>`_ tensor type ``np.ndarray``.
5. Scalars are also valid values. A scalar is a single boolean, number, or object. They can be python scalar (``False``, ``1``, ``2.3``, ``None``, ``'hello'``) or NumPy scalar (``np.bool_(True)``, ``np.int32(1)``, ``np.float64(2.3)``). They just shouldn't be mixed up with Batch/dict/tensors.
.. note::
``Batch`` cannot store ``dict`` objects, because internally ``Batch`` uses ``dict`` to store data. During construction, ``dict`` objects will be automatically converted to ``Batch`` objects.
The data types of tensors are bool and numbers (any size of int and float as long as they are supported by NumPy or PyTorch). Besides, NumPy supports ndarray of objects and we take advantage of this feature to store non-number objects in ``Batch``. If one wants to store data that are neither boolean nor numbers (such as strings and sets), they can store the data in ``np.ndarray`` with the ``np.object`` data type. This way, ``Batch`` can store any type of python objects.
Construction of Batch
^^^^^^^^^^^^^^^^^^^^^
There are two ways to construct a ``Batch`` object: from a ``dict``, or using ``kwargs``. Below are some code snippets.
.. raw:: html
<details>
<summary>Construct Batch from dict</summary>
.. code-block:: python
>>> # directly passing a dict object (possibly nested) is ok
>>> data = Batch({'a': 4, 'b': [5, 5], 'c': '2312312'})
>>> # the list will automatically be converted to numpy array
>>> data.b
array([5, 5])
>>> data.b = np.array([3, 4, 5])
>>> print(data)
Batch(
a: 4,
b: array([3, 4, 5]),
c: '2312312',
)
>>> # a list of dict objects (possibly nested) will be automatically stacked
>>> data = Batch([{'a': 0.0, 'b': "hello"}, {'a': 1.0, 'b': "world"}])
>>> print(data)
Batch(
a: array([0., 1.]),
b: array(['hello', 'world'], dtype=object),
)
.. raw:: html
</details><br>
.. raw:: html
<details>
<summary>Construct Batch from kwargs</summary>
.. code-block:: python
>>> # construct a Batch with keyword arguments
>>> data = Batch(a=[4, 4], b=[5, 5], c=[None, None])
>>> print(data)
Batch(
a: array([4, 4]),
b: array([5, 5]),
c: array([None, None], dtype=object),
)
>>> # combining keyword arguments and batch_dict works fine
>>> data = Batch({'a':[4, 4], 'b':[5, 5]}, c=[None, None]) # the first argument is a dict, and 'c' is a keyword argument
>>> print(data)
Batch(
a: array([4, 4]),
b: array([5, 5]),
c: array([None, None], dtype=object),
)
>>> arr = np.zeros((3, 4))
>>> # By default, Batch only keeps the reference to the data, but it also supports data copying
>>> data = Batch(arr=arr, copy=True) # data.arr now is a copy of 'arr'
.. raw:: html
</details><br>
Data Manipulation With Batch
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Users can access the internal data by ``b.key`` or ``b[key]``, where ``b.key`` finds the sub-tree with ``key`` as the root node. If the result is a sub-tree with non-empty keys, the key-reference can be chained, i.e. ``b.key.key1.key2.key3``. When it reaches a leaf node, users get the data (scalars/tensors) stored in that ``Batch`` object.
.. raw:: html
<details>
<summary>Access data stored in Batch</summary>
.. code-block:: python
>>> data = Batch(a=4, b=[5, 5])
>>> print(data.b)
[5 5]
>>> # obj.key is equivalent to obj["key"]
>>> print(data["a"])
4
>>> # iterating over data items like a dict is supported
>>> for key, value in data.items():
>>> print(f"{key}: {value}")
a: 4
b: [5, 5]
>>> # obj.keys() and obj.values() work just like dict.keys() and dict.values()
>>> for key in data.keys():
>>> print(f"{key}")
a
b
>>> # obj.update() behaves like dict.update()
>>> # this is the same as data.c = 1; data.c = 2; data.e = 3;
>>> data.update(c=1, d=2, e=3)
>>> print(data)
Batch(
a: 4,
b: array([5, 5]),
c: 1,
d: 2,
e: 3,
)
.. raw:: html
</details><br>
.. note::
If ``data`` is a ``dict`` object, ``for x in data`` iterates over keys in the dict. However, it has a different meaning for ``Batch`` objects: ``for x in data`` iterates over ``data[0], data[1], ..., data[-1]``. An example is given below.
``Batch`` also partially reproduces the NumPy ndarray APIs. It supports advanced slicing, such as ``batch[:, i]`` so long as the slice is valid. Broadcast mechanism of NumPy works for ``Batch``, too.
.. raw:: html
<details>
<summary>Length, shape, indexing, and slicing of Batch</summary>
.. code-block:: python
>>> # initialize Batch with tensors
>>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5, -5], [1, -2]])
>>> # if values have the same length/shape, that length/shape is used for this Batch
>>> # else, check the advanced topic for details
>>> print(len(data))
2
>>> print(data.shape)
[2, 2]
>>> # access the first item of all the stored tensors, while keeping the structure of Batch
>>> print(data[0])
Batch(
a: array([0., 2.])
b: array([ 5, -5]),
)
>>> # iterates over ``data[0], data[1], ..., data[-1]``
>>> for sample in data:
>>> print(sample.a)
[0. 2.]
[1. 3.]
>>> # Advanced slicing works just fine
>>> # Arithmetic operations are passed to each value in the Batch, with broadcast enabled
>>> data[:, 1] += 1
>>> print(data)
Batch(
a: array([[0., 3.],
[1., 4.]]),
b: array([[ 5, -4]]),
)
>>> # amazingly, you can directly apply np.mean to a Batch object
>>> print(np.mean(data))
Batch(
a: 1.5,
b: -0.25,
)
>>> # directly converted to a list is also available
>>> list(data)
[Batch(
a: array([0., 3.]),
b: array([ 5, -4]),
),
Batch(
a: array([1., 4.]),
b: array([ 1, -1]),
)]
.. raw:: html
</details><br>
Stacking and concatenating multiple ``Batch`` instances, or split an instance into multiple batches, they are all easy and intuitive in Tianshou. For now, we stick to the aggregation (stack/concatenate) of homogeneous (same structure) batches. Stack/Concatenation of heterogeneous batches are discussed in :ref:`aggregation`.
.. raw:: html
<details>
<summary>Stack / Concatenate / Split of Batches</summary>
.. code-block:: python
>>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5)
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)
>>> data = Batch.stack((data_1, data_2))
>>> print(data)
Batch(
b: array([ 5, -5]),
a: array([[0., 2.],
[1., 3.]]),
)
>>> # split supports random shuffling
>>> data_split = list(data.split(1, shuffle=False))
>>> print(list(data.split(1, shuffle=False)))
[Batch(
b: array([5]),
a: array([[0., 2.]]),
), Batch(
b: array([-5]),
a: array([[1., 3.]]),
)]
>>> data_cat = Batch.cat(data_split)
>>> print(data_cat)
Batch(
b: array([ 5, -5]),
a: array([[0., 2.],
[1., 3.]]),
)
.. raw:: html
</details><br>
Advanced Topics
---------------
From here on, this tutorial focuses on advanced topics of ``Batch``, including key reservation, length/shape, and aggregation of heterogeneous batches.
.. _key_reservations:
Key Reservations
^^^^^^^^^^^^^^^^
.. sidebar:: The structure of a Batch with reserved keys
.. Figure:: ../_static/images/batch_reserve.png
In many cases, we know in the first place what keys we have, but we do not know the shape of values until we run the environment. To deal with this, Tianshou supports key reservations: **reserve a key and use a placeholder value**.
The usage is easy: just use ``Batch()`` to be the value of reserved keys.
.. code-block:: python
a = Batch(b=Batch()) # 'b' is a reserved key
# this is called hierarchical key reservation
a = Batch(b=Batch(c=Batch()), d=Batch()) # 'c' and 'd' are reserved key
# the structure of this last Batch is shown in the right figure
a = Batch(key1=tensor1, key2=tensor2, key3=Batch(key4=Batch(), key5=Batch()))
Still, we can use a tree (in the right) to show the structure of ``Batch`` objects with reserved keys, where reserved keys are special internal nodes that do not have attached leaf nodes.
.. note::
Reserved keys mean that in the future there will eventually be values attached to them. The values can be scalars, tensors, or even **Batch** objects. Understanding this is critical to understand the behavior of ``Batch`` when dealing with heterogeneous Batches.
The introduction of reserved keys gives rise to the need to check if a key is reserved. Tianshou provides ``Batch.is_empty`` to achieve this.
.. raw:: html
<details>
<summary>Examples of Batch.is_empty</summary>
.. code-block:: python
>>> Batch().is_empty()
True
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
False
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
True
>>> Batch(d=1).is_empty()
False
>>> Batch(a=np.float64(1.0)).is_empty()
False
.. raw:: html
</details><br>
The ``Batch.is_empty`` function has an option to decide whether to identify direct emptiness (just a ``Batch()``) or to identify recurse emptiness (a ``Batch`` object without any scalar/tensor leaf nodes).
.. note::
Do not get confused with ``Batch.is_empty`` and ``Batch.empty``. ``Batch.empty`` and its in-place variant ``Batch.empty_`` are used to set some values to zeros or None. Check the API documentation for further details.
Length and Shape
^^^^^^^^^^^^^^^^
The most common usage of ``Batch`` is to store a Batch of data. The term "Batch" comes from the deep learning community to denote a mini-batch of sampled data from the whole dataset. In this regard, "Batch" typically means a collection of tensors whose first dimensions are the same. Then the length of a ``Batch`` object is simply the batch-size.
If all the leaf nodes in a ``Batch`` object are tensors, but they have different lengths, they can be readily stored in ``Batch``. However, for ``Batch`` of this kind, the ``len(obj)`` seems a bit ambiguous. Currently, Tianshou returns the length of the shortest tensor, but we strongly recommend that users do not use the ``len(obj)`` operator on ``Batch`` objects with tensors of different lengths.
.. raw:: html
<details>
<summary>Examples of len and obj.shape for Batch objects</summary>
.. code-block:: python
>>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4)))
>>> data.shape
[2]
>>> len(data)
2
>>> data[0].shape
[]
>>> len(data[0])
TypeError: Object of type 'Batch' has no len()
.. raw:: html
</details><br>
.. note::
Following the convention of scientific computation, scalars have no length. If there is any scalar leaf node in a ``Batch`` object, an exception will occur when users call ``len(obj)``.
Besides, values of reserved keys are undetermined, so they have no length, neither. Or, to be specific, values of reserved keys have lengths of **any**. When there is a mix of tensors and reserved keys, the latter will be ignored in ``len(obj)`` and the minimum length of tensors is returned. When there is not any tensor in the ``Batch`` object, Tianshou raises an exception, too.
The ``obj.shape`` attribute of ``Batch`` behaves somewhat similar to ``len(obj)``:
1. If all the leaf nodes in a ``Batch`` object are tensors with the same shape, that shape is returned.
2. If all the leaf nodes in a ``Batch`` object are tensors but they have different shapes, the minimum length of each dimension is returned.
3. If there is any scalar value in a ``Batch`` object, ``obj.shape`` returns ``[]``.
4. The shape of reserved keys is undetermined, too. We treat their shape as ``[]``.
.. _aggregation:
Aggregation of Heterogeneous Batches
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In this section, we talk about aggregation operators (stack/concatenate) on heterogeneous ``Batch`` objects.
The following picture will give you an intuitive understanding of this behavior. It shows two examples of aggregation operators with heterogeneous ``Batch``. The shapes of tensors are annotated in the leaf nodes.
.. image:: ../_static/images/aggregation.png
We only consider the heterogeneity in the structure of ``Batch`` objects. The aggregation operators are eventually done by NumPy/PyTorch operators (``np.stack``, ``np.concatenate``, ``torch.stack``, ``torch.cat``). Heterogeneity in values can fail these operators (such as stacking ``np.ndarray`` with ``torch.Tensor``, or stacking tensors with different shapes) and an exception will be raised.
The behavior is natural: for keys that are not shared across all batches, batches that do not have these keys will be padded by zeros (or ``None`` if the data type is ``np.object``). It can be written in the following scripts:
::
>>> # examples of stack: a is missing key `b`, and b is missing key `a`
>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
>>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.stack([a, b])
>>> c.a.shape
(2, 4, 4)
>>> c.b.shape
(2, 4, 6)
>>> c.common.c.shape
(2, 4, 5)
>>> # None or 0 is padded with appropriate shape
>>> data_1 = Batch(a=np.array([0.0, 2.0]))
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done')
>>> data = Batch.stack((data_1, data_2))
>>> print(data)
Batch(
a: array([[0., 2.],
[1., 3.]]),
b: array([None, 'done'], dtype=object),
)
>>> # examples of cat: a is missing key `b`, and b is missing key `a`
>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
>>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.cat([a, b])
>>> c.a.shape
(7, 4)
>>> c.b.shape
(7, 3)
>>> c.common.c.shape
(7, 5)
However, there are some cases when batches are too heterogeneous that they cannot be aggregated:
::
>>> a = Batch(a=np.zeros([4, 4]))
>>> b = Batch(a=Batch(b=Batch()))
>>> # this will raise an exception
>>> c = Batch.stack([a, b])
Then how to determine if batches can be aggregated? Let's rethink the purpose of reserved keys. What is the advantage of ``a1=Batch(b=Batch())`` over ``a2=Batch()``? The only difference is that ``a1.b`` returns ``Batch()`` but ``a2.b`` raises an exception. That's to say, **we reserve keys for attribute reference**.
We say a key chain ``k=[key1, key2, ..., keyn]`` applies to ``b`` if the expression ``b.key1.key2.{...}.keyn`` is valid, and the result is ``b[k]``.
For a set of ``Batch`` objects denoted as :math:`S`, they can be aggregated if there exists a ``Batch`` object ``b`` satisfying the following rules:
1. Key chain applicability: For any object ``bi`` in :math:`S`, and any key chain ``k``, if ``bi[k]`` is valid, then ``b[k]`` is valid.
2. Type consistency: If ``bi[k]`` is not ``Batch()`` (the last key in the key chain is not a reserved key), then the type of ``b[k]`` should be the same as ``bi[k]`` (both should be scalar/tensor/non-empty Batch values).
The ``Batch`` object ``b`` satisfying these rules with the minimum number of keys determines the structure of aggregating :math:`S`. The values are relatively easy to define: for any key chain ``k`` that applies to ``b``, ``b[k]`` is the stack/concatenation of ``[bi[k] for bi in S]`` (if ``k`` does not apply to ``bi``, the appropriate size of zeros or ``None`` are filled automatically). If ``bi[k]`` are all ``Batch()``, then the aggregation result is also an empty ``Batch()``.
Miscellaneous Notes
^^^^^^^^^^^^^^^^^^^
1. ``Batch`` is serializable and therefore Pickle compatible. ``Batch`` objects can be saved to disk and later restored by the python ``pickle`` module. This pickle compatibility is especially important for distributed sampling from environments.
.. raw:: html
<details>
<summary>Batch.to_torch and Batch.to_numpy</summary>
::
>>> data = Batch(a=np.zeros((3, 4)))
>>> data.to_torch(dtype=torch.float32, device='cpu')
>>> print(data.a)
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
>>> # data.to_numpy is also available
>>> data.to_numpy()
.. raw:: html
</details><br>
2. It is often the case that the observations returned from the environment are NumPy ndarrays but the policy requires ``torch.Tensor`` for prediction and learning. In this regard, Tianshou provides helper functions to convert the stored data in-place into Numpy arrays or Torch tensors.
3. ``obj.stack_([a, b])`` is the same as ``Batch.stack([obj, a, b])``, and ``obj.cat_([a, b])`` is the same as ``Batch.cat([obj, a, b])``. Considering the frequent requirement of concatenating two ``Batch`` objects, Tianshou also supports ``obj.cat_(a)`` to be an alias of ``obj.cat_([a])``.
4. ``Batch.cat`` and ``Batch.cat_`` does not support ``axis`` argument as ``np.concatenate`` and ``torch.cat`` currently.
5. ``Batch.stack`` and ``Batch.stack_`` support the ``axis`` argument so that one can stack batches besides the first dimension. But be cautious, if there are keys that are not shared across all batches, ``stack`` with ``axis != 0`` is undefined, and will cause an exception currently.

View File

@ -14,16 +14,42 @@ Here is a more detailed description, where ``Env`` is the environment and ``Mode
:align: center
:height: 300
Data Batch
----------
Batch
-----
.. automodule:: tianshou.data.Batch
:members:
:noindex:
Tianshou provides :class:`~tianshou.data.Batch` as the internal data structure to pass any kind of data to other methods, for example, a collector gives a :class:`~tianshou.data.Batch` to policy for learning. Let's take a look at this script:
::
>>> import torch, numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5], c='2312312', d=('a', -2, -3))
>>> # the list will automatically be converted to numpy array
>>> data.b
array([5, 5])
>>> data.b = np.array([3, 4, 5])
>>> print(data)
Batch(
a: 4,
b: array([3, 4, 5]),
c: '2312312',
d: array(['a', '-2', '-3'], dtype=object),
)
>>> data = Batch(obs={'index': np.zeros((2, 3))}, act=torch.zeros((2, 2)))
>>> data[:, 1] += 6
>>> print(data[-1])
Batch(
obs: Batch(
index: array([0., 6., 0.]),
),
act: tensor([0., 6.]),
)
Data Buffer
-----------
In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair, and perform some common operations over it.
:ref:`batch_concept` is a dedicated tutorial for :class:`~tianshou.data.Batch`. We strongly recommend every user to read it so as to correctly understand and use :class:`~tianshou.data.Batch`.
Buffer
------
.. automodule:: tianshou.data.ReplayBuffer
:members:

View File

@ -62,6 +62,7 @@ def test_batch():
'd': Batch(e=np.array(3.0))}])
assert len(batch2) == 1
assert Batch().shape == []
assert Batch(a=1).shape == []
assert batch2.shape[0] == 1
with pytest.raises(IndexError):
batch2[-2]

View File

@ -29,6 +29,43 @@ def _is_batch_set(data: Any) -> bool:
return False
def _is_scalar(value: Any) -> bool:
# check if the value is a scalar
# 1. python bool object, number object: isinstance(value, Number)
# 2. numpy scalar: isinstance(value, np.generic)
# 3. python object rather than dict / Batch / tensor
# the check of dict / Batch is omitted because this only checks a value.
# a dict / Batch will eventually check their values
value = np.asanyarray(value)
return value.size == 1 and not value.shape
def _is_number(value: Any) -> bool:
# isinstance(value, Number) checks 1, 1.0, np.int(1), np.float(1.0), etc.
# isinstance(value, np.nummber) checks np.int32(1), np.float64(1.0), etc.
# isinstance(value, np.bool_) checks np.bool_(True), etc.
is_number = isinstance(value, Number)
is_number = is_number or isinstance(value, np.number)
is_number = is_number or isinstance(value, np.bool_)
return is_number
def _to_array_with_correct_type(v: Any) -> np.ndarray:
# convert the value to np.ndarray
# convert to np.object data type if neither bool nor number
v = np.asanyarray(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
if v.dtype == np.object and not v.shape:
# scalar ndarray with np.object data type is very annoying
# a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
# a is not array([{}, {}], dtype=object), and a[0]={} results in
# something very strange:
# array([{}, array({}, dtype=object)], dtype=object)
v = v.item(0)
return v
def _create_value(inst: Any, size: int, stack=True) -> Union[
'Batch', np.ndarray, torch.Tensor]:
"""
@ -37,14 +74,11 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[
of (10, 3, 5), otherwise (10, 5)
"""
has_shape = isinstance(inst, (np.ndarray, torch.Tensor))
is_scalar = \
isinstance(inst, Number) or \
issubclass(inst.__class__, np.generic) or \
(has_shape and not inst.shape)
is_scalar = _is_scalar(inst)
if not stack and is_scalar:
# here we do not consider scalar types, following the
# behavior of numpy which does not support concatenation
# of zero-dimensional arrays (scalars)
# here we do not consider scalar types, following the behavior of numpy
# which does not support concatenation of zero-dimensional arrays
# (scalars)
raise TypeError(f"cannot concatenate with {inst} which is scalar")
if has_shape:
shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
@ -78,223 +112,36 @@ def _assert_type_keys(keys):
f"keys should all be string, but got {keys}"
def _parse_value(v: Any):
if isinstance(v, (list, tuple, np.ndarray)):
if not isinstance(v, np.ndarray) and \
all(isinstance(e, torch.Tensor) for e in v):
v = torch.stack(v)
return v
v_ = _to_array_with_correct_type(v)
if v_.dtype == np.object and _is_batch_set(v):
v = Batch(v) # list of dict / Batch
else:
# normal data list (main case)
# or actually a data list with objects
v = v_
elif isinstance(v, dict):
v = Batch(v)
elif isinstance(v, (Batch, torch.Tensor)):
pass
else:
# scalar case, convert to ndarray
v = _to_array_with_correct_type(v)
return v
class Batch:
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
structure to pass any kind of data to other methods, for example, a
collector gives a :class:`~tianshou.data.Batch` to policy for learning.
Here is the usage:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5], c='2312312')
>>> # the list will automatically be converted to numpy array
>>> data.b
array([5, 5])
>>> data.b = np.array([3, 4, 5])
>>> print(data)
Batch(
a: 4,
b: array([3, 4, 5]),
c: '2312312',
)
In short, you can define a :class:`Batch` with any key-value pair.
For Numpy arrays, only data types with ``np.object``, bool, and number are
supported. For strings or other data types, however, they can be held in
``np.object`` arrays.
The current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
* ``rew`` the reward of step :math:`t` ;
* ``done`` the done flag of step :math:`t` ;
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
function returns 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
For convenience, :class:`~tianshou.data.Batch` supports the mechanism of
key reservation: one can specify a key without any value, which serves as
a placeholder for the Batch object. For example, you know there will be a
key named ``obs``, but do not know the value until the simulator runs. Then
you can reserve the key ``obs``. This is done by setting the value to
``Batch()``.
For a Batch object, we call it "incomplete" if: (i) it is ``Batch()``; (ii)
it has reserved keys; (iii) any of its sub-Batch is incomplete. Otherwise,
the Batch object is finalized.
Key reservation mechanism is convenient, but also causes some problem in
aggregation operators like ``stack`` or ``cat`` of Batch objects. We say
that Batch objects are compatible for aggregation with three cases:
1. finalized Batch objects are compatible if and only if their exists a \
way to extend keys so that their structures are exactly the same.
2. incomplete Batch objects and other finalized objects are compatible if \
their exists a way to extend keys so that incomplete Batch objects can \
have the same structure as finalized objects.
3. incomplete Batch objects themselevs are compatible if their exists a \
way to extend keys so that their structure can be the same.
In a word, incomplete Batch objects have a set of possible structures
in the future, but finalized Batch object only have a finalized structure.
Batch objects are compatible if and only if they share at least one
commonly possible structure by extending keys.
:class:`~tianshou.data.Batch` object can be initialized by a wide variety
of arguments, ranging from the key/value pairs or dictionary, to list and
Numpy arrays of :class:`dict` or Batch instances where each element is
considered as an individual sample and get stacked together:
::
>>> data = Batch([{'a': {'b': [0.0, "info"]}}])
>>> print(data[0])
Batch(
a: Batch(
b: array([0.0, 'info'], dtype=object),
),
)
:class:`~tianshou.data.Batch` has the same API as a native Python
:class:`dict`. In this regard, one can access stored data using string key,
or iterate over stored data:
::
>>> data = Batch(a=4, b=[5, 5])
>>> print(data["a"])
4
>>> for key, value in data.items():
>>> print(f"{key}: {value}")
a: 4
b: [5, 5]
:class:`~tianshou.data.Batch` also partially reproduces the Numpy API for
arrays. It also supports the advanced slicing method, such as batch[:, i],
if the index is valid. You can access or iterate over the individual
samples, if any:
::
>>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5, -5]])
>>> print(data[0])
Batch(
a: array([0., 2.])
b: array([ 5, -5]),
)
>>> for sample in data:
>>> print(sample.a)
[0. 2.]
>>> print(data.shape)
[1, 2]
>>> data[:, 1] += 1
>>> print(data)
Batch(
a: array([[0., 3.],
[1., 4.]]),
b: array([[ 5, -4]]),
)
Similarly, one can also perform simple algebra on it, and stack, split or
concatenate multiple instances:
::
>>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5)
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)
>>> data = Batch.stack((data_1, data_2))
>>> print(data)
Batch(
b: array([ 5, -5]),
a: array([[0., 2.],
[1., 3.]]),
)
>>> print(np.mean(data))
Batch(
b: 0.0,
a: array([0.5, 2.5]),
)
>>> data_split = list(data.split(1, False))
>>> print(list(data.split(1, False)))
[Batch(
b: array([5]),
a: array([[0., 2.]]),
), Batch(
b: array([-5]),
a: array([[1., 3.]]),
)]
>>> data_cat = Batch.cat(data_split)
>>> print(data_cat)
Batch(
b: array([ 5, -5]),
a: array([[0., 2.],
[1., 3.]]),
)
Note that stacking of inconsistent data is also supported. In which case,
``None`` is added in list or :class:`np.ndarray` of objects, 0 otherwise.
::
>>> data_1 = Batch(a=np.array([0.0, 2.0]))
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done')
>>> data = Batch.stack((data_1, data_2))
>>> print(data)
Batch(
a: array([[0., 2.],
[1., 3.]]),
b: array([None, 'done'], dtype=object),
)
Method ``empty_`` sets elements to 0 or ``None`` for ``np.object``.
::
>>> data.empty_()
>>> print(data)
Batch(
a: array([[0., 0.],
[0., 0.]]),
b: array([None, None], dtype=object),
)
>>> data = Batch(a=[False, True], b={'c': [2., 'st'], 'd': [1., 0.]})
>>> data[0] = Batch.empty(data[1])
>>> data
Batch(
a: array([False, True]),
b: Batch(
c: array([None, 'st']),
d: array([0., 0.]),
),
)
:meth:`~tianshou.data.Batch.shape` and :meth:`~tianshou.data.Batch.__len__`
methods are also provided to respectively get the shape and the length of
a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which
means that getting the length of a scalar Batch raises an exception.
::
>>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4)))
>>> data.shape
[2]
>>> len(data)
2
>>> data[0].shape
[]
>>> len(data[0])
TypeError: Object of type 'Batch' has no len()
Convenience helpers are available to convert in-place the stored data into
Numpy arrays or Torch tensors.
Finally, note that :class:`~tianshou.data.Batch` is serializable and
therefore Pickle compatible. This is especially important for distributed
sampling.
For a detailed description, please refer to :ref:`batch_concept`.
"""
def __init__(self,
batch_dict: Optional[Union[
dict, 'Batch', Tuple[Union[dict, 'Batch']],
@ -307,48 +154,15 @@ class Batch:
if isinstance(batch_dict, (dict, Batch)):
_assert_type_keys(batch_dict.keys())
for k, v in batch_dict.items():
if isinstance(v, (list, tuple, np.ndarray)):
v_ = None
if not isinstance(v, np.ndarray) and \
all(isinstance(e, torch.Tensor) for e in v):
self.__dict__[k] = torch.stack(v)
continue
else:
v_ = np.asanyarray(v)
if v_.dtype != np.object:
v = v_ # normal data list, this is the main case
if not issubclass(v.dtype.type,
(np.bool_, np.number)):
v = v.astype(np.object)
else:
if _is_batch_set(v):
v = Batch(v) # list of dict / Batch
else:
# this is actually a data list with objects
v = v_
self.__dict__[k] = v
elif isinstance(v, dict):
self.__dict__[k] = Batch(v)
else:
self.__dict__[k] = v
self.__dict__[k] = _parse_value(v)
elif _is_batch_set(batch_dict):
self.stack_(batch_dict)
if len(kwargs) > 0:
self.__init__(kwargs, copy=copy)
def __setattr__(self, key: str, value: Any):
"""self[key] = value"""
if isinstance(value, list):
if _is_batch_set(value):
value = Batch(value)
else:
value = np.array(value)
if not issubclass(value.dtype.type, (np.bool_, np.number)):
value = value.astype(np.object)
elif isinstance(value, dict) or isinstance(value, np.ndarray) \
and value.dtype == np.object and _is_batch_set(value):
value = Batch(value)
self.__dict__[key] = value
"""self.key = value"""
self.__dict__[key] = _parse_value(value)
def __getstate__(self):
"""Pickling interface. Only the actual data are serialized for both
@ -389,20 +203,13 @@ class Batch:
str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None:
"""Assign value to self[index]."""
if isinstance(value, (list, tuple)):
value = np.asanyarray(value)
if isinstance(value, np.ndarray):
if not issubclass(value.dtype.type, (np.bool_, np.number)):
value = value.astype(np.object)
if isinstance(index, str):
self.__dict__[index] = value
self.__dict__[index] = _parse_value(value)
return
if not isinstance(value, (dict, Batch)):
if _is_batch_set(value):
value = Batch(value)
else:
raise TypeError("Batch does not supported value type "
f"{type(value)} for item assignment.")
value = _parse_value(value)
if isinstance(value, (np.ndarray, torch.Tensor)):
raise ValueError("Batch does not supported tensor assignment."
" Use a compatible Batch or dict instead.")
if not set(value.keys()).issubset(self.__dict__.keys()):
raise KeyError(
"Creating keys is not supported by item assignment.")
@ -431,7 +238,7 @@ class Batch:
else:
self.__dict__[k] += v
return self
elif isinstance(other, (Number, np.number)):
elif _is_number(other):
for k, r in self.items():
if isinstance(r, Batch) and r.is_empty():
continue
@ -448,7 +255,7 @@ class Batch:
def __imul__(self, val: Union[Number, np.number]):
"""Algebraic multiplication with a scalar value in-place."""
assert isinstance(val, (Number, np.number)), \
assert _is_number(val), \
"Only multiplication by a number is supported."
for k, r in self.__dict__.items():
if isinstance(r, Batch) and r.is_empty():
@ -462,7 +269,7 @@ class Batch:
def __itruediv__(self, val: Union[Number, np.number]):
"""Algebraic division with a scalar value in-place."""
assert isinstance(val, (Number, np.number)), \
assert _is_number(val), \
"Only division by a number is supported."
for k, r in self.__dict__.items():
if isinstance(r, Batch) and r.is_empty():
@ -524,14 +331,7 @@ class Batch:
device = torch.device(device)
for k, v in self.items():
if isinstance(v, (np.number, np.bool_, Number, np.ndarray)):
if isinstance(v, (np.number, np.bool_, Number)):
v = np.asanyarray(v)
v = torch.from_numpy(v).to(device)
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v
elif isinstance(v, torch.Tensor):
if isinstance(v, torch.Tensor):
if dtype is not None and v.dtype != dtype or \
v.device.type != device.type or \
device.index is not None and \
@ -541,6 +341,14 @@ class Batch:
self.__dict__[k] = v.to(device)
elif isinstance(v, Batch):
v.to_torch(dtype, device)
else:
# ndarray or scalar
if not isinstance(v, np.ndarray):
v = np.asanyarray(v)
v = torch.from_numpy(v).to(device)
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v
def __cat(self,
batches: Union['Batch', List[Union[dict, 'Batch']]],
@ -586,8 +394,7 @@ class Batch:
# cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch()))
# will fail here
v = np.concatenate(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
v = _to_array_with_correct_type(v)
self.__dict__[k] = v
keys_total = set.union(*[set(b.keys()) for b in batches])
keys_reserve_or_partial = set.difference(keys_total, keys_shared)
@ -691,8 +498,7 @@ class Batch:
self.__dict__[k] = torch.stack(v, axis)
else:
v = np.stack(v, axis)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
v = _to_array_with_correct_type(v)
self.__dict__[k] = v
# all the keys
keys_total = set.union(*[set(b.keys()) for b in batches])
@ -742,7 +548,6 @@ class Batch:
(2, 4, 5)
.. note::
If there are keys that are not shared across all batches, ``stack``
with ``axis != 0`` is undefined, and will cause an exception.
"""
@ -756,6 +561,26 @@ class Batch:
"""Return an empty a :class:`~tianshou.data.Batch` object with 0 or
``None`` filled. If ``index`` is specified, it will only reset the
specific indexed-data.
::
>>> data.empty_()
>>> print(data)
Batch(
a: array([[0., 0.],
[0., 0.]]),
b: array([None, None], dtype=object),
)
>>> b={'c': [2., 'st'], 'd': [1., 0.]}
>>> data = Batch(a=[False, True], b=b)
>>> data[0] = Batch.empty(data[1])
>>> data
Batch(
a: array([False, True]),
b: Batch(
c: array([None, 'st']),
d: array([0., 0.]),
),
)
"""
for k, v in self.items():
if v is None:
@ -772,7 +597,7 @@ class Batch:
else: # scalar value
warnings.warn('You are calling Batch.empty on a NumPy scalar, '
'which may cause undefined behaviors.')
if isinstance(v, (np.number, np.bool_, Number)):
if _is_number(v):
self.__dict__[k] = v.__class__(0)
else:
self.__dict__[k] = None
@ -813,6 +638,8 @@ class Batch:
else:
raise TypeError(f"Object {v} in {self} has no len()")
if len(r) == 0:
# empty batch has the shape of any, like the tensorflow '?' shape.
# So it has no length.
raise TypeError(f"Object {self} has no len()")
return min(r)
@ -827,7 +654,7 @@ class Batch:
``cat``, while the latter is a scalar and cannot be used in ``cat``.
Another usage is in ``__len__``, where we have to skip checking the
length of recursely empty Batch.
length of recursively empty Batch.
::
>>> Batch().is_empty()
@ -857,10 +684,9 @@ class Batch:
data_shape = []
for v in self.__dict__.values():
try:
data_shape.append(v.shape)
data_shape.append(list(v.shape))
except AttributeError:
raise TypeError("No support for 'shape' method with "
f"type {type(v)} in class Batch.")
data_shape.append([])
return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \
else data_shape[0]

View File

@ -6,9 +6,19 @@ from tianshou.data.batch import Batch, _create_value
class ReplayBuffer:
""":class:`~tianshou.data.ReplayBuffer` stores data generated from
interaction between the policy and environment. It stores basically 7 types
of data, as mentioned in :class:`~tianshou.data.Batch`, based on
``numpy.ndarray``. Here is the usage:
interaction between the policy and environment. The current implementation
of Tianshou typically use 7 reserved keys in :class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
* ``rew`` the reward of step :math:`t` ;
* ``done`` the done flag of step :math:`t` ;
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` \
function returns 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
The following code snippet illustrates its usage:
::
>>> import numpy as np
@ -16,13 +26,13 @@ class ReplayBuffer:
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf)
3
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.])
>>> # but there are only three valid items, so len(buf) == 3.
>>> len(buf)
3
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})