Minor refactor for Batch class. (#61)
* Minor refactor for Batch class. * Fix. * Add back key sorting. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
be9ce44290
commit
b5093ecb56
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
from typing import Any, List, Union, Iterator, Optional
|
from typing import Any, List, Union, Iterator, Optional
|
||||||
|
|
||||||
|
|
||||||
class Batch(object):
|
class Batch:
|
||||||
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
|
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
|
||||||
structure to pass any kind of data to other methods, for example, a
|
structure to pass any kind of data to other methods, for example, a
|
||||||
collector gives a :class:`~tianshou.data.Batch` to policy for learning.
|
collector gives a :class:`~tianshou.data.Batch` to policy for learning.
|
||||||
@ -80,9 +80,9 @@ class Batch(object):
|
|||||||
])
|
])
|
||||||
elif isinstance(v, dict):
|
elif isinstance(v, dict):
|
||||||
self._meta[k] = list(v.keys())
|
self._meta[k] = list(v.keys())
|
||||||
for k_ in v.keys():
|
for k_, v_ in v.items():
|
||||||
k__ = '_' + k + '@' + k_
|
k__ = '_' + k + '@' + k_
|
||||||
self.__dict__[k__] = v[k_]
|
self.__dict__[k__] = v_
|
||||||
else:
|
else:
|
||||||
self.__dict__[k] = kwargs[k]
|
self.__dict__[k] = kwargs[k]
|
||||||
|
|
||||||
@ -91,15 +91,15 @@ class Batch(object):
|
|||||||
if isinstance(index, str):
|
if isinstance(index, str):
|
||||||
return self.__getattr__(index)
|
return self.__getattr__(index)
|
||||||
b = Batch()
|
b = Batch()
|
||||||
for k in self.__dict__:
|
for k, v in self.__dict__.items():
|
||||||
if k != '_meta' and self.__dict__[k] is not None:
|
if k != '_meta' and v is not None:
|
||||||
b.__dict__.update(**{k: self.__dict__[k][index]})
|
b.__dict__.update(**{k: v[index]})
|
||||||
b._meta = self._meta
|
b._meta = self._meta
|
||||||
return b
|
return b
|
||||||
|
|
||||||
def __getattr__(self, key: str) -> Union['Batch', Any]:
|
def __getattr__(self, key: str) -> Union['Batch', Any]:
|
||||||
"""Return self.key"""
|
"""Return self.key"""
|
||||||
if key not in self._meta:
|
if key not in self._meta.keys():
|
||||||
if key not in self.__dict__:
|
if key not in self.__dict__:
|
||||||
raise AttributeError(key)
|
raise AttributeError(key)
|
||||||
return self.__dict__[key]
|
return self.__dict__[key]
|
||||||
@ -128,8 +128,8 @@ class Batch(object):
|
|||||||
|
|
||||||
def keys(self) -> List[str]:
|
def keys(self) -> List[str]:
|
||||||
"""Return self.keys()."""
|
"""Return self.keys()."""
|
||||||
return sorted([
|
return sorted(list(self._meta.keys()) +
|
||||||
i for i in self.__dict__ if i[0] != '_'] + list(self._meta))
|
[k for k in self.__dict__.keys() if k[0] != '_'])
|
||||||
|
|
||||||
def values(self) -> List[Any]:
|
def values(self) -> List[Any]:
|
||||||
"""Return self.values()."""
|
"""Return self.values()."""
|
||||||
@ -145,40 +145,36 @@ class Batch(object):
|
|||||||
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
|
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
|
||||||
operation.
|
operation.
|
||||||
"""
|
"""
|
||||||
for k in self.__dict__:
|
for k, v in self.__dict__.items():
|
||||||
if isinstance(self.__dict__[k], torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
self.__dict__[k] = self.__dict__[k].cpu().numpy()
|
self.__dict__[k] = v.cpu().numpy()
|
||||||
|
|
||||||
def append(self, batch: 'Batch') -> None:
|
def append(self, batch: 'Batch') -> None:
|
||||||
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
||||||
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
||||||
for k in batch.__dict__:
|
for k, v in batch.__dict__.items():
|
||||||
if k == '_meta':
|
if k == '_meta':
|
||||||
self._meta.update(batch._meta)
|
self._meta.update(batch._meta)
|
||||||
continue
|
continue
|
||||||
if batch.__dict__[k] is None:
|
if v is None:
|
||||||
continue
|
continue
|
||||||
if not hasattr(self, k) or self.__dict__[k] is None:
|
if not hasattr(self, k) or self.__dict__[k] is None:
|
||||||
self.__dict__[k] = batch.__dict__[k]
|
self.__dict__[k] = v
|
||||||
elif isinstance(batch.__dict__[k], np.ndarray):
|
elif isinstance(v, np.ndarray):
|
||||||
self.__dict__[k] = np.concatenate([
|
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
|
||||||
self.__dict__[k], batch.__dict__[k]])
|
elif isinstance(v, torch.Tensor):
|
||||||
elif isinstance(batch.__dict__[k], torch.Tensor):
|
self.__dict__[k] = torch.cat([self.__dict__[k], v])
|
||||||
self.__dict__[k] = torch.cat([
|
elif isinstance(v, list):
|
||||||
self.__dict__[k], batch.__dict__[k]])
|
self.__dict__[k] += v
|
||||||
elif isinstance(batch.__dict__[k], list):
|
|
||||||
self.__dict__[k] += batch.__dict__[k]
|
|
||||||
else:
|
else:
|
||||||
s = 'No support for append with type' \
|
s = f'No support for append with type \
|
||||||
+ str(type(batch.__dict__[k])) \
|
{type(v)} in class Batch.'
|
||||||
+ 'in class Batch.'
|
|
||||||
raise TypeError(s)
|
raise TypeError(s)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return len(self)."""
|
"""Return len(self)."""
|
||||||
return min([
|
return min([len(v) for k, v in self.__dict__.items()
|
||||||
len(self.__dict__[k]) for k in self.__dict__
|
if k != '_meta' and v is not None])
|
||||||
if k != '_meta' and self.__dict__[k] is not None])
|
|
||||||
|
|
||||||
def split(self, size: Optional[int] = None,
|
def split(self, size: Optional[int] = None,
|
||||||
shuffle: bool = True) -> Iterator['Batch']:
|
shuffle: bool = True) -> Iterator['Batch']:
|
||||||
@ -193,11 +189,9 @@ class Batch(object):
|
|||||||
length = len(self)
|
length = len(self)
|
||||||
if size is None:
|
if size is None:
|
||||||
size = length
|
size = length
|
||||||
temp = 0
|
|
||||||
if shuffle:
|
if shuffle:
|
||||||
index = np.random.permutation(length)
|
indices = np.random.permutation(length)
|
||||||
else:
|
else:
|
||||||
index = np.arange(length)
|
indices = np.arange(length)
|
||||||
while temp < length:
|
for idx in np.arange(0, length, size):
|
||||||
yield self[index[temp:temp + size]]
|
yield self[indices[idx:(idx + size)]]
|
||||||
temp += size
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user