Fix Batch to numpy compatibility (#92)

* Fix Batch to numpy compatibility.

* Fix Batch unit tests.

* Fix linter

* Add Batch shape method.

* Remove shape and add size. Enable to reserve keys using empty batch/list.

* Fix linter and unit tests.

* Batch init using list of Batch.

* Add unit tests.

* Fix Batch __len__.

* Fix unit tests.

* Fix slicing

* Add missing slicing unit tests.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-06-24 15:43:48 +02:00 committed by GitHub
parent ebc551a25e
commit 49f43e9f1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 144 additions and 32 deletions

View File

@ -39,13 +39,32 @@ def test_batch():
'c': np.zeros(1),
'd': Batch(e=np.array(3.0))}])
assert len(batch2) == 1
assert list(batch2[1].keys()) == ['a']
assert len(batch2[-2].a.d.keys()) == 0
assert len(batch2[-1].keys()) > 0
assert batch2[0][0].a.c == 0.0
with pytest.raises(IndexError):
batch2[-2]
with pytest.raises(IndexError):
batch2[1]
with pytest.raises(TypeError):
batch2[0][0]
assert isinstance(batch2[0].a.c, np.ndarray)
assert isinstance(batch2[0].a.b, np.float64)
assert isinstance(batch2[0].a.d.e, np.float64)
batch2_from_list = Batch(list(batch2))
batch2_from_comp = Batch([e for e in batch2])
assert batch2_from_list.a.b == batch2.a.b
assert batch2_from_list.a.c == batch2.a.c
assert batch2_from_list.a.d.e == batch2.a.d.e
assert batch2_from_comp.a.b == batch2.a.b
assert batch2_from_comp.a.c == batch2.a.c
assert batch2_from_comp.a.d.e == batch2.a.d.e
for batch_slice in [
batch2[slice(0, 1)], batch2[:1], batch2[0:]]:
assert batch_slice.a.b == batch2.a.b
assert batch_slice.a.c == batch2.a.c
assert batch_slice.a.d.e == batch2.a.d.e
batch2_sum = (batch2 + 1.0) * 2
assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
def test_batch_over_batch():
@ -146,6 +165,19 @@ def test_batch_from_to_numpy_without_copy():
assert c_mem_addr_new == c_mem_addr_orig
def test_batch_numpy_compatibility():
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]),
b=Batch(),
c=np.array([5.0, 6.0]))
batch_mean = np.mean(batch)
assert isinstance(batch_mean, Batch)
assert sorted(batch_mean.keys()) == ['a', 'b', 'c']
with pytest.raises(TypeError):
len(batch_mean)
assert np.all(batch_mean.a == np.mean(batch.a, axis=0))
assert batch_mean.c == np.mean(batch.c, axis=0)
if __name__ == '__main__':
test_batch()
test_batch_over_batch()

View File

@ -3,6 +3,7 @@ import copy
import pprint
import warnings
import numpy as np
from numbers import Number
from typing import Any, List, Tuple, Union, Iterator, Optional
# Disable pickle warning related to torch, since it has been removed
@ -75,15 +76,16 @@ class Batch:
"""
def __init__(self,
batch_dict: Optional[
Union[dict, Tuple[dict], List[dict], np.ndarray]] = None,
batch_dict: Optional[Union[
dict, 'Batch', Tuple[Union[dict, 'Batch']],
List[Union[dict, 'Batch']], np.ndarray]] = None,
**kwargs) -> None:
def _is_batch_set(data: Any) -> bool:
if isinstance(data, (list, tuple)):
if len(data) > 0 and isinstance(data[0], dict):
if len(data) > 0 and isinstance(data[0], (dict, Batch)):
return True
elif isinstance(data, np.ndarray):
if isinstance(data.item(0), dict):
if isinstance(data.item(0), (dict, Batch)):
return True
return False
@ -102,7 +104,7 @@ class Batch:
self.__dict__[k] = Batch.stack(v)
else:
self.__dict__[k] = list(v)
elif isinstance(batch_dict, dict):
elif isinstance(batch_dict, (dict, Batch)):
for k, v in batch_dict.items():
if isinstance(v, dict) or _is_batch_set(v):
self.__dict__[k] = Batch(v)
@ -141,22 +143,82 @@ class Batch:
return _valid_bounds(length, min(index)) and \
_valid_bounds(length, max(index))
elif isinstance(index, slice):
return _valid_bounds(length, index.start) and \
_valid_bounds(length, index.stop - 1)
if index.start is not None:
start_valid = _valid_bounds(length, index.start)
else:
start_valid = True
if index.stop is not None:
stop_valid = _valid_bounds(length, index.stop - 1)
else:
stop_valid = True
return start_valid and stop_valid
if isinstance(index, str):
return self.__getattr__(index)
if not _valid_bounds(len(self), index):
raise IndexError(
f"Index {index} out of bounds for Batch of len {len(self)}.")
else:
b = Batch()
for k, v in self.__dict__.items():
if isinstance(v, Batch):
b.__dict__[k] = v[index]
if isinstance(v, Batch) and v.size == 0:
b.__dict__[k] = Batch()
elif isinstance(v, list) and len(v) == 0:
b.__dict__[k] = []
elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
if _valid_bounds(len(v), index):
b.__dict__[k] = v[index]
else:
raise IndexError(
f"Index {index} out of bounds for {type(v)} of "
f"len {len(self)}.")
return b
def __iadd__(self, val: Union['Batch', Number]):
if isinstance(val, Batch):
for k, r, v in zip(self.__dict__.keys(),
self.__dict__.values(),
val.__dict__.values()):
if r is None:
self.__dict__[k] = r
elif isinstance(r, list):
self.__dict__[k] = [r_ + v_ for r_, v_ in zip(r, v)]
else:
self.__dict__[k] = r + v
return self
elif isinstance(val, Number):
for k, r in zip(self.__dict__.keys(), self.__dict__.values()):
if r is None:
self.__dict__[k] = r
elif isinstance(r, list):
self.__dict__[k] = [r_ + val for r_ in r]
else:
self.__dict__[k] = r + val
return self
else:
raise TypeError("Only addition of Batch or number is supported.")
def __add__(self, val: Union['Batch', Number]):
return copy.deepcopy(self).__iadd__(val)
def __mul__(self, val: Number):
assert isinstance(val, Number), \
"Only multiplication by a number is supported."
result = Batch()
for k, r in zip(self.__dict__.keys(), self.__dict__.values()):
result.__dict__[k] = r * val
return result
def __truediv__(self, val: Number):
assert isinstance(val, Number), \
"Only division by a number is supported."
result = Batch()
for k, r in zip(self.__dict__.keys(), self.__dict__.values()):
result.__dict__[k] = r / val
return result
def __getattr__(self, key: str) -> Union['Batch', Any]:
"""Return self.key"""
if key not in self.__dict__:
@ -167,12 +229,11 @@ class Batch:
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
for k in sorted(self.__dict__.keys()):
if self.__dict__.get(k, None) is not None:
rpl = '\n' + ' ' * (6 + len(k))
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
s += f' {k}: {obj},\n'
flag = True
for k, v in self.__dict__.items():
rpl = '\n' + ' ' * (6 + len(k))
obj = pprint.pformat(v).replace('\n', rpl)
s += f' {k}: {obj},\n'
flag = True
if flag:
s += ')'
else:
@ -296,10 +357,33 @@ class Batch:
"""Return len(self)."""
r = []
for v in self.__dict__.values():
if hasattr(v, '__len__') and (not isinstance(
if isinstance(v, Batch) and v.size == 0:
continue
elif isinstance(v, list) and len(v) == 0:
continue
elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
r.append(len(v))
return max(r) if len(r) > 0 else 0
else:
raise TypeError("Object of type 'Batch' has no len()")
if len(r) == 0:
raise TypeError("Object of type 'Batch' has no len()")
return min(r)
@property
def size(self) -> int:
"""Return self.size."""
if len(self.__dict__) == 0:
return 0
else:
r = []
for v in self.__dict__.values():
if isinstance(v, Batch):
r.append(v.size)
elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
r.append(len(v))
return max(1, min(r) if len(r) > 0 else 0)
def split(self, size: Optional[int] = None,
shuffle: bool = True) -> Iterator['Batch']:

View File

@ -5,7 +5,7 @@ from typing import Any, Tuple, Union, Optional
from tianshou.data.batch import Batch
class ReplayBuffer(object):
class ReplayBuffer:
""":class:`~tianshou.data.ReplayBuffer` stores data generated from
interaction between the policy and environment. It stores basically 7 types
of data, as mentioned in :class:`~tianshou.data.Batch`, based on
@ -96,7 +96,6 @@ class ReplayBuffer(object):
def __init__(self, size: int, stack_num: Optional[int] = 0,
ignore_obs_next: bool = False, **kwargs) -> None:
super().__init__()
self._maxsize = size
self._stack = stack_num
self._save_s_ = not ignore_obs_next
@ -137,7 +136,7 @@ class ReplayBuffer(object):
d[k_] = self.__dict__[k__]
else:
d[k_] = self.__getattr__(k__)
return Batch(**d)
return Batch(d)
def _add_to_buffer(self, name: str, inst: Any) -> None:
if inst is None:
@ -177,10 +176,7 @@ class ReplayBuffer(object):
"""Move the data from the given buffer to self."""
i = begin = buffer._index % len(buffer)
while True:
self.add(
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i],
buffer.obs_next[i] if self._save_s_ else None,
buffer.info[i], buffer.policy[i])
self.add(**buffer[i])
i = (i + 1) % len(buffer)
if i == begin:
break
@ -272,7 +268,7 @@ class ReplayBuffer(object):
else:
stack = []
many_keys = None
for i in range(stack_num):
for _ in range(stack_num):
if many_keys is not None:
for k_ in many_keys:
k__ = '_' + key + '@' + k_
@ -287,7 +283,7 @@ class ReplayBuffer(object):
if many_keys is not None:
for k in stack:
stack[k] = np.stack(stack[k], axis=1)
stack = Batch(**stack)
stack = Batch(stack)
else:
stack = np.stack(stack, axis=1)
return stack
@ -303,7 +299,7 @@ class ReplayBuffer(object):
rew=self.rew[index],
done=self.done[index],
obs_next=self.get(index, 'obs_next'),
info=self.info[index],
info=self.get(index, 'info'),
policy=self.get(index, 'policy'),
)
@ -440,7 +436,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
rew=self.rew[index],
done=self.done[index],
obs_next=self.get(index, 'obs_next'),
info=self.info[index],
info=self.get(index, 'info'),
weight=self.weight[index],
policy=self.get(index, 'policy'),
)