Buffer refactoring to support batch over batch reliably (#93)

* Fix support of batch over batch for Buffer.

* Do not use internal __dict__ attribute to store batch data since it breaks inheritance.

* Various fixes.

* Improve robustness of Batch/Buffer by avoiding direct attribute assignment. Buffer refactoring.

* Add axis optional argument to Batch stack method.

* Add item assignment to Batch class.

* Fix list support for Buffer.

* Convert list to np.array by default for efficiency.

* Add missing unit test for Batch. Fix unit tests.

* Batch item assignment is now robust to key order.

* Do not use getattr/setattr explicity for simplicity.

* More flexible __setitem__.

* Fixes

* Remove broacasting at Batch level since it is unreliable.

* Forbid item assignement for inconsistent batches.

* Implement broadcasting at Buffer level.

* Add more unit test for Batch item assignment.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-06-25 14:39:30 +02:00 committed by GitHub
parent 506cc97ba5
commit 3086b5c31d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 233 additions and 195 deletions

View File

@ -65,6 +65,15 @@ def test_batch():
assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
batch3 = Batch(a={
'c': np.zeros(1),
'd': Batch(e=np.array([0.0]), f=np.array([3.0]))})
batch3.a.d[0] = {'e': 4.0}
assert batch3.a.d.e[0] == 4.0
batch3.a.d[0] = Batch(f=5.0)
assert batch3.a.d.f[0] == 5.0
with pytest.raises(ValueError):
batch3.a.d[0] = Batch(f=5.0, g=0.0)
def test_batch_over_batch(): def test_batch_over_batch():
@ -93,16 +102,20 @@ def test_batch_over_batch():
def test_batch_cat_and_stack(): def test_batch_cat_and_stack():
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
b_cat_out = Batch.cat((b1, b2)) b12_cat_out = Batch.cat((b1, b2))
b_cat_in = copy.deepcopy(b1) b12_cat_in = copy.deepcopy(b1)
b_cat_in.cat_(b2) b12_cat_in.cat_(b2)
assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
assert isinstance(b_cat_in.a.d.e, np.ndarray) assert isinstance(b12_cat_in.a.d.e, np.ndarray)
assert b_cat_in.a.d.e.ndim == 1 assert b12_cat_in.a.d.e.ndim == 1
b_stack = Batch.stack((b1, b2)) b12_stack = Batch.stack((b1, b2))
assert isinstance(b_stack.a.d.e, np.ndarray) assert isinstance(b12_stack.a.d.e, np.ndarray)
assert b_stack.a.d.e.ndim == 2 assert b12_stack.a.d.e.ndim == 2
b3 = Batch(a=np.zeros((3, 4)))
b4 = Batch(a=np.ones((3, 4)))
b34_stack = Batch.stack((b3, b4), axis=1)
assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
def test_batch_over_batch_to_torch(): def test_batch_over_batch_to_torch():

View File

@ -75,6 +75,11 @@ class Batch:
[11 22] [6 6] [11 22] [6 6]
""" """
def __new__(cls, *args, **kwargs) -> 'Batch':
self = super().__new__(cls)
self.__dict__['_data'] = {}
return self
def __init__(self, def __init__(self,
batch_dict: Optional[Union[ batch_dict: Optional[Union[
dict, 'Batch', Tuple[Union[dict, 'Batch']], dict, 'Batch', Tuple[Union[dict, 'Batch']],
@ -95,21 +100,21 @@ class Batch:
for k, v in zip(batch_dict[0].keys(), for k, v in zip(batch_dict[0].keys(),
zip(*[e.values() for e in batch_dict])): zip(*[e.values() for e in batch_dict])):
if isinstance(v[0], dict) or _is_batch_set(v[0]): if isinstance(v[0], dict) or _is_batch_set(v[0]):
self.__dict__[k] = Batch(v) self[k] = Batch(v)
elif isinstance(v[0], (np.generic, np.ndarray)): elif isinstance(v[0], (np.generic, np.ndarray)):
self.__dict__[k] = np.stack(v, axis=0) self[k] = np.stack(v, axis=0)
elif isinstance(v[0], torch.Tensor): elif isinstance(v[0], torch.Tensor):
self.__dict__[k] = torch.stack(v, dim=0) self[k] = torch.stack(v, dim=0)
elif isinstance(v[0], Batch): elif isinstance(v[0], Batch):
self.__dict__[k] = Batch.stack(v) self[k] = Batch.stack(v)
else: else:
self.__dict__[k] = list(v) self[k] = np.array(v) # fall back to np.object
elif isinstance(batch_dict, (dict, Batch)): elif isinstance(batch_dict, (dict, Batch)):
for k, v in batch_dict.items(): for k, v in batch_dict.items():
if isinstance(v, dict) or _is_batch_set(v): if isinstance(v, dict) or _is_batch_set(v):
self.__dict__[k] = Batch(v) self[k] = Batch(v)
else: else:
self.__dict__[k] = v self[k] = v
if len(kwargs) > 0: if len(kwargs) > 0:
self.__init__(kwargs) self.__init__(kwargs)
@ -140,8 +145,8 @@ class Batch:
if isinstance(index, (int, np.integer)): if isinstance(index, (int, np.integer)):
return -length <= index and index < length return -length <= index and index < length
elif isinstance(index, (list, np.ndarray)): elif isinstance(index, (list, np.ndarray)):
return _valid_bounds(length, min(index)) and \ return _valid_bounds(length, np.min(index)) and \
_valid_bounds(length, max(index)) _valid_bounds(length, np.max(index))
elif isinstance(index, slice): elif isinstance(index, slice):
if index.start is not None: if index.start is not None:
start_valid = _valid_bounds(length, index.start) start_valid = _valid_bounds(length, index.start)
@ -154,48 +159,75 @@ class Batch:
return start_valid and stop_valid return start_valid and stop_valid
if isinstance(index, str): if isinstance(index, str):
return self.__getattr__(index) return getattr(self, index)
if not _valid_bounds(len(self), index): if not _valid_bounds(len(self), index):
raise IndexError( raise IndexError(
f"Index {index} out of bounds for Batch of len {len(self)}.") f"Index {index} out of bounds for Batch of len {len(self)}.")
else: else:
b = Batch() b = Batch()
for k, v in self.__dict__.items(): for k, v in self.items():
if isinstance(v, Batch) and v.size == 0: if isinstance(v, Batch) and v.size == 0:
b.__dict__[k] = Batch() b[k] = Batch()
elif isinstance(v, list) and len(v) == 0:
b.__dict__[k] = []
elif hasattr(v, '__len__') and (not isinstance( elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0): v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
if _valid_bounds(len(v), index): if _valid_bounds(len(v), index):
b.__dict__[k] = v[index] if isinstance(index, (int, np.integer)) or \
(isinstance(index, np.ndarray) and
index.ndim == 0) or \
not isinstance(v, list):
b[k] = v[index]
else:
b[k] = [v[i] for i in index]
else: else:
raise IndexError( raise IndexError(
f"Index {index} out of bounds for {type(v)} of " f"Index {index} out of bounds for {type(v)} of "
f"len {len(self)}.") f"len {len(self)}.")
return b return b
def __setitem__(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None:
if isinstance(index, str):
return setattr(self, index, value)
if value is None:
value = Batch()
if not isinstance(value, (dict, Batch)):
raise TypeError("Batch does not supported value type "
f"{type(value)} for item assignment.")
if not set(value.keys()).issubset(self.keys()):
raise ValueError(
"Creating keys is not supported by item assignment.")
for key in self.keys():
if isinstance(self[key], Batch):
default = Batch()
elif isinstance(self[key], np.ndarray) and \
self[key].dtype == np.integer:
# Fallback for np.array of integer,
# since neither None or nan is supported.
default = 0
else:
default = None
self[key][index] = value.get(key, default)
def __iadd__(self, val: Union['Batch', Number]): def __iadd__(self, val: Union['Batch', Number]):
if isinstance(val, Batch): if isinstance(val, Batch):
for k, r, v in zip(self.__dict__.keys(), for k, r, v in zip(self.keys(), self.values(), val.values()):
self.__dict__.values(),
val.__dict__.values()):
if r is None: if r is None:
self.__dict__[k] = r self[k] = r
elif isinstance(r, list): elif isinstance(r, list):
self.__dict__[k] = [r_ + v_ for r_, v_ in zip(r, v)] self[k] = [r_ + v_ for r_, v_ in zip(r, v)]
else: else:
self.__dict__[k] = r + v self[k] = r + v
return self return self
elif isinstance(val, Number): elif isinstance(val, Number):
for k, r in zip(self.__dict__.keys(), self.__dict__.values()): for k, r in zip(self.keys(), self.values()):
if r is None: if r is None:
self.__dict__[k] = r self[k] = r
elif isinstance(r, list): elif isinstance(r, list):
self.__dict__[k] = [r_ + val for r_ in r] self[k] = [r_ + val for r_ in r]
else: else:
self.__dict__[k] = r + val self[k] = r + val
return self return self
else: else:
raise TypeError("Only addition of Batch or number is supported.") raise TypeError("Only addition of Batch or number is supported.")
@ -206,30 +238,40 @@ class Batch:
def __mul__(self, val: Number): def __mul__(self, val: Number):
assert isinstance(val, Number), \ assert isinstance(val, Number), \
"Only multiplication by a number is supported." "Only multiplication by a number is supported."
result = Batch() result = self.__class__()
for k, r in zip(self.__dict__.keys(), self.__dict__.values()): for k, r in zip(self.keys(), self.values()):
result.__dict__[k] = r * val result[k] = r * val
return result return result
def __truediv__(self, val: Number): def __truediv__(self, val: Number):
assert isinstance(val, Number), \ assert isinstance(val, Number), \
"Only division by a number is supported." "Only division by a number is supported."
result = Batch() result = self.__class__()
for k, r in zip(self.__dict__.keys(), self.__dict__.values()): for k, r in zip(self.keys(), self.values()):
result.__dict__[k] = r / val result[k] = r / val
return result return result
def __getattr__(self, key: str) -> Union['Batch', Any]: def __getattr__(self, key: str) -> Union['Batch', Any]:
"""Return self.key""" """Return self.key"""
if key not in self.__dict__: if key in self.__dict__.keys():
raise AttributeError(key) return self.__dict__[key]
return self.__dict__[key] elif key in self._data.keys():
return self._data[key]
raise AttributeError(key)
def __setattr__(self, key, value):
if key in self._data.keys():
self._data[key] = value
elif key in self.__dict__.keys():
self.__dict__[key] = value
else:
self._data[key] = value
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return str(self).""" """Return str(self)."""
s = self.__class__.__name__ + '(\n' s = self.__class__.__name__ + '(\n'
flag = False flag = False
for k, v in self.__dict__.items(): for k, v in self.items():
rpl = '\n' + ' ' * (6 + len(k)) rpl = '\n' + ' ' * (6 + len(k))
obj = pprint.pformat(v).replace('\n', rpl) obj = pprint.pformat(v).replace('\n', rpl)
s += f' {k}: {obj},\n' s += f' {k}: {obj},\n'
@ -242,29 +284,29 @@ class Batch:
def keys(self) -> List[str]: def keys(self) -> List[str]:
"""Return self.keys().""" """Return self.keys()."""
return self.__dict__.keys() return self._data.keys()
def values(self) -> List[Any]: def values(self) -> List[Any]:
"""Return self.values().""" """Return self.values()."""
return self.__dict__.values() return self._data.values()
def items(self) -> Any: def items(self) -> Any:
"""Return self.items().""" """Return self.items()."""
return self.__dict__.items() return self._data.items()
def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]: def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]:
"""Return self[k] if k in self else d. d defaults to None.""" """Return self[k] if k in self else d. d defaults to None."""
if k in self.__dict__: if k in self.keys():
return self.__getattr__(k) return self[k]
return d return d
def to_numpy(self) -> None: def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray. This is an inplace """Change all torch.Tensor to numpy.ndarray. This is an in-place
operation. operation.
""" """
for k, v in self.__dict__.items(): for k, v in self.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
self.__dict__[k] = v.detach().cpu().numpy() self[k] = v.detach().cpu().numpy()
elif isinstance(v, Batch): elif isinstance(v, Batch):
v.to_numpy() v.to_numpy()
@ -272,18 +314,18 @@ class Batch:
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = 'cpu' device: Union[str, int, torch.device] = 'cpu'
) -> None: ) -> None:
"""Change all numpy.ndarray to torch.Tensor. This is an inplace """Change all numpy.ndarray to torch.Tensor. This is an in-place
operation. operation.
""" """
if not isinstance(device, torch.device): if not isinstance(device, torch.device):
device = torch.device(device) device = torch.device(device)
for k, v in self.__dict__.items(): for k, v in self.items():
if isinstance(v, (np.generic, np.ndarray)): if isinstance(v, (np.generic, np.ndarray)):
v = torch.from_numpy(v).to(device) v = torch.from_numpy(v).to(device)
if dtype is not None: if dtype is not None:
v = v.type(dtype) v = v.type(dtype)
self.__dict__[k] = v self[k] = v
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
if dtype is not None and v.dtype != dtype: if dtype is not None and v.dtype != dtype:
must_update_tensor = True must_update_tensor = True
@ -297,7 +339,7 @@ class Batch:
if must_update_tensor: if must_update_tensor:
if dtype is not None: if dtype is not None:
v = v.type(dtype) v = v.type(dtype)
self.__dict__[k] = v.to(device) self[k] = v.to(device)
elif isinstance(v, Batch): elif isinstance(v, Batch):
v.to_torch(dtype, device) v.to_torch(dtype, device)
@ -312,51 +354,67 @@ class Batch:
""" """
assert isinstance(batch, Batch), \ assert isinstance(batch, Batch), \
'Only Batch is allowed to be concatenated in-place!' 'Only Batch is allowed to be concatenated in-place!'
for k, v in batch.__dict__.items(): for k, v in batch.items():
if v is None: if v is None:
continue continue
if not hasattr(self, k) or self.__dict__[k] is None: if not hasattr(self, k) or self[k] is None:
self.__dict__[k] = copy.deepcopy(v) self[k] = copy.deepcopy(v)
elif isinstance(v, np.ndarray) and v.ndim > 0: elif isinstance(v, np.ndarray) and v.ndim > 0:
self.__dict__[k] = np.concatenate([self.__dict__[k], v]) self[k] = np.concatenate([self[k], v])
elif isinstance(v, torch.Tensor): elif isinstance(v, torch.Tensor):
self.__dict__[k] = torch.cat([self.__dict__[k], v]) self[k] = torch.cat([self[k], v])
elif isinstance(v, list): elif isinstance(v, list):
self.__dict__[k] += copy.deepcopy(v) self[k] = self[k] + copy.deepcopy(v)
elif isinstance(v, Batch): elif isinstance(v, Batch):
self.__dict__[k].cat_(v) self[k].cat_(v)
else: else:
s = 'No support for method "cat" with type '\ s = 'No support for method "cat" with type '\
f'{type(v)} in class Batch.' f'{type(v)} in class Batch.'
raise TypeError(s) raise TypeError(s)
@staticmethod @classmethod
def cat(batches: List['Batch']) -> None: def cat(cls, batches: List['Batch']) -> 'Batch':
"""Concatenate a :class:`~tianshou.data.Batch` object into a """Concatenate a :class:`~tianshou.data.Batch` object into a
single new batch. single new batch.
""" """
assert isinstance(batches, (tuple, list)), \ assert isinstance(batches, (tuple, list)), \
'Only list of Batch instances is allowed to be '\ 'Only list of Batch instances is allowed to be '\
'concatenated out-of-place!' 'concatenated out-of-place!'
batch = Batch() batch = cls()
for batch_ in batches: for batch_ in batches:
batch.cat_(batch_) batch.cat_(batch_)
return batch return batch
@staticmethod @classmethod
def stack(batches: List['Batch']): def stack(cls, batches: List['Batch'], axis: int = 0) -> 'Batch':
"""Stack a :class:`~tianshou.data.Batch` object into a """Stack a :class:`~tianshou.data.Batch` object into a
single new batch. single new batch.
""" """
assert isinstance(batches, (tuple, list)), \ assert isinstance(batches, (tuple, list)), \
'Only list of Batch instances is allowed to be '\ 'Only list of Batch instances is allowed to be '\
'stacked out-of-place!' 'stacked out-of-place!'
return Batch(np.array([batch.__dict__ for batch in batches])) if axis == 0:
return cls(batches)
else:
batch = Batch()
for k, v in zip(batches[0].keys(),
zip(*[e.values() for e in batches])):
if isinstance(v[0], (np.generic, np.ndarray, list)):
batch[k] = np.stack(v, axis)
elif isinstance(v[0], torch.Tensor):
batch[k] = torch.stack(v, axis)
elif isinstance(v[0], Batch):
batch[k] = Batch.stack(v, axis)
else:
s = 'No support for method "stack" with type '\
f'{type(v[0])} in class Batch and axis != 0.'
raise TypeError(s)
return batch
def __len__(self) -> int: def __len__(self) -> int:
"""Return len(self).""" """Return len(self)."""
r = [] r = []
for v in self.__dict__.values(): for v in self.values():
if isinstance(v, Batch) and v.size == 0: if isinstance(v, Batch) and v.size == 0:
continue continue
elif isinstance(v, list) and len(v) == 0: elif isinstance(v, list) and len(v) == 0:
@ -373,11 +431,11 @@ class Batch:
@property @property
def size(self) -> int: def size(self) -> int:
"""Return self.size.""" """Return self.size."""
if len(self.__dict__) == 0: if len(self.keys()) == 0:
return 0 return 0
else: else:
r = [] r = []
for v in self.__dict__.values(): for v in self.values():
if isinstance(v, Batch): if isinstance(v, Batch):
r.append(v.size) r.append(v.size)
elif hasattr(v, '__len__') and (not isinstance( elif hasattr(v, '__len__') and (not isinstance(

View File

@ -1,11 +1,11 @@
import pprint
import numpy as np import numpy as np
from numbers import Number
from typing import Any, Tuple, Union, Optional from typing import Any, Tuple, Union, Optional
from tianshou.data.batch import Batch from .batch import Batch
class ReplayBuffer: class ReplayBuffer(Batch):
""":class:`~tianshou.data.ReplayBuffer` stores data generated from """:class:`~tianshou.data.ReplayBuffer` stores data generated from
interaction between the policy and environment. It stores basically 7 types interaction between the policy and environment. It stores basically 7 types
of data, as mentioned in :class:`~tianshou.data.Batch`, based on of data, as mentioned in :class:`~tianshou.data.Batch`, based on
@ -96,81 +96,47 @@ class ReplayBuffer:
def __init__(self, size: int, stack_num: Optional[int] = 0, def __init__(self, size: int, stack_num: Optional[int] = 0,
ignore_obs_next: bool = False, **kwargs) -> None: ignore_obs_next: bool = False, **kwargs) -> None:
self._maxsize = size super().__init__()
self._stack = stack_num self.__dict__['_maxsize'] = size
self._save_s_ = not ignore_obs_next self.__dict__['_stack'] = stack_num
self._meta = {} self.__dict__['_save_s_'] = not ignore_obs_next
self.__dict__['_index'] = 0
self.__dict__['_size'] = 0
self.reset() self.reset()
def __len__(self) -> int: def __len__(self) -> int:
"""Return len(self).""" """Return len(self)."""
return self._size return self._size
def __repr__(self) -> str:
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
for k in sorted(list(self.__dict__) + list(self._meta)):
if k[0] != '_' and (self.__dict__.get(k, None) is not None or
k in self._meta):
rpl = '\n' + ' ' * (6 + len(k))
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
s += f' {k}: {obj},\n'
flag = True
if flag:
s += ')'
else:
s = self.__class__.__name__ + '()'
return s
def __getattr__(self, key: str) -> Union[Batch, np.ndarray]:
"""Return self.key"""
if key not in self._meta:
if key not in self.__dict__:
raise AttributeError(key)
return self.__dict__[key]
d = {}
for k_ in self._meta[key]:
k__ = '_' + key + '@' + k_
if k__ in self.__dict__:
d[k_] = self.__dict__[k__]
else:
d[k_] = self.__getattr__(k__)
return Batch(d)
def _add_to_buffer(self, name: str, inst: Any) -> None: def _add_to_buffer(self, name: str, inst: Any) -> None:
if inst is None: def _create_value(inst: Any) -> Union['Batch', np.ndarray]:
if getattr(self, name, None) is None:
self.__dict__[name] = None
return
if name in self._meta:
for k in inst.keys():
self._add_to_buffer('_' + name + '@' + k, inst[k])
return
if self.__dict__.get(name, None) is None:
if isinstance(inst, np.ndarray): if isinstance(inst, np.ndarray):
self.__dict__[name] = np.zeros( return np.zeros(
(self._maxsize, *inst.shape), dtype=inst.dtype) (self._maxsize, *inst.shape), dtype=inst.dtype)
elif isinstance(inst, (dict, Batch)): elif isinstance(inst, (dict, Batch)):
if self._meta.get(name, None) is None: return Batch([Batch(inst) for _ in range(self._maxsize)])
self._meta[name] = list(inst.keys()) elif isinstance(inst, (np.generic, Number)):
for k in inst.keys(): return np.zeros(
k_ = '_' + name + '@' + k
self._add_to_buffer(k_, inst[k])
elif np.isscalar(inst):
self.__dict__[name] = np.zeros(
(self._maxsize,), dtype=np.asarray(inst).dtype) (self._maxsize,), dtype=np.asarray(inst).dtype)
else: # fall back to np.object else: # fall back to np.object
self.__dict__[name] = np.array( return np.array([None for _ in range(self._maxsize)])
[None for _ in range(self._maxsize)])
if inst is None:
inst = Batch()
if name not in self.keys():
self[name] = _create_value(inst)
if isinstance(inst, np.ndarray) and \ if isinstance(inst, np.ndarray) and \
self.__dict__[name].shape[1:] != inst.shape: self[name].shape[1:] != inst.shape:
raise ValueError( raise ValueError(
"Cannot add data to a buffer with different shape, " "Cannot add data to a buffer with different shape, "
f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, " f"key: {name}, expect shape: {self[name].shape[1:]}"
f"given shape: {inst.shape}.") f", given shape: {inst.shape}.")
if name not in self._meta: if isinstance(self[name], Batch):
self.__dict__[name][self._index] = inst field_keys = self[name].keys()
for key, val in inst.items():
if key not in field_keys:
self[name][key] = _create_value(val)
self[name][self._index] = inst
def update(self, buffer: 'ReplayBuffer') -> None: def update(self, buffer: 'ReplayBuffer') -> None:
"""Move the data from the given buffer to self.""" """Move the data from the given buffer to self."""
@ -209,7 +175,8 @@ class ReplayBuffer:
def reset(self) -> None: def reset(self) -> None:
"""Clear all the data in replay buffer.""" """Clear all the data in replay buffer."""
self._index = self._size = 0 self._index = 0
self._size = 0
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
"""Get a random sample from buffer with size equal to batch_size. \ """Get a random sample from buffer with size equal to batch_size. \
@ -226,7 +193,7 @@ class ReplayBuffer:
]) ])
return self[indice], indice return self[indice], indice
def get(self, indice: Union[slice, np.ndarray], key: str, def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str,
stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]: stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]:
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
where s is self.key, t is indice. The stack_num (here equals to 4) is where s is self.key, t is indice. The stack_num (here equals to 4) is
@ -234,20 +201,16 @@ class ReplayBuffer:
""" """
if stack_num is None: if stack_num is None:
stack_num = self._stack stack_num = self._stack
if not isinstance(indice, np.ndarray): if isinstance(indice, slice):
if np.isscalar(indice): indice = np.arange(
indice = np.array(indice) 0 if indice.start is None
elif isinstance(indice, slice): else self._size - indice.start if indice.start < 0
indice = np.arange( else indice.start,
0 if indice.start is None self._size if indice.stop is None
else self._size - indice.start if indice.start < 0 else self._size - indice.stop if indice.stop < 0
else indice.start, else indice.stop,
self._size if indice.stop is None 1 if indice.step is None else indice.step)
else self._size - indice.stop if indice.stop < 0 indice = np.array(indice, copy=True)
else indice.stop,
1 if indice.step is None else indice.step)
else:
indice = np.array(indice)
# set last frame done to True # set last frame done to True
last_index = (self._index - 1 + self._size) % self._size last_index = (self._index - 1 + self._size) % self._size
last_done, self.done[last_index] = self.done[last_index], True last_done, self.done[last_index] = self.done[last_index], True
@ -257,49 +220,51 @@ class ReplayBuffer:
key = 'obs' key = 'obs'
if stack_num == 0: if stack_num == 0:
self.done[last_index] = last_done self.done[last_index] = last_done
if key in self._meta: val = self[key]
return {k: self.__dict__['_' + key + '@' + k][indice] if isinstance(val, Batch) and val.size == 0:
for k in self._meta[key]} return val
else: else:
return self.__dict__[key][indice] if isinstance(indice, (int, np.integer)) or \
if key in self._meta: (isinstance(indice, np.ndarray) and
many_keys = self._meta[key] indice.ndim == 0) or not isinstance(val, list):
stack = {k: [] for k in self._meta[key]} return val[indice]
else:
return [val[i] for i in indice]
else: else:
stack = [] val = self[key]
many_keys = None if not isinstance(val, Batch) or val.size > 0:
for _ in range(stack_num): stack = []
if many_keys is not None: for _ in range(stack_num):
for k_ in many_keys: stack = [val[indice]] + stack
k__ = '_' + key + '@' + k_ pre_indice = np.asarray(indice - 1)
stack[k_] = [self.__dict__[k__][indice]] + stack[k_] pre_indice[pre_indice == -1] = self._size - 1
indice = np.asarray(
pre_indice + self.done[pre_indice].astype(np.int))
indice[indice == self._size] = 0
if isinstance(stack[0], Batch):
stack = Batch.stack(stack, axis=indice.ndim)
else:
stack = np.stack(stack, axis=indice.ndim)
else: else:
stack = [self.__dict__[key][indice]] + stack stack = Batch()
pre_indice = indice - 1 self.done[last_index] = last_done
pre_indice[pre_indice == -1] = self._size - 1 return stack
indice = pre_indice + self.done[pre_indice].astype(np.int)
indice[indice == self._size] = 0
self.done[last_index] = last_done
if many_keys is not None:
for k in stack:
stack[k] = np.stack(stack[k], axis=1)
stack = Batch(stack)
else:
stack = np.stack(stack, axis=1)
return stack
def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch: def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
"""Return a data batch: self[index]. If stack_num is set to be > 0, """Return a data batch: self[index]. If stack_num is set to be > 0,
return the stacked obs and obs_next with shape [batch, len, ...]. return the stacked obs and obs_next with shape [batch, len, ...].
""" """
if isinstance(index, str):
return getattr(self, index)
return Batch( return Batch(
obs=self.get(index, 'obs'), obs=self.get(index, 'obs'),
act=self.act[index], act=self.get(index, 'act', stack_num=0),
# act_=self.get(index, 'act'), # stacked action, for RNN # act_=self.get(index, 'act'), # stacked action, for RNN
rew=self.rew[index], rew=self.get(index, 'rew', stack_num=0),
done=self.done[index], done=self.get(index, 'done', stack_num=0),
obs_next=self.get(index, 'obs_next'), obs_next=self.get(index, 'obs_next'),
info=self.get(index, 'info'), info=self.get(index, 'info', stack_num=0),
policy=self.get(index, 'policy'), policy=self.get(index, 'policy'),
) )
@ -323,15 +288,15 @@ class ListReplayBuffer(ReplayBuffer):
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None: inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
if inst is None: if inst is None:
return return
if self.__dict__.get(name, None) is None: if self._data.get(name, None) is None:
self.__dict__[name] = [] self._data[name] = []
self.__dict__[name].append(inst) self._data[name].append(inst)
def reset(self) -> None: def reset(self) -> None:
self._index = self._size = 0 self._index = self._size = 0
for k in list(self.__dict__): for k in list(self._data):
if isinstance(self.__dict__[k], list): if isinstance(self._data[k], list):
self.__dict__[k] = [] self._data[k] = []
class PrioritizedReplayBuffer(ReplayBuffer): class PrioritizedReplayBuffer(ReplayBuffer):
@ -449,16 +414,18 @@ class PrioritizedReplayBuffer(ReplayBuffer):
- self.weight[indice].sum() - self.weight[indice].sum()
self.weight[indice] = np.power(np.abs(new_weight), self._alpha) self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch: def __getitem__(self, index: Union[str, slice, np.ndarray]) -> Batch:
if isinstance(index, str):
return getattr(self, index)
return Batch( return Batch(
obs=self.get(index, 'obs'), obs=self.get(index, 'obs'),
act=self.act[index], act=self.get(index, 'act', stack_num=0),
# act_=self.get(index, 'act'), # stacked action, for RNN # act_=self.get(index, 'act'), # stacked action, for RNN
rew=self.rew[index], rew=self.get(index, 'rew', stack_num=0),
done=self.done[index], done=self.get(index, 'done', stack_num=0),
obs_next=self.get(index, 'obs_next'), obs_next=self.get(index, 'obs_next'),
info=self.get(index, 'info'), info=self.get(index, 'info'),
weight=self.weight[index], weight=self.get(index, 'weight', stack_num=0),
policy=self.get(index, 'policy'), policy=self.get(index, 'policy'),
) )