Fix Batch to numpy compatibility (#92)

* Fix Batch to numpy compatibility.

* Fix Batch unit tests.

* Fix linter

* Add Batch shape method.

* Remove shape and add size. Enable to reserve keys using empty batch/list.

* Fix linter and unit tests.

* Batch init using list of Batch.

* Add unit tests.

* Fix Batch __len__.

* Fix unit tests.

* Fix slicing

* Add missing slicing unit tests.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-06-24 15:43:48 +02:00 committed by GitHub
parent ebc551a25e
commit 49f43e9f1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 144 additions and 32 deletions

View File

@ -39,13 +39,32 @@ def test_batch():
'c': np.zeros(1), 'c': np.zeros(1),
'd': Batch(e=np.array(3.0))}]) 'd': Batch(e=np.array(3.0))}])
assert len(batch2) == 1 assert len(batch2) == 1
assert list(batch2[1].keys()) == ['a'] with pytest.raises(IndexError):
assert len(batch2[-2].a.d.keys()) == 0 batch2[-2]
assert len(batch2[-1].keys()) > 0 with pytest.raises(IndexError):
assert batch2[0][0].a.c == 0.0 batch2[1]
with pytest.raises(TypeError):
batch2[0][0]
assert isinstance(batch2[0].a.c, np.ndarray) assert isinstance(batch2[0].a.c, np.ndarray)
assert isinstance(batch2[0].a.b, np.float64) assert isinstance(batch2[0].a.b, np.float64)
assert isinstance(batch2[0].a.d.e, np.float64) assert isinstance(batch2[0].a.d.e, np.float64)
batch2_from_list = Batch(list(batch2))
batch2_from_comp = Batch([e for e in batch2])
assert batch2_from_list.a.b == batch2.a.b
assert batch2_from_list.a.c == batch2.a.c
assert batch2_from_list.a.d.e == batch2.a.d.e
assert batch2_from_comp.a.b == batch2.a.b
assert batch2_from_comp.a.c == batch2.a.c
assert batch2_from_comp.a.d.e == batch2.a.d.e
for batch_slice in [
batch2[slice(0, 1)], batch2[:1], batch2[0:]]:
assert batch_slice.a.b == batch2.a.b
assert batch_slice.a.c == batch2.a.c
assert batch_slice.a.d.e == batch2.a.d.e
batch2_sum = (batch2 + 1.0) * 2
assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
def test_batch_over_batch(): def test_batch_over_batch():
@ -146,6 +165,19 @@ def test_batch_from_to_numpy_without_copy():
assert c_mem_addr_new == c_mem_addr_orig assert c_mem_addr_new == c_mem_addr_orig
def test_batch_numpy_compatibility():
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]),
b=Batch(),
c=np.array([5.0, 6.0]))
batch_mean = np.mean(batch)
assert isinstance(batch_mean, Batch)
assert sorted(batch_mean.keys()) == ['a', 'b', 'c']
with pytest.raises(TypeError):
len(batch_mean)
assert np.all(batch_mean.a == np.mean(batch.a, axis=0))
assert batch_mean.c == np.mean(batch.c, axis=0)
if __name__ == '__main__': if __name__ == '__main__':
test_batch() test_batch()
test_batch_over_batch() test_batch_over_batch()

View File

@ -3,6 +3,7 @@ import copy
import pprint import pprint
import warnings import warnings
import numpy as np import numpy as np
from numbers import Number
from typing import Any, List, Tuple, Union, Iterator, Optional from typing import Any, List, Tuple, Union, Iterator, Optional
# Disable pickle warning related to torch, since it has been removed # Disable pickle warning related to torch, since it has been removed
@ -75,15 +76,16 @@ class Batch:
""" """
def __init__(self, def __init__(self,
batch_dict: Optional[ batch_dict: Optional[Union[
Union[dict, Tuple[dict], List[dict], np.ndarray]] = None, dict, 'Batch', Tuple[Union[dict, 'Batch']],
List[Union[dict, 'Batch']], np.ndarray]] = None,
**kwargs) -> None: **kwargs) -> None:
def _is_batch_set(data: Any) -> bool: def _is_batch_set(data: Any) -> bool:
if isinstance(data, (list, tuple)): if isinstance(data, (list, tuple)):
if len(data) > 0 and isinstance(data[0], dict): if len(data) > 0 and isinstance(data[0], (dict, Batch)):
return True return True
elif isinstance(data, np.ndarray): elif isinstance(data, np.ndarray):
if isinstance(data.item(0), dict): if isinstance(data.item(0), (dict, Batch)):
return True return True
return False return False
@ -102,7 +104,7 @@ class Batch:
self.__dict__[k] = Batch.stack(v) self.__dict__[k] = Batch.stack(v)
else: else:
self.__dict__[k] = list(v) self.__dict__[k] = list(v)
elif isinstance(batch_dict, dict): elif isinstance(batch_dict, (dict, Batch)):
for k, v in batch_dict.items(): for k, v in batch_dict.items():
if isinstance(v, dict) or _is_batch_set(v): if isinstance(v, dict) or _is_batch_set(v):
self.__dict__[k] = Batch(v) self.__dict__[k] = Batch(v)
@ -141,22 +143,82 @@ class Batch:
return _valid_bounds(length, min(index)) and \ return _valid_bounds(length, min(index)) and \
_valid_bounds(length, max(index)) _valid_bounds(length, max(index))
elif isinstance(index, slice): elif isinstance(index, slice):
return _valid_bounds(length, index.start) and \ if index.start is not None:
_valid_bounds(length, index.stop - 1) start_valid = _valid_bounds(length, index.start)
else:
start_valid = True
if index.stop is not None:
stop_valid = _valid_bounds(length, index.stop - 1)
else:
stop_valid = True
return start_valid and stop_valid
if isinstance(index, str): if isinstance(index, str):
return self.__getattr__(index) return self.__getattr__(index)
if not _valid_bounds(len(self), index):
raise IndexError(
f"Index {index} out of bounds for Batch of len {len(self)}.")
else: else:
b = Batch() b = Batch()
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if isinstance(v, Batch): if isinstance(v, Batch) and v.size == 0:
b.__dict__[k] = v[index] b.__dict__[k] = Batch()
elif isinstance(v, list) and len(v) == 0:
b.__dict__[k] = []
elif hasattr(v, '__len__') and (not isinstance( elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0): v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
if _valid_bounds(len(v), index): if _valid_bounds(len(v), index):
b.__dict__[k] = v[index] b.__dict__[k] = v[index]
else:
raise IndexError(
f"Index {index} out of bounds for {type(v)} of "
f"len {len(self)}.")
return b return b
def __iadd__(self, val: Union['Batch', Number]):
if isinstance(val, Batch):
for k, r, v in zip(self.__dict__.keys(),
self.__dict__.values(),
val.__dict__.values()):
if r is None:
self.__dict__[k] = r
elif isinstance(r, list):
self.__dict__[k] = [r_ + v_ for r_, v_ in zip(r, v)]
else:
self.__dict__[k] = r + v
return self
elif isinstance(val, Number):
for k, r in zip(self.__dict__.keys(), self.__dict__.values()):
if r is None:
self.__dict__[k] = r
elif isinstance(r, list):
self.__dict__[k] = [r_ + val for r_ in r]
else:
self.__dict__[k] = r + val
return self
else:
raise TypeError("Only addition of Batch or number is supported.")
def __add__(self, val: Union['Batch', Number]):
return copy.deepcopy(self).__iadd__(val)
def __mul__(self, val: Number):
assert isinstance(val, Number), \
"Only multiplication by a number is supported."
result = Batch()
for k, r in zip(self.__dict__.keys(), self.__dict__.values()):
result.__dict__[k] = r * val
return result
def __truediv__(self, val: Number):
assert isinstance(val, Number), \
"Only division by a number is supported."
result = Batch()
for k, r in zip(self.__dict__.keys(), self.__dict__.values()):
result.__dict__[k] = r / val
return result
def __getattr__(self, key: str) -> Union['Batch', Any]: def __getattr__(self, key: str) -> Union['Batch', Any]:
"""Return self.key""" """Return self.key"""
if key not in self.__dict__: if key not in self.__dict__:
@ -167,12 +229,11 @@ class Batch:
"""Return str(self).""" """Return str(self)."""
s = self.__class__.__name__ + '(\n' s = self.__class__.__name__ + '(\n'
flag = False flag = False
for k in sorted(self.__dict__.keys()): for k, v in self.__dict__.items():
if self.__dict__.get(k, None) is not None: rpl = '\n' + ' ' * (6 + len(k))
rpl = '\n' + ' ' * (6 + len(k)) obj = pprint.pformat(v).replace('\n', rpl)
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl) s += f' {k}: {obj},\n'
s += f' {k}: {obj},\n' flag = True
flag = True
if flag: if flag:
s += ')' s += ')'
else: else:
@ -296,10 +357,33 @@ class Batch:
"""Return len(self).""" """Return len(self)."""
r = [] r = []
for v in self.__dict__.values(): for v in self.__dict__.values():
if hasattr(v, '__len__') and (not isinstance( if isinstance(v, Batch) and v.size == 0:
continue
elif isinstance(v, list) and len(v) == 0:
continue
elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0): v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
r.append(len(v)) r.append(len(v))
return max(r) if len(r) > 0 else 0 else:
raise TypeError("Object of type 'Batch' has no len()")
if len(r) == 0:
raise TypeError("Object of type 'Batch' has no len()")
return min(r)
@property
def size(self) -> int:
"""Return self.size."""
if len(self.__dict__) == 0:
return 0
else:
r = []
for v in self.__dict__.values():
if isinstance(v, Batch):
r.append(v.size)
elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
r.append(len(v))
return max(1, min(r) if len(r) > 0 else 0)
def split(self, size: Optional[int] = None, def split(self, size: Optional[int] = None,
shuffle: bool = True) -> Iterator['Batch']: shuffle: bool = True) -> Iterator['Batch']:

View File

@ -5,7 +5,7 @@ from typing import Any, Tuple, Union, Optional
from tianshou.data.batch import Batch from tianshou.data.batch import Batch
class ReplayBuffer(object): class ReplayBuffer:
""":class:`~tianshou.data.ReplayBuffer` stores data generated from """:class:`~tianshou.data.ReplayBuffer` stores data generated from
interaction between the policy and environment. It stores basically 7 types interaction between the policy and environment. It stores basically 7 types
of data, as mentioned in :class:`~tianshou.data.Batch`, based on of data, as mentioned in :class:`~tianshou.data.Batch`, based on
@ -96,7 +96,6 @@ class ReplayBuffer(object):
def __init__(self, size: int, stack_num: Optional[int] = 0, def __init__(self, size: int, stack_num: Optional[int] = 0,
ignore_obs_next: bool = False, **kwargs) -> None: ignore_obs_next: bool = False, **kwargs) -> None:
super().__init__()
self._maxsize = size self._maxsize = size
self._stack = stack_num self._stack = stack_num
self._save_s_ = not ignore_obs_next self._save_s_ = not ignore_obs_next
@ -137,7 +136,7 @@ class ReplayBuffer(object):
d[k_] = self.__dict__[k__] d[k_] = self.__dict__[k__]
else: else:
d[k_] = self.__getattr__(k__) d[k_] = self.__getattr__(k__)
return Batch(**d) return Batch(d)
def _add_to_buffer(self, name: str, inst: Any) -> None: def _add_to_buffer(self, name: str, inst: Any) -> None:
if inst is None: if inst is None:
@ -177,10 +176,7 @@ class ReplayBuffer(object):
"""Move the data from the given buffer to self.""" """Move the data from the given buffer to self."""
i = begin = buffer._index % len(buffer) i = begin = buffer._index % len(buffer)
while True: while True:
self.add( self.add(**buffer[i])
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i],
buffer.obs_next[i] if self._save_s_ else None,
buffer.info[i], buffer.policy[i])
i = (i + 1) % len(buffer) i = (i + 1) % len(buffer)
if i == begin: if i == begin:
break break
@ -272,7 +268,7 @@ class ReplayBuffer(object):
else: else:
stack = [] stack = []
many_keys = None many_keys = None
for i in range(stack_num): for _ in range(stack_num):
if many_keys is not None: if many_keys is not None:
for k_ in many_keys: for k_ in many_keys:
k__ = '_' + key + '@' + k_ k__ = '_' + key + '@' + k_
@ -287,7 +283,7 @@ class ReplayBuffer(object):
if many_keys is not None: if many_keys is not None:
for k in stack: for k in stack:
stack[k] = np.stack(stack[k], axis=1) stack[k] = np.stack(stack[k], axis=1)
stack = Batch(**stack) stack = Batch(stack)
else: else:
stack = np.stack(stack, axis=1) stack = np.stack(stack, axis=1)
return stack return stack
@ -303,7 +299,7 @@ class ReplayBuffer(object):
rew=self.rew[index], rew=self.rew[index],
done=self.done[index], done=self.done[index],
obs_next=self.get(index, 'obs_next'), obs_next=self.get(index, 'obs_next'),
info=self.info[index], info=self.get(index, 'info'),
policy=self.get(index, 'policy'), policy=self.get(index, 'policy'),
) )
@ -440,7 +436,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
rew=self.rew[index], rew=self.rew[index],
done=self.done[index], done=self.done[index],
obs_next=self.get(index, 'obs_next'), obs_next=self.get(index, 'obs_next'),
info=self.info[index], info=self.get(index, 'info'),
weight=self.weight[index], weight=self.weight[index],
policy=self.get(index, 'policy'), policy=self.get(index, 'policy'),
) )