Minor refactor for Batch class. (#61)

* Minor refactor for Batch class.

* Fix.

* Add back key sorting.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-05-29 11:56:46 +02:00 committed by GitHub
parent be9ce44290
commit b5093ecb56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,7 @@ import numpy as np
from typing import Any, List, Union, Iterator, Optional from typing import Any, List, Union, Iterator, Optional
class Batch(object): class Batch:
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data """Tianshou provides :class:`~tianshou.data.Batch` as the internal data
structure to pass any kind of data to other methods, for example, a structure to pass any kind of data to other methods, for example, a
collector gives a :class:`~tianshou.data.Batch` to policy for learning. collector gives a :class:`~tianshou.data.Batch` to policy for learning.
@ -80,9 +80,9 @@ class Batch(object):
]) ])
elif isinstance(v, dict): elif isinstance(v, dict):
self._meta[k] = list(v.keys()) self._meta[k] = list(v.keys())
for k_ in v.keys(): for k_, v_ in v.items():
k__ = '_' + k + '@' + k_ k__ = '_' + k + '@' + k_
self.__dict__[k__] = v[k_] self.__dict__[k__] = v_
else: else:
self.__dict__[k] = kwargs[k] self.__dict__[k] = kwargs[k]
@ -91,15 +91,15 @@ class Batch(object):
if isinstance(index, str): if isinstance(index, str):
return self.__getattr__(index) return self.__getattr__(index)
b = Batch() b = Batch()
for k in self.__dict__: for k, v in self.__dict__.items():
if k != '_meta' and self.__dict__[k] is not None: if k != '_meta' and v is not None:
b.__dict__.update(**{k: self.__dict__[k][index]}) b.__dict__.update(**{k: v[index]})
b._meta = self._meta b._meta = self._meta
return b return b
def __getattr__(self, key: str) -> Union['Batch', Any]: def __getattr__(self, key: str) -> Union['Batch', Any]:
"""Return self.key""" """Return self.key"""
if key not in self._meta: if key not in self._meta.keys():
if key not in self.__dict__: if key not in self.__dict__:
raise AttributeError(key) raise AttributeError(key)
return self.__dict__[key] return self.__dict__[key]
@ -128,8 +128,8 @@ class Batch(object):
def keys(self) -> List[str]: def keys(self) -> List[str]:
"""Return self.keys().""" """Return self.keys()."""
return sorted([ return sorted(list(self._meta.keys()) +
i for i in self.__dict__ if i[0] != '_'] + list(self._meta)) [k for k in self.__dict__.keys() if k[0] != '_'])
def values(self) -> List[Any]: def values(self) -> List[Any]:
"""Return self.values().""" """Return self.values()."""
@ -145,40 +145,36 @@ class Batch(object):
"""Change all torch.Tensor to numpy.ndarray. This is an inplace """Change all torch.Tensor to numpy.ndarray. This is an inplace
operation. operation.
""" """
for k in self.__dict__: for k, v in self.__dict__.items():
if isinstance(self.__dict__[k], torch.Tensor): if isinstance(v, torch.Tensor):
self.__dict__[k] = self.__dict__[k].cpu().numpy() self.__dict__[k] = v.cpu().numpy()
def append(self, batch: 'Batch') -> None: def append(self, batch: 'Batch') -> None:
"""Append a :class:`~tianshou.data.Batch` object to current batch.""" """Append a :class:`~tianshou.data.Batch` object to current batch."""
assert isinstance(batch, Batch), 'Only append Batch is allowed!' assert isinstance(batch, Batch), 'Only append Batch is allowed!'
for k in batch.__dict__: for k, v in batch.__dict__.items():
if k == '_meta': if k == '_meta':
self._meta.update(batch._meta) self._meta.update(batch._meta)
continue continue
if batch.__dict__[k] is None: if v is None:
continue continue
if not hasattr(self, k) or self.__dict__[k] is None: if not hasattr(self, k) or self.__dict__[k] is None:
self.__dict__[k] = batch.__dict__[k] self.__dict__[k] = v
elif isinstance(batch.__dict__[k], np.ndarray): elif isinstance(v, np.ndarray):
self.__dict__[k] = np.concatenate([ self.__dict__[k] = np.concatenate([self.__dict__[k], v])
self.__dict__[k], batch.__dict__[k]]) elif isinstance(v, torch.Tensor):
elif isinstance(batch.__dict__[k], torch.Tensor): self.__dict__[k] = torch.cat([self.__dict__[k], v])
self.__dict__[k] = torch.cat([ elif isinstance(v, list):
self.__dict__[k], batch.__dict__[k]]) self.__dict__[k] += v
elif isinstance(batch.__dict__[k], list):
self.__dict__[k] += batch.__dict__[k]
else: else:
s = 'No support for append with type' \ s = f'No support for append with type \
+ str(type(batch.__dict__[k])) \ {type(v)} in class Batch.'
+ 'in class Batch.'
raise TypeError(s) raise TypeError(s)
def __len__(self) -> int: def __len__(self) -> int:
"""Return len(self).""" """Return len(self)."""
return min([ return min([len(v) for k, v in self.__dict__.items()
len(self.__dict__[k]) for k in self.__dict__ if k != '_meta' and v is not None])
if k != '_meta' and self.__dict__[k] is not None])
def split(self, size: Optional[int] = None, def split(self, size: Optional[int] = None,
shuffle: bool = True) -> Iterator['Batch']: shuffle: bool = True) -> Iterator['Batch']:
@ -193,11 +189,9 @@ class Batch(object):
length = len(self) length = len(self)
if size is None: if size is None:
size = length size = length
temp = 0
if shuffle: if shuffle:
index = np.random.permutation(length) indices = np.random.permutation(length)
else: else:
index = np.arange(length) indices = np.arange(length)
while temp < length: for idx in np.arange(0, length, size):
yield self[index[temp:temp + size]] yield self[indices[idx:(idx + size)]]
temp += size