216 lines
7.7 KiB
Python
Raw Normal View History

2020-03-14 21:48:31 +08:00
import torch
2020-04-28 20:56:02 +08:00
import pprint
2020-03-13 17:49:22 +08:00
import numpy as np
2020-05-12 11:31:47 +08:00
from typing import Any, List, Union, Iterator, Optional
2020-03-13 17:49:22 +08:00
class Batch:
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
2020-04-03 21:28:12 +08:00
structure to pass any kind of data to other methods, for example, a
collector gives a :class:`~tianshou.data.Batch` to policy for learning.
Here is the usage:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5], c='2312312')
>>> data.b
[5, 5]
>>> data.b = np.array([3, 4, 5])
>>> print(data)
Batch(
a: 4,
2020-05-12 11:31:47 +08:00
b: array([3, 4, 5]),
c: '2312312',
)
2020-04-03 21:28:12 +08:00
In short, you can define a :class:`Batch` with any key-value pair. The
current implementation of Tianshou typically use 7 reserved keys in
2020-04-03 21:28:12 +08:00
:class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
* ``rew`` the reward of step :math:`t` ;
* ``done`` the done flag of step :math:`t` ;
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
2020-04-03 21:28:12 +08:00
function return 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
2020-04-03 21:28:12 +08:00
:class:`~tianshou.data.Batch` has other methods, including
:meth:`~tianshou.data.Batch.__getitem__`,
:meth:`~tianshou.data.Batch.__len__`,
:meth:`~tianshou.data.Batch.append`,
and :meth:`~tianshou.data.Batch.split`:
::
>>> data = Batch(obs=np.array([0, 11, 22]), rew=np.array([6, 6, 6]))
>>> # here we test __getitem__
>>> index = [2, 1]
>>> data[index].obs
array([22, 11])
>>> # here we test __len__
>>> len(data)
3
>>> data.append(data) # similar to list.append
>>> data.obs
array([0, 11, 22, 0, 11, 22])
>>> # split whole data into multiple small batch
2020-04-28 20:56:02 +08:00
>>> for d in data.split(size=2, shuffle=False):
2020-04-03 21:28:12 +08:00
... print(d.obs, d.rew)
[ 0 11] [6 6]
[22 0] [6 6]
[11 22] [6 6]
"""
2020-03-13 17:49:22 +08:00
2020-05-12 11:31:47 +08:00
def __init__(self, **kwargs) -> None:
2020-03-11 09:09:56 +08:00
super().__init__()
2020-04-28 20:56:02 +08:00
self._meta = {}
for k, v in kwargs.items():
2020-05-27 11:02:23 +08:00
if isinstance(v, (list, np.ndarray)) \
2020-04-28 20:56:02 +08:00
and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
self._meta[k] = list(v[0].keys())
for k_ in v[0].keys():
k__ = '_' + k + '@' + k_
self.__dict__[k__] = np.array([
v[i][k_] for i in range(len(v))
])
2020-05-27 11:02:23 +08:00
elif isinstance(v, dict):
2020-04-28 20:56:02 +08:00
self._meta[k] = list(v.keys())
for k_, v_ in v.items():
2020-04-28 20:56:02 +08:00
k__ = '_' + k + '@' + k_
self.__dict__[k__] = v_
2020-04-28 20:56:02 +08:00
else:
self.__dict__[k] = kwargs[k]
2020-03-12 22:20:33 +08:00
2020-05-12 11:31:47 +08:00
def __getitem__(self, index: Union[str, slice]) -> Union['Batch', dict]:
2020-04-04 21:02:06 +08:00
"""Return self[index]."""
2020-04-28 20:56:02 +08:00
if isinstance(index, str):
return self.__getattr__(index)
2020-03-14 21:48:31 +08:00
b = Batch()
for k, v in self.__dict__.items():
if k != '_meta' and v is not None:
b.__dict__.update(**{k: v[index]})
2020-04-28 20:56:02 +08:00
b._meta = self._meta
2020-03-14 21:48:31 +08:00
return b
2020-05-12 11:31:47 +08:00
def __getattr__(self, key: str) -> Union['Batch', Any]:
2020-04-28 20:56:02 +08:00
"""Return self.key"""
if key not in self._meta.keys():
if key not in self.__dict__:
2020-04-28 20:56:02 +08:00
raise AttributeError(key)
return self.__dict__[key]
d = {}
for k_ in self._meta[key]:
k__ = '_' + key + '@' + k_
d[k_] = self.__dict__[k__]
return Batch(**d)
2020-04-28 20:56:02 +08:00
2020-05-12 11:31:47 +08:00
def __repr__(self) -> str:
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
for k in sorted(list(self.__dict__) + list(self._meta)):
2020-04-28 20:56:02 +08:00
if k[0] != '_' and (self.__dict__.get(k, None) is not None or
k in self._meta):
rpl = '\n' + ' ' * (6 + len(k))
2020-04-28 20:56:02 +08:00
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
s += f' {k}: {obj},\n'
flag = True
if flag:
s += ')'
else:
s = self.__class__.__name__ + '()'
return s
2020-05-12 11:31:47 +08:00
def keys(self) -> List[str]:
2020-04-28 20:56:02 +08:00
"""Return self.keys()."""
return sorted(list(self._meta.keys()) +
[k for k in self.__dict__.keys() if k[0] != '_'])
2020-04-28 20:56:02 +08:00
2020-05-29 08:03:37 +08:00
def values(self) -> List[Any]:
"""Return self.values()."""
return [self[k] for k in self.keys()]
2020-05-12 11:31:47 +08:00
def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]:
2020-05-05 13:39:51 +08:00
"""Return self[k] if k in self else d. d defaults to None."""
if k in self.__dict__ or k in self._meta:
return self.__getattr__(k)
return d
def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
operation.
"""
for k, v in self.__dict__.items():
if isinstance(v, torch.Tensor):
self.__dict__[k] = v.cpu().numpy()
elif isinstance(v, Batch):
v.to_numpy()
def to_torch(self,
dtype: Optional[torch.dtype] = None,
device: Union[str, int] = 'cpu'
) -> None:
"""Change all numpy.ndarray to torch.Tensor. This is an inplace
operation.
"""
for k, v in self.__dict__.items():
if isinstance(v, np.ndarray):
v = torch.from_numpy(v).to(device)
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v
elif isinstance(v, Batch):
v.to_torch()
2020-05-12 11:31:47 +08:00
def append(self, batch: 'Batch') -> None:
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
2020-03-13 17:49:22 +08:00
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
for k, v in batch.__dict__.items():
2020-04-28 20:56:02 +08:00
if k == '_meta':
self._meta.update(batch._meta)
2020-04-28 20:56:02 +08:00
continue
if v is None:
2020-03-13 17:49:22 +08:00
continue
if not hasattr(self, k) or self.__dict__[k] is None:
self.__dict__[k] = v
elif isinstance(v, np.ndarray):
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
elif isinstance(v, torch.Tensor):
self.__dict__[k] = torch.cat([self.__dict__[k], v])
elif isinstance(v, list):
self.__dict__[k] += v
2020-03-13 17:49:22 +08:00
else:
s = f'No support for append with type \
{type(v)} in class Batch.'
2020-03-28 09:43:35 +08:00
raise TypeError(s)
2020-03-17 11:37:31 +08:00
2020-05-12 11:31:47 +08:00
def __len__(self) -> int:
2020-04-04 21:02:06 +08:00
"""Return len(self)."""
return min([len(v) for k, v in self.__dict__.items()
if k != '_meta' and v is not None])
2020-04-03 21:28:12 +08:00
2020-05-12 11:31:47 +08:00
def split(self, size: Optional[int] = None,
2020-05-16 20:08:32 +08:00
shuffle: bool = True) -> Iterator['Batch']:
"""Split whole data into multiple small batch.
2020-04-03 21:28:12 +08:00
2020-04-06 19:36:59 +08:00
:param int size: if it is ``None``, it does not split the data batch;
2020-04-03 21:28:12 +08:00
otherwise it will divide the data batch with the given size.
2020-04-06 19:36:59 +08:00
Default to ``None``.
2020-04-28 20:56:02 +08:00
:param bool shuffle: randomly shuffle the entire data batch if it is
2020-04-06 19:36:59 +08:00
``True``, otherwise remain in the same. Default to ``True``.
2020-04-03 21:28:12 +08:00
"""
length = len(self)
2020-03-17 11:37:31 +08:00
if size is None:
size = length
2020-04-28 20:56:02 +08:00
if shuffle:
indices = np.random.permutation(length)
2020-03-20 19:52:29 +08:00
else:
indices = np.arange(length)
for idx in np.arange(0, length, size):
yield self[indices[idx:(idx + size)]]