Fix Batch to numpy compatibility (#92)
* Fix Batch to numpy compatibility. * Fix Batch unit tests. * Fix linter * Add Batch shape method. * Remove shape and add size. Enable to reserve keys using empty batch/list. * Fix linter and unit tests. * Batch init using list of Batch. * Add unit tests. * Fix Batch __len__. * Fix unit tests. * Fix slicing * Add missing slicing unit tests. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
ebc551a25e
commit
49f43e9f1f
@ -39,13 +39,32 @@ def test_batch():
|
||||
'c': np.zeros(1),
|
||||
'd': Batch(e=np.array(3.0))}])
|
||||
assert len(batch2) == 1
|
||||
assert list(batch2[1].keys()) == ['a']
|
||||
assert len(batch2[-2].a.d.keys()) == 0
|
||||
assert len(batch2[-1].keys()) > 0
|
||||
assert batch2[0][0].a.c == 0.0
|
||||
with pytest.raises(IndexError):
|
||||
batch2[-2]
|
||||
with pytest.raises(IndexError):
|
||||
batch2[1]
|
||||
with pytest.raises(TypeError):
|
||||
batch2[0][0]
|
||||
assert isinstance(batch2[0].a.c, np.ndarray)
|
||||
assert isinstance(batch2[0].a.b, np.float64)
|
||||
assert isinstance(batch2[0].a.d.e, np.float64)
|
||||
batch2_from_list = Batch(list(batch2))
|
||||
batch2_from_comp = Batch([e for e in batch2])
|
||||
assert batch2_from_list.a.b == batch2.a.b
|
||||
assert batch2_from_list.a.c == batch2.a.c
|
||||
assert batch2_from_list.a.d.e == batch2.a.d.e
|
||||
assert batch2_from_comp.a.b == batch2.a.b
|
||||
assert batch2_from_comp.a.c == batch2.a.c
|
||||
assert batch2_from_comp.a.d.e == batch2.a.d.e
|
||||
for batch_slice in [
|
||||
batch2[slice(0, 1)], batch2[:1], batch2[0:]]:
|
||||
assert batch_slice.a.b == batch2.a.b
|
||||
assert batch_slice.a.c == batch2.a.c
|
||||
assert batch_slice.a.d.e == batch2.a.d.e
|
||||
batch2_sum = (batch2 + 1.0) * 2
|
||||
assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
|
||||
assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
|
||||
assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
|
||||
|
||||
|
||||
def test_batch_over_batch():
|
||||
@ -146,6 +165,19 @@ def test_batch_from_to_numpy_without_copy():
|
||||
assert c_mem_addr_new == c_mem_addr_orig
|
||||
|
||||
|
||||
def test_batch_numpy_compatibility():
|
||||
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]),
|
||||
b=Batch(),
|
||||
c=np.array([5.0, 6.0]))
|
||||
batch_mean = np.mean(batch)
|
||||
assert isinstance(batch_mean, Batch)
|
||||
assert sorted(batch_mean.keys()) == ['a', 'b', 'c']
|
||||
with pytest.raises(TypeError):
|
||||
len(batch_mean)
|
||||
assert np.all(batch_mean.a == np.mean(batch.a, axis=0))
|
||||
assert batch_mean.c == np.mean(batch.c, axis=0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_batch()
|
||||
test_batch_over_batch()
|
||||
|
@ -3,6 +3,7 @@ import copy
|
||||
import pprint
|
||||
import warnings
|
||||
import numpy as np
|
||||
from numbers import Number
|
||||
from typing import Any, List, Tuple, Union, Iterator, Optional
|
||||
|
||||
# Disable pickle warning related to torch, since it has been removed
|
||||
@ -75,15 +76,16 @@ class Batch:
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_dict: Optional[
|
||||
Union[dict, Tuple[dict], List[dict], np.ndarray]] = None,
|
||||
batch_dict: Optional[Union[
|
||||
dict, 'Batch', Tuple[Union[dict, 'Batch']],
|
||||
List[Union[dict, 'Batch']], np.ndarray]] = None,
|
||||
**kwargs) -> None:
|
||||
def _is_batch_set(data: Any) -> bool:
|
||||
if isinstance(data, (list, tuple)):
|
||||
if len(data) > 0 and isinstance(data[0], dict):
|
||||
if len(data) > 0 and isinstance(data[0], (dict, Batch)):
|
||||
return True
|
||||
elif isinstance(data, np.ndarray):
|
||||
if isinstance(data.item(0), dict):
|
||||
if isinstance(data.item(0), (dict, Batch)):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -102,7 +104,7 @@ class Batch:
|
||||
self.__dict__[k] = Batch.stack(v)
|
||||
else:
|
||||
self.__dict__[k] = list(v)
|
||||
elif isinstance(batch_dict, dict):
|
||||
elif isinstance(batch_dict, (dict, Batch)):
|
||||
for k, v in batch_dict.items():
|
||||
if isinstance(v, dict) or _is_batch_set(v):
|
||||
self.__dict__[k] = Batch(v)
|
||||
@ -141,22 +143,82 @@ class Batch:
|
||||
return _valid_bounds(length, min(index)) and \
|
||||
_valid_bounds(length, max(index))
|
||||
elif isinstance(index, slice):
|
||||
return _valid_bounds(length, index.start) and \
|
||||
_valid_bounds(length, index.stop - 1)
|
||||
if index.start is not None:
|
||||
start_valid = _valid_bounds(length, index.start)
|
||||
else:
|
||||
start_valid = True
|
||||
if index.stop is not None:
|
||||
stop_valid = _valid_bounds(length, index.stop - 1)
|
||||
else:
|
||||
stop_valid = True
|
||||
return start_valid and stop_valid
|
||||
|
||||
if isinstance(index, str):
|
||||
return self.__getattr__(index)
|
||||
|
||||
if not _valid_bounds(len(self), index):
|
||||
raise IndexError(
|
||||
f"Index {index} out of bounds for Batch of len {len(self)}.")
|
||||
else:
|
||||
b = Batch()
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, Batch):
|
||||
b.__dict__[k] = v[index]
|
||||
if isinstance(v, Batch) and v.size == 0:
|
||||
b.__dict__[k] = Batch()
|
||||
elif isinstance(v, list) and len(v) == 0:
|
||||
b.__dict__[k] = []
|
||||
elif hasattr(v, '__len__') and (not isinstance(
|
||||
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
|
||||
if _valid_bounds(len(v), index):
|
||||
b.__dict__[k] = v[index]
|
||||
else:
|
||||
raise IndexError(
|
||||
f"Index {index} out of bounds for {type(v)} of "
|
||||
f"len {len(self)}.")
|
||||
return b
|
||||
|
||||
def __iadd__(self, val: Union['Batch', Number]):
|
||||
if isinstance(val, Batch):
|
||||
for k, r, v in zip(self.__dict__.keys(),
|
||||
self.__dict__.values(),
|
||||
val.__dict__.values()):
|
||||
if r is None:
|
||||
self.__dict__[k] = r
|
||||
elif isinstance(r, list):
|
||||
self.__dict__[k] = [r_ + v_ for r_, v_ in zip(r, v)]
|
||||
else:
|
||||
self.__dict__[k] = r + v
|
||||
return self
|
||||
elif isinstance(val, Number):
|
||||
for k, r in zip(self.__dict__.keys(), self.__dict__.values()):
|
||||
if r is None:
|
||||
self.__dict__[k] = r
|
||||
elif isinstance(r, list):
|
||||
self.__dict__[k] = [r_ + val for r_ in r]
|
||||
else:
|
||||
self.__dict__[k] = r + val
|
||||
return self
|
||||
else:
|
||||
raise TypeError("Only addition of Batch or number is supported.")
|
||||
|
||||
def __add__(self, val: Union['Batch', Number]):
|
||||
return copy.deepcopy(self).__iadd__(val)
|
||||
|
||||
def __mul__(self, val: Number):
|
||||
assert isinstance(val, Number), \
|
||||
"Only multiplication by a number is supported."
|
||||
result = Batch()
|
||||
for k, r in zip(self.__dict__.keys(), self.__dict__.values()):
|
||||
result.__dict__[k] = r * val
|
||||
return result
|
||||
|
||||
def __truediv__(self, val: Number):
|
||||
assert isinstance(val, Number), \
|
||||
"Only division by a number is supported."
|
||||
result = Batch()
|
||||
for k, r in zip(self.__dict__.keys(), self.__dict__.values()):
|
||||
result.__dict__[k] = r / val
|
||||
return result
|
||||
|
||||
def __getattr__(self, key: str) -> Union['Batch', Any]:
|
||||
"""Return self.key"""
|
||||
if key not in self.__dict__:
|
||||
@ -167,12 +229,11 @@ class Batch:
|
||||
"""Return str(self)."""
|
||||
s = self.__class__.__name__ + '(\n'
|
||||
flag = False
|
||||
for k in sorted(self.__dict__.keys()):
|
||||
if self.__dict__.get(k, None) is not None:
|
||||
rpl = '\n' + ' ' * (6 + len(k))
|
||||
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
|
||||
s += f' {k}: {obj},\n'
|
||||
flag = True
|
||||
for k, v in self.__dict__.items():
|
||||
rpl = '\n' + ' ' * (6 + len(k))
|
||||
obj = pprint.pformat(v).replace('\n', rpl)
|
||||
s += f' {k}: {obj},\n'
|
||||
flag = True
|
||||
if flag:
|
||||
s += ')'
|
||||
else:
|
||||
@ -296,10 +357,33 @@ class Batch:
|
||||
"""Return len(self)."""
|
||||
r = []
|
||||
for v in self.__dict__.values():
|
||||
if hasattr(v, '__len__') and (not isinstance(
|
||||
if isinstance(v, Batch) and v.size == 0:
|
||||
continue
|
||||
elif isinstance(v, list) and len(v) == 0:
|
||||
continue
|
||||
elif hasattr(v, '__len__') and (not isinstance(
|
||||
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
|
||||
r.append(len(v))
|
||||
return max(r) if len(r) > 0 else 0
|
||||
else:
|
||||
raise TypeError("Object of type 'Batch' has no len()")
|
||||
if len(r) == 0:
|
||||
raise TypeError("Object of type 'Batch' has no len()")
|
||||
return min(r)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Return self.size."""
|
||||
if len(self.__dict__) == 0:
|
||||
return 0
|
||||
else:
|
||||
r = []
|
||||
for v in self.__dict__.values():
|
||||
if isinstance(v, Batch):
|
||||
r.append(v.size)
|
||||
elif hasattr(v, '__len__') and (not isinstance(
|
||||
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
|
||||
r.append(len(v))
|
||||
return max(1, min(r) if len(r) > 0 else 0)
|
||||
|
||||
def split(self, size: Optional[int] = None,
|
||||
shuffle: bool = True) -> Iterator['Batch']:
|
||||
|
@ -5,7 +5,7 @@ from typing import Any, Tuple, Union, Optional
|
||||
from tianshou.data.batch import Batch
|
||||
|
||||
|
||||
class ReplayBuffer(object):
|
||||
class ReplayBuffer:
|
||||
""":class:`~tianshou.data.ReplayBuffer` stores data generated from
|
||||
interaction between the policy and environment. It stores basically 7 types
|
||||
of data, as mentioned in :class:`~tianshou.data.Batch`, based on
|
||||
@ -96,7 +96,6 @@ class ReplayBuffer(object):
|
||||
|
||||
def __init__(self, size: int, stack_num: Optional[int] = 0,
|
||||
ignore_obs_next: bool = False, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self._maxsize = size
|
||||
self._stack = stack_num
|
||||
self._save_s_ = not ignore_obs_next
|
||||
@ -137,7 +136,7 @@ class ReplayBuffer(object):
|
||||
d[k_] = self.__dict__[k__]
|
||||
else:
|
||||
d[k_] = self.__getattr__(k__)
|
||||
return Batch(**d)
|
||||
return Batch(d)
|
||||
|
||||
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
||||
if inst is None:
|
||||
@ -177,10 +176,7 @@ class ReplayBuffer(object):
|
||||
"""Move the data from the given buffer to self."""
|
||||
i = begin = buffer._index % len(buffer)
|
||||
while True:
|
||||
self.add(
|
||||
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i],
|
||||
buffer.obs_next[i] if self._save_s_ else None,
|
||||
buffer.info[i], buffer.policy[i])
|
||||
self.add(**buffer[i])
|
||||
i = (i + 1) % len(buffer)
|
||||
if i == begin:
|
||||
break
|
||||
@ -272,7 +268,7 @@ class ReplayBuffer(object):
|
||||
else:
|
||||
stack = []
|
||||
many_keys = None
|
||||
for i in range(stack_num):
|
||||
for _ in range(stack_num):
|
||||
if many_keys is not None:
|
||||
for k_ in many_keys:
|
||||
k__ = '_' + key + '@' + k_
|
||||
@ -287,7 +283,7 @@ class ReplayBuffer(object):
|
||||
if many_keys is not None:
|
||||
for k in stack:
|
||||
stack[k] = np.stack(stack[k], axis=1)
|
||||
stack = Batch(**stack)
|
||||
stack = Batch(stack)
|
||||
else:
|
||||
stack = np.stack(stack, axis=1)
|
||||
return stack
|
||||
@ -303,7 +299,7 @@ class ReplayBuffer(object):
|
||||
rew=self.rew[index],
|
||||
done=self.done[index],
|
||||
obs_next=self.get(index, 'obs_next'),
|
||||
info=self.info[index],
|
||||
info=self.get(index, 'info'),
|
||||
policy=self.get(index, 'policy'),
|
||||
)
|
||||
|
||||
@ -440,7 +436,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
rew=self.rew[index],
|
||||
done=self.done[index],
|
||||
obs_next=self.get(index, 'obs_next'),
|
||||
info=self.info[index],
|
||||
info=self.get(index, 'info'),
|
||||
weight=self.weight[index],
|
||||
policy=self.get(index, 'policy'),
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user