433 lines
16 KiB
Python
Raw Normal View History

2020-03-14 21:48:31 +08:00
import torch
import copy
2020-04-28 20:56:02 +08:00
import pprint
2020-06-20 22:23:12 +08:00
import warnings
2020-03-13 17:49:22 +08:00
import numpy as np
from numbers import Number
2020-06-23 17:37:26 +02:00
from typing import Any, List, Tuple, Union, Iterator, Optional
2020-03-13 17:49:22 +08:00
# Disable pickle warning related to torch, since it has been removed
# on torch master branch. See Pull Request #39003 for details:
# https://github.com/pytorch/pytorch/pull/39003
warnings.filterwarnings(
"ignore", message="pickle support for Storage will be removed in 1.5.")
2020-03-13 17:49:22 +08:00
def _is_batch_set(data: Any) -> bool:
if isinstance(data, (list, tuple)):
if len(data) > 0 and isinstance(data[0], (dict, Batch)):
return True
elif isinstance(data, np.ndarray):
if isinstance(data.item(0), (dict, Batch)):
return True
return False
def _valid_bounds(length: int, index: Union[
slice, int, np.integer, np.ndarray, List[int]]) -> bool:
if isinstance(index, (int, np.integer)):
return -length <= index and index < length
elif isinstance(index, (list, np.ndarray)):
return _valid_bounds(length, np.min(index)) and \
_valid_bounds(length, np.max(index))
elif isinstance(index, slice):
if index.start is not None:
start_valid = _valid_bounds(length, index.start)
else:
start_valid = True
if index.stop is not None:
stop_valid = _valid_bounds(length, index.stop - 1)
else:
stop_valid = True
return start_valid and stop_valid
class Batch:
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
2020-04-03 21:28:12 +08:00
structure to pass any kind of data to other methods, for example, a
collector gives a :class:`~tianshou.data.Batch` to policy for learning.
Here is the usage:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5], c='2312312')
>>> data.b
[5, 5]
>>> data.b = np.array([3, 4, 5])
>>> print(data)
Batch(
a: 4,
2020-05-12 11:31:47 +08:00
b: array([3, 4, 5]),
c: '2312312',
)
2020-04-03 21:28:12 +08:00
In short, you can define a :class:`Batch` with any key-value pair. The
current implementation of Tianshou typically use 7 reserved keys in
2020-04-03 21:28:12 +08:00
:class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
* ``rew`` the reward of step :math:`t` ;
* ``done`` the done flag of step :math:`t` ;
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
2020-04-03 21:28:12 +08:00
function return 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
2020-04-03 21:28:12 +08:00
:class:`~tianshou.data.Batch` has other methods, including
:meth:`~tianshou.data.Batch.__getitem__`,
:meth:`~tianshou.data.Batch.__len__`,
:meth:`~tianshou.data.Batch.append`,
and :meth:`~tianshou.data.Batch.split`:
::
>>> data = Batch(obs=np.array([0, 11, 22]), rew=np.array([6, 6, 6]))
>>> # here we test __getitem__
>>> index = [2, 1]
>>> data[index].obs
array([22, 11])
>>> # here we test __len__
>>> len(data)
3
>>> data.append(data) # similar to list.append
>>> data.obs
array([0, 11, 22, 0, 11, 22])
>>> # split whole data into multiple small batch
2020-04-28 20:56:02 +08:00
>>> for d in data.split(size=2, shuffle=False):
2020-04-03 21:28:12 +08:00
... print(d.obs, d.rew)
[ 0 11] [6 6]
[22 0] [6 6]
[11 22] [6 6]
"""
2020-03-13 17:49:22 +08:00
def __init__(self,
batch_dict: Optional[Union[
dict, 'Batch', Tuple[Union[dict, 'Batch']],
List[Union[dict, 'Batch']], np.ndarray]] = None,
**kwargs) -> None:
if _is_batch_set(batch_dict):
for k, v in zip(batch_dict[0].keys(),
zip(*[e.values() for e in batch_dict])):
if isinstance(v[0], dict) or _is_batch_set(v[0]):
self.__dict__[k] = Batch(v)
elif isinstance(v[0], (np.generic, np.ndarray)):
self.__dict__[k] = np.stack(v, axis=0)
elif isinstance(v[0], torch.Tensor):
self.__dict__[k] = torch.stack(v, dim=0)
elif isinstance(v[0], Batch):
self.__dict__[k] = Batch.stack(v)
else:
self.__dict__[k] = np.array(v)
elif isinstance(batch_dict, (dict, Batch)):
for k, v in batch_dict.items():
if isinstance(v, dict) or _is_batch_set(v):
self.__dict__[k] = Batch(v)
else:
self.__dict__[k] = v
if len(kwargs) > 0:
self.__init__(kwargs)
2020-03-12 22:20:33 +08:00
def __getstate__(self):
"""Pickling interface. Only the actual data are serialized
for both efficiency and simplicity.
"""
state = {}
for k, v in self.items():
if isinstance(v, Batch):
v = v.__getstate__()
state[k] = v
return state
def __setstate__(self, state):
"""Unpickling interface. At this point, self is an empty Batch
instance that has not been initialized, so it can safely be
initialized by the pickle state.
"""
self.__init__(**state)
def __getitem__(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]]) -> 'Batch':
2020-04-04 21:02:06 +08:00
"""Return self[index]."""
2020-04-28 20:56:02 +08:00
if isinstance(index, str):
return self.__dict__[index]
if not _valid_bounds(len(self), index):
raise IndexError(
f"Index {index} out of bounds for Batch of len {len(self)}.")
else:
b = Batch()
for k, v in self.items():
if isinstance(v, Batch) and len(v.__dict__) == 0:
b.__dict__[k] = Batch()
else:
b.__dict__[k] = v[index]
return b
2020-03-14 21:48:31 +08:00
def __setitem__(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None:
if isinstance(index, str):
self.__dict__[index] = value
return
if not isinstance(value, (dict, Batch)):
raise TypeError("Batch does not supported value type "
f"{type(value)} for item assignment.")
if not set(value.keys()).issubset(self.__dict__.keys()):
raise KeyError(
"Creating keys is not supported by item assignment.")
for key, val in self.items():
try:
self.__dict__[key][index] = value[key]
except KeyError:
if isinstance(val, Batch):
self.__dict__[key][index] = Batch()
elif isinstance(val, np.ndarray) and \
val.dtype == np.integer:
# Fallback for np.array of integer,
# since neither None or nan is supported.
self.__dict__[key][index] = 0
else:
self.__dict__[key][index] = None
def __iadd__(self, val: Union['Batch', Number]):
if isinstance(val, Batch):
for (k, r), v in zip(self.__dict__.items(),
val.__dict__.values()):
if r is None:
continue
elif isinstance(r, list):
self.__dict__[k] = [r_ + v_ for r_, v_ in zip(r, v)]
else:
self.__dict__[k] += v
return self
elif isinstance(val, Number):
for k, r in self.items():
if r is None:
continue
elif isinstance(r, list):
self.__dict__[k] = [r_ + val for r_ in r]
else:
self.__dict__[k] += val
return self
else:
raise TypeError("Only addition of Batch or number is supported.")
def __add__(self, val: Union['Batch', Number]):
return copy.deepcopy(self).__iadd__(val)
def __imul__(self, val: Number):
assert isinstance(val, Number), \
"Only multiplication by a number is supported."
for k in self.__dict__.keys():
self.__dict__[k] *= val
return self
def __mul__(self, val: Number):
return copy.deepcopy(self).__imul__(val)
def __itruediv__(self, val: Number):
assert isinstance(val, Number), \
"Only division by a number is supported."
for k in self.__dict__.keys():
self.__dict__[k] /= val
return self
def __truediv__(self, val: Number):
return copy.deepcopy(self).__itruediv__(val)
2020-04-28 20:56:02 +08:00
2020-05-12 11:31:47 +08:00
def __repr__(self) -> str:
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
for k, v in self.items():
rpl = '\n' + ' ' * (6 + len(k))
obj = pprint.pformat(v).replace('\n', rpl)
s += f' {k}: {obj},\n'
flag = True
if flag:
s += ')'
else:
s = self.__class__.__name__ + '()'
return s
2020-05-12 11:31:47 +08:00
def keys(self) -> List[str]:
2020-04-28 20:56:02 +08:00
"""Return self.keys()."""
return self.__dict__.keys()
2020-04-28 20:56:02 +08:00
2020-05-29 08:03:37 +08:00
def values(self) -> List[Any]:
"""Return self.values()."""
return self.__dict__.values()
def items(self) -> List[Tuple[str, Any]]:
"""Return self.items()."""
return self.__dict__.items()
2020-05-29 08:03:37 +08:00
2020-05-12 11:31:47 +08:00
def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]:
2020-05-05 13:39:51 +08:00
"""Return self[k] if k in self else d. d defaults to None."""
return self.__dict__.get(k, d)
2020-05-05 13:39:51 +08:00
def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray. This is an in-place
operation.
"""
for k, v in self.items():
if isinstance(v, torch.Tensor):
self.__dict__[k] = v.detach().cpu().numpy()
elif isinstance(v, Batch):
v.to_numpy()
def to_torch(self,
dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = 'cpu'
) -> None:
"""Change all numpy.ndarray to torch.Tensor. This is an in-place
operation.
"""
if not isinstance(device, torch.device):
device = torch.device(device)
for k, v in self.items():
if isinstance(v, (np.generic, np.ndarray)):
v = torch.from_numpy(v).to(device)
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v
if isinstance(v, torch.Tensor):
if dtype is not None and v.dtype != dtype:
must_update_tensor = True
elif v.device.type != device.type:
must_update_tensor = True
elif device.index is not None and \
device.index != v.device.index:
must_update_tensor = True
else:
must_update_tensor = False
if must_update_tensor:
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v.to(device)
elif isinstance(v, Batch):
v.to_torch(dtype, device)
2020-05-12 11:31:47 +08:00
def append(self, batch: 'Batch') -> None:
2020-06-20 22:23:12 +08:00
warnings.warn('Method append will be removed soon, please use '
':meth:`~tianshou.data.Batch.cat`')
return self.cat_(batch)
2020-06-20 22:23:12 +08:00
def cat_(self, batch: 'Batch') -> None:
2020-06-20 22:23:12 +08:00
"""Concatenate a :class:`~tianshou.data.Batch` object to current
batch.
"""
assert isinstance(batch, Batch), \
'Only Batch is allowed to be concatenated in-place!'
for k, v in batch.items():
if v is None:
2020-03-13 17:49:22 +08:00
continue
if not hasattr(self, k) or self.__dict__[k] is None:
self.__dict__[k] = copy.deepcopy(v)
elif isinstance(v, np.ndarray) and v.ndim > 0:
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
elif isinstance(v, torch.Tensor):
self.__dict__[k] = torch.cat([self.__dict__[k], v])
elif isinstance(v, list):
self.__dict__[k] += copy.deepcopy(v)
2020-06-20 22:03:22 +08:00
elif isinstance(v, Batch):
self.__dict__[k].cat_(v)
2020-03-13 17:49:22 +08:00
else:
s = 'No support for method "cat" with type '\
f'{type(v)} in class Batch.'
2020-03-28 09:43:35 +08:00
raise TypeError(s)
2020-03-17 11:37:31 +08:00
@classmethod
def cat(cls, batches: List['Batch']) -> 'Batch':
"""Concatenate a :class:`~tianshou.data.Batch` object into a
single new batch.
"""
assert isinstance(batches, (tuple, list)), \
'Only list of Batch instances is allowed to be '\
'concatenated out-of-place!'
batch = cls()
for batch_ in batches:
batch.cat_(batch_)
return batch
@classmethod
def stack(cls, batches: List['Batch'], axis: int = 0) -> 'Batch':
"""Stack a :class:`~tianshou.data.Batch` object into a
single new batch.
"""
assert isinstance(batches, (tuple, list)), \
'Only list of Batch instances is allowed to be '\
'stacked out-of-place!'
if axis == 0:
return cls(batches)
else:
batch = Batch()
for k, v in zip(batches[0].keys(),
zip(*[e.values() for e in batches])):
if isinstance(v[0], (np.generic, np.ndarray, list)):
batch.__dict__[k] = np.stack(v, axis)
elif isinstance(v[0], torch.Tensor):
batch.__dict__[k] = torch.stack(v, axis)
elif isinstance(v[0], Batch):
batch.__dict__[k] = Batch.stack(v, axis)
else:
s = 'No support for method "stack" with type '\
f'{type(v[0])} in class Batch and axis != 0.'
raise TypeError(s)
return batch
2020-05-12 11:31:47 +08:00
def __len__(self) -> int:
2020-04-04 21:02:06 +08:00
"""Return len(self)."""
r = []
for v in self.__dict__.values():
if isinstance(v, Batch) and len(v.__dict__) == 0:
continue
elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
r.append(len(v))
else:
raise TypeError("Object of type 'Batch' has no len()")
if len(r) == 0:
raise TypeError("Object of type 'Batch' has no len()")
return min(r)
@property
def size(self) -> int:
"""Return self.size."""
if len(self.__dict__.keys()) == 0:
return 0
else:
r = []
for v in self.__dict__.values():
if isinstance(v, Batch):
r.append(v.size)
elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
r.append(len(v))
return max(1, min(r) if len(r) > 0 else 0)
2020-04-03 21:28:12 +08:00
2020-05-12 11:31:47 +08:00
def split(self, size: Optional[int] = None,
2020-05-16 20:08:32 +08:00
shuffle: bool = True) -> Iterator['Batch']:
"""Split whole data into multiple small batch.
2020-04-03 21:28:12 +08:00
2020-04-06 19:36:59 +08:00
:param int size: if it is ``None``, it does not split the data batch;
2020-04-03 21:28:12 +08:00
otherwise it will divide the data batch with the given size.
2020-04-06 19:36:59 +08:00
Default to ``None``.
2020-04-28 20:56:02 +08:00
:param bool shuffle: randomly shuffle the entire data batch if it is
2020-04-06 19:36:59 +08:00
``True``, otherwise remain in the same. Default to ``True``.
2020-04-03 21:28:12 +08:00
"""
length = len(self)
2020-03-17 11:37:31 +08:00
if size is None:
size = length
2020-04-28 20:56:02 +08:00
if shuffle:
indices = np.random.permutation(length)
2020-03-20 19:52:29 +08:00
else:
indices = np.arange(length)
for idx in np.arange(0, length, size):
yield self[indices[idx:(idx + size)]]