124 lines
4.3 KiB
Python
Raw Normal View History

2020-03-14 21:48:31 +08:00
import torch
2020-03-13 17:49:22 +08:00
import numpy as np
2020-03-11 09:09:56 +08:00
class Batch(object):
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
2020-04-03 21:28:12 +08:00
structure to pass any kind of data to other methods, for example, a
collector gives a :class:`~tianshou.data.Batch` to policy for learning.
Here is the usage:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5], c='2312312')
>>> data.b
[5, 5]
>>> data.b = np.array([3, 4, 5])
>>> len(data.b)
3
>>> data.b[-1]
5
In short, you can define a :class:`Batch` with any key-value pair. The
current implementation of Tianshou typically use 6 keys in
:class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
* ``rew`` the reward of step :math:`t` ;
* ``done`` the done flag of step :math:`t` ;
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
2020-04-03 21:28:12 +08:00
function return 4 arguments, and the last one is ``info``);
:class:`~tianshou.data.Batch` has other methods, including
:meth:`~tianshou.data.Batch.__getitem__`,
:meth:`~tianshou.data.Batch.__len__`,
:meth:`~tianshou.data.Batch.append`,
and :meth:`~tianshou.data.Batch.split`:
::
>>> data = Batch(obs=np.array([0, 11, 22]), rew=np.array([6, 6, 6]))
>>> # here we test __getitem__
>>> index = [2, 1]
>>> data[index].obs
array([22, 11])
>>> # here we test __len__
>>> len(data)
3
>>> data.append(data) # similar to list.append
>>> data.obs
array([0, 11, 22, 0, 11, 22])
>>> # split whole data into multiple small batch
>>> for d in data.split(size=2, permute=False):
... print(d.obs, d.rew)
[ 0 11] [6 6]
[22 0] [6 6]
[11 22] [6 6]
"""
2020-03-13 17:49:22 +08:00
2020-03-11 09:09:56 +08:00
def __init__(self, **kwargs):
super().__init__()
2020-03-12 22:20:33 +08:00
self.__dict__.update(kwargs)
2020-03-14 21:48:31 +08:00
def __getitem__(self, index):
2020-04-04 21:02:06 +08:00
"""Return self[index]."""
2020-03-14 21:48:31 +08:00
b = Batch()
for k in self.__dict__.keys():
if self.__dict__[k] is not None:
2020-04-03 21:28:12 +08:00
b.__dict__.update(**{k: self.__dict__[k][index]})
2020-03-14 21:48:31 +08:00
return b
2020-03-13 17:49:22 +08:00
def append(self, batch):
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
2020-03-13 17:49:22 +08:00
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
for k in batch.__dict__.keys():
if batch.__dict__[k] is None:
continue
if not hasattr(self, k) or self.__dict__[k] is None:
self.__dict__[k] = batch.__dict__[k]
elif isinstance(batch.__dict__[k], np.ndarray):
self.__dict__[k] = np.concatenate([
self.__dict__[k], batch.__dict__[k]])
2020-03-14 21:48:31 +08:00
elif isinstance(batch.__dict__[k], torch.Tensor):
self.__dict__[k] = torch.cat([
self.__dict__[k], batch.__dict__[k]])
2020-03-13 17:49:22 +08:00
elif isinstance(batch.__dict__[k], list):
self.__dict__[k] += batch.__dict__[k]
else:
2020-03-28 09:43:35 +08:00
s = 'No support for append with type'\
+ str(type(batch.__dict__[k]))\
+ 'in class Batch.'
raise TypeError(s)
2020-03-17 11:37:31 +08:00
2020-04-03 21:28:12 +08:00
def __len__(self):
2020-04-04 21:02:06 +08:00
"""Return len(self)."""
2020-04-03 21:28:12 +08:00
return min([
2020-03-17 11:37:31 +08:00
len(self.__dict__[k]) for k in self.__dict__.keys()
if self.__dict__[k] is not None])
2020-04-03 21:28:12 +08:00
def split(self, size=None, permute=True):
"""Split whole data into multiple small batch.
2020-04-03 21:28:12 +08:00
2020-04-06 19:36:59 +08:00
:param int size: if it is ``None``, it does not split the data batch;
2020-04-03 21:28:12 +08:00
otherwise it will divide the data batch with the given size.
2020-04-06 19:36:59 +08:00
Default to ``None``.
:param bool permute: randomly shuffle the entire data batch if it is
``True``, otherwise remain in the same. Default to ``True``.
2020-04-03 21:28:12 +08:00
"""
length = len(self)
2020-03-17 11:37:31 +08:00
if size is None:
size = length
temp = 0
2020-03-20 19:52:29 +08:00
if permute:
index = np.random.permutation(length)
else:
index = np.arange(length)
2020-03-17 11:37:31 +08:00
while temp < length:
2020-03-20 19:52:29 +08:00
yield self[index[temp:temp + size]]
2020-03-17 11:37:31 +08:00
temp += size