2020-03-14 21:48:31 +08:00
|
|
|
import torch
|
2020-04-28 20:56:02 +08:00
|
|
|
import pprint
|
2020-03-13 17:49:22 +08:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
2020-03-11 09:09:56 +08:00
|
|
|
class Batch(object):
|
2020-04-05 18:34:45 +08:00
|
|
|
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
|
2020-04-03 21:28:12 +08:00
|
|
|
structure to pass any kind of data to other methods, for example, a
|
|
|
|
collector gives a :class:`~tianshou.data.Batch` to policy for learning.
|
|
|
|
Here is the usage:
|
|
|
|
::
|
|
|
|
|
|
|
|
>>> import numpy as np
|
|
|
|
>>> from tianshou.data import Batch
|
|
|
|
>>> data = Batch(a=4, b=[5, 5], c='2312312')
|
|
|
|
>>> data.b
|
|
|
|
[5, 5]
|
|
|
|
>>> data.b = np.array([3, 4, 5])
|
2020-04-09 21:36:53 +08:00
|
|
|
>>> print(data)
|
|
|
|
Batch(
|
|
|
|
a: 4,
|
|
|
|
b: [3 4 5],
|
|
|
|
c: 2312312,
|
|
|
|
)
|
2020-04-03 21:28:12 +08:00
|
|
|
|
|
|
|
In short, you can define a :class:`Batch` with any key-value pair. The
|
2020-04-29 17:48:48 +08:00
|
|
|
current implementation of Tianshou typically use 7 reserved keys in
|
2020-04-03 21:28:12 +08:00
|
|
|
:class:`~tianshou.data.Batch`:
|
|
|
|
|
2020-04-05 18:34:45 +08:00
|
|
|
* ``obs`` the observation of step :math:`t` ;
|
|
|
|
* ``act`` the action of step :math:`t` ;
|
|
|
|
* ``rew`` the reward of step :math:`t` ;
|
|
|
|
* ``done`` the done flag of step :math:`t` ;
|
|
|
|
* ``obs_next`` the observation of step :math:`t+1` ;
|
|
|
|
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
|
2020-04-03 21:28:12 +08:00
|
|
|
function return 4 arguments, and the last one is ``info``);
|
2020-04-29 17:48:48 +08:00
|
|
|
* ``policy`` the data computed by policy in step :math:`t`;
|
2020-04-03 21:28:12 +08:00
|
|
|
|
|
|
|
:class:`~tianshou.data.Batch` has other methods, including
|
|
|
|
:meth:`~tianshou.data.Batch.__getitem__`,
|
|
|
|
:meth:`~tianshou.data.Batch.__len__`,
|
|
|
|
:meth:`~tianshou.data.Batch.append`,
|
|
|
|
and :meth:`~tianshou.data.Batch.split`:
|
|
|
|
::
|
|
|
|
|
|
|
|
>>> data = Batch(obs=np.array([0, 11, 22]), rew=np.array([6, 6, 6]))
|
|
|
|
>>> # here we test __getitem__
|
|
|
|
>>> index = [2, 1]
|
|
|
|
>>> data[index].obs
|
|
|
|
array([22, 11])
|
|
|
|
|
|
|
|
>>> # here we test __len__
|
|
|
|
>>> len(data)
|
|
|
|
3
|
|
|
|
|
|
|
|
>>> data.append(data) # similar to list.append
|
|
|
|
>>> data.obs
|
|
|
|
array([0, 11, 22, 0, 11, 22])
|
|
|
|
|
|
|
|
>>> # split whole data into multiple small batch
|
2020-04-28 20:56:02 +08:00
|
|
|
>>> for d in data.split(size=2, shuffle=False):
|
2020-04-03 21:28:12 +08:00
|
|
|
... print(d.obs, d.rew)
|
|
|
|
[ 0 11] [6 6]
|
|
|
|
[22 0] [6 6]
|
|
|
|
[11 22] [6 6]
|
|
|
|
"""
|
2020-03-13 17:49:22 +08:00
|
|
|
|
2020-03-11 09:09:56 +08:00
|
|
|
def __init__(self, **kwargs):
|
|
|
|
super().__init__()
|
2020-04-28 20:56:02 +08:00
|
|
|
self._meta = {}
|
|
|
|
for k, v in kwargs.items():
|
|
|
|
if (isinstance(v, list) or isinstance(v, np.ndarray)) \
|
|
|
|
and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
|
|
|
|
self._meta[k] = list(v[0].keys())
|
|
|
|
for k_ in v[0].keys():
|
|
|
|
k__ = '_' + k + '@' + k_
|
|
|
|
self.__dict__[k__] = np.array([
|
|
|
|
v[i][k_] for i in range(len(v))
|
|
|
|
])
|
2020-04-29 12:14:53 +08:00
|
|
|
elif isinstance(v, dict) or isinstance(v, Batch):
|
2020-04-28 20:56:02 +08:00
|
|
|
self._meta[k] = list(v.keys())
|
|
|
|
for k_ in v.keys():
|
|
|
|
k__ = '_' + k + '@' + k_
|
|
|
|
self.__dict__[k__] = v[k_]
|
|
|
|
else:
|
|
|
|
self.__dict__[k] = kwargs[k]
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2020-03-14 21:48:31 +08:00
|
|
|
def __getitem__(self, index):
|
2020-04-04 21:02:06 +08:00
|
|
|
"""Return self[index]."""
|
2020-04-28 20:56:02 +08:00
|
|
|
if isinstance(index, str):
|
|
|
|
return self.__getattr__(index)
|
2020-03-14 21:48:31 +08:00
|
|
|
b = Batch()
|
2020-04-29 12:14:53 +08:00
|
|
|
for k in self.__dict__:
|
2020-04-28 20:56:02 +08:00
|
|
|
if k != '_meta' and self.__dict__[k] is not None:
|
2020-04-03 21:28:12 +08:00
|
|
|
b.__dict__.update(**{k: self.__dict__[k][index]})
|
2020-04-28 20:56:02 +08:00
|
|
|
b._meta = self._meta
|
2020-03-14 21:48:31 +08:00
|
|
|
return b
|
|
|
|
|
2020-04-28 20:56:02 +08:00
|
|
|
def __getattr__(self, key):
|
|
|
|
"""Return self.key"""
|
2020-04-29 12:14:53 +08:00
|
|
|
if key not in self._meta:
|
|
|
|
if key not in self.__dict__:
|
2020-04-28 20:56:02 +08:00
|
|
|
raise AttributeError(key)
|
|
|
|
return self.__dict__[key]
|
|
|
|
d = {}
|
|
|
|
for k_ in self._meta[key]:
|
|
|
|
k__ = '_' + key + '@' + k_
|
|
|
|
d[k_] = self.__dict__[k__]
|
2020-04-29 12:14:53 +08:00
|
|
|
return Batch(**d)
|
2020-04-28 20:56:02 +08:00
|
|
|
|
2020-04-09 19:53:45 +08:00
|
|
|
def __repr__(self):
|
|
|
|
"""Return str(self)."""
|
|
|
|
s = self.__class__.__name__ + '(\n'
|
|
|
|
flag = False
|
2020-04-29 12:14:53 +08:00
|
|
|
for k in sorted(list(self.__dict__) + list(self._meta)):
|
2020-04-28 20:56:02 +08:00
|
|
|
if k[0] != '_' and (self.__dict__.get(k, None) is not None or
|
2020-04-29 12:14:53 +08:00
|
|
|
k in self._meta):
|
2020-04-09 19:53:45 +08:00
|
|
|
rpl = '\n' + ' ' * (6 + len(k))
|
2020-04-28 20:56:02 +08:00
|
|
|
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
|
2020-04-09 19:53:45 +08:00
|
|
|
s += f' {k}: {obj},\n'
|
|
|
|
flag = True
|
|
|
|
if flag:
|
2020-04-29 12:14:53 +08:00
|
|
|
s += ')'
|
2020-04-09 19:53:45 +08:00
|
|
|
else:
|
2020-04-29 12:14:53 +08:00
|
|
|
s = self.__class__.__name__ + '()'
|
2020-04-09 19:53:45 +08:00
|
|
|
return s
|
|
|
|
|
2020-04-28 20:56:02 +08:00
|
|
|
def keys(self):
|
|
|
|
"""Return self.keys()."""
|
2020-04-29 12:14:53 +08:00
|
|
|
return sorted([i for i in self.__dict__ if i[0] != '_'] +
|
|
|
|
list(self._meta))
|
2020-04-28 20:56:02 +08:00
|
|
|
|
2020-05-05 13:39:51 +08:00
|
|
|
def get(self, k, d=None):
|
|
|
|
"""Return self[k] if k in self else d. d defaults to None."""
|
|
|
|
if k in self.__dict__ or k in self._meta:
|
|
|
|
return self.__getattr__(k)
|
|
|
|
return d
|
|
|
|
|
2020-04-29 17:48:48 +08:00
|
|
|
def to_numpy(self):
|
|
|
|
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
|
|
|
|
operation.
|
|
|
|
"""
|
|
|
|
for k in self.__dict__:
|
|
|
|
if isinstance(self.__dict__[k], torch.Tensor):
|
|
|
|
self.__dict__[k] = self.__dict__[k].cpu().numpy()
|
|
|
|
|
2020-03-13 17:49:22 +08:00
|
|
|
def append(self, batch):
|
2020-04-05 18:34:45 +08:00
|
|
|
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
2020-03-13 17:49:22 +08:00
|
|
|
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
2020-04-29 12:14:53 +08:00
|
|
|
for k in batch.__dict__:
|
2020-04-28 20:56:02 +08:00
|
|
|
if k == '_meta':
|
2020-04-29 12:14:53 +08:00
|
|
|
self._meta.update(batch._meta)
|
2020-04-28 20:56:02 +08:00
|
|
|
continue
|
2020-03-13 17:49:22 +08:00
|
|
|
if batch.__dict__[k] is None:
|
|
|
|
continue
|
|
|
|
if not hasattr(self, k) or self.__dict__[k] is None:
|
|
|
|
self.__dict__[k] = batch.__dict__[k]
|
|
|
|
elif isinstance(batch.__dict__[k], np.ndarray):
|
|
|
|
self.__dict__[k] = np.concatenate([
|
|
|
|
self.__dict__[k], batch.__dict__[k]])
|
2020-03-14 21:48:31 +08:00
|
|
|
elif isinstance(batch.__dict__[k], torch.Tensor):
|
|
|
|
self.__dict__[k] = torch.cat([
|
|
|
|
self.__dict__[k], batch.__dict__[k]])
|
2020-03-13 17:49:22 +08:00
|
|
|
elif isinstance(batch.__dict__[k], list):
|
|
|
|
self.__dict__[k] += batch.__dict__[k]
|
|
|
|
else:
|
2020-04-29 12:14:53 +08:00
|
|
|
s = 'No support for append with type' \
|
|
|
|
+ str(type(batch.__dict__[k])) \
|
2020-03-28 09:43:35 +08:00
|
|
|
+ 'in class Batch.'
|
|
|
|
raise TypeError(s)
|
2020-03-17 11:37:31 +08:00
|
|
|
|
2020-04-03 21:28:12 +08:00
|
|
|
def __len__(self):
|
2020-04-04 21:02:06 +08:00
|
|
|
"""Return len(self)."""
|
2020-04-03 21:28:12 +08:00
|
|
|
return min([
|
2020-04-29 12:14:53 +08:00
|
|
|
len(self.__dict__[k]) for k in self.__dict__
|
2020-04-28 20:56:02 +08:00
|
|
|
if k != '_meta' and self.__dict__[k] is not None])
|
2020-04-03 21:28:12 +08:00
|
|
|
|
2020-04-28 20:56:02 +08:00
|
|
|
def split(self, size=None, shuffle=True):
|
2020-04-05 18:34:45 +08:00
|
|
|
"""Split whole data into multiple small batch.
|
2020-04-03 21:28:12 +08:00
|
|
|
|
2020-04-06 19:36:59 +08:00
|
|
|
:param int size: if it is ``None``, it does not split the data batch;
|
2020-04-03 21:28:12 +08:00
|
|
|
otherwise it will divide the data batch with the given size.
|
2020-04-06 19:36:59 +08:00
|
|
|
Default to ``None``.
|
2020-04-28 20:56:02 +08:00
|
|
|
:param bool shuffle: randomly shuffle the entire data batch if it is
|
2020-04-06 19:36:59 +08:00
|
|
|
``True``, otherwise remain in the same. Default to ``True``.
|
2020-04-03 21:28:12 +08:00
|
|
|
"""
|
|
|
|
length = len(self)
|
2020-03-17 11:37:31 +08:00
|
|
|
if size is None:
|
|
|
|
size = length
|
|
|
|
temp = 0
|
2020-04-28 20:56:02 +08:00
|
|
|
if shuffle:
|
2020-03-20 19:52:29 +08:00
|
|
|
index = np.random.permutation(length)
|
|
|
|
else:
|
|
|
|
index = np.arange(length)
|
2020-03-17 11:37:31 +08:00
|
|
|
while temp < length:
|
2020-03-20 19:52:29 +08:00
|
|
|
yield self[index[temp:temp + size]]
|
2020-03-17 11:37:31 +08:00
|
|
|
temp += size
|