Use lower-level API to reduce overhead. (#97)

* Use lower-level API to reduce overhead.

* Further improvements.

* Buffer _add_to_buffer improvement.

* Do not use _data field to store Batch data to avoid overhead. Add back _meta field in Buffer.

* Restore metadata attribute to store batch in Buffer.

* Move out nested methods.

* Update try/catch instead of actual check to efficiency.

* Remove unsed branches for efficiency.

* Use np.array over list when possible for efficiency.

* Final performance improvement.

* Add unit tests for Batch size method.

* Add missing stack unit tests.

* Enforce Buffer initialization to zero.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-06-26 12:37:50 +02:00 committed by GitHub
parent 5ac9f9b144
commit 70aa7bf93e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 202 additions and 207 deletions

View File

@ -39,12 +39,17 @@ def test_batch():
'c': np.zeros(1), 'c': np.zeros(1),
'd': Batch(e=np.array(3.0))}]) 'd': Batch(e=np.array(3.0))}])
assert len(batch2) == 1 assert len(batch2) == 1
assert Batch().size == 0
assert batch2.size == 1
with pytest.raises(IndexError): with pytest.raises(IndexError):
batch2[-2] batch2[-2]
with pytest.raises(IndexError): with pytest.raises(IndexError):
batch2[1] batch2[1]
assert batch2[0].size == 1
with pytest.raises(TypeError): with pytest.raises(TypeError):
batch2[0][0] batch2[0][0]
with pytest.raises(TypeError):
len(batch2[0])
assert isinstance(batch2[0].a.c, np.ndarray) assert isinstance(batch2[0].a.c, np.ndarray)
assert isinstance(batch2[0].a.b, np.float64) assert isinstance(batch2[0].a.b, np.float64)
assert isinstance(batch2[0].a.d.e, np.float64) assert isinstance(batch2[0].a.d.e, np.float64)
@ -72,7 +77,7 @@ def test_batch():
assert batch3.a.d.e[0] == 4.0 assert batch3.a.d.e[0] == 4.0
batch3.a.d[0] = Batch(f=5.0) batch3.a.d[0] = Batch(f=5.0)
assert batch3.a.d.f[0] == 5.0 assert batch3.a.d.f[0] == 5.0
with pytest.raises(ValueError): with pytest.raises(KeyError):
batch3.a.d[0] = Batch(f=5.0, g=0.0) batch3.a.d[0] = Batch(f=5.0, g=0.0)
@ -112,10 +117,15 @@ def test_batch_cat_and_stack():
b12_stack = Batch.stack((b1, b2)) b12_stack = Batch.stack((b1, b2))
assert isinstance(b12_stack.a.d.e, np.ndarray) assert isinstance(b12_stack.a.d.e, np.ndarray)
assert b12_stack.a.d.e.ndim == 2 assert b12_stack.a.d.e.ndim == 2
b3 = Batch(a=np.zeros((3, 4))) b3 = Batch(a=np.zeros((3, 4)),
b4 = Batch(a=np.ones((3, 4))) b=torch.ones((2, 5)),
c=Batch(d=[[1], [2]]))
b4 = Batch(a=np.ones((3, 4)),
b=torch.ones((2, 5)),
c=Batch(d=[[0], [3]]))
b34_stack = Batch.stack((b3, b4), axis=1) b34_stack = Batch.stack((b3, b4), axis=1)
assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1)) assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d))))
def test_batch_over_batch_to_torch(): def test_batch_over_batch_to_torch():

View File

@ -27,6 +27,16 @@ def test_replaybuffer(size=10, bufsize=20):
assert len(buf) == len(buf2) assert len(buf) == len(buf2)
assert buf2[0].obs == buf[5].obs assert buf2[0].obs == buf[5].obs
assert buf2[-1].obs == buf[4].obs assert buf2[-1].obs == buf[4].obs
b = ReplayBuffer(size=10)
b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
assert b.obs[0] == 1
assert b.done[0] == 'str'
assert np.all(b.obs[1:] == 0)
assert np.all(b.done[1:] == np.array(None))
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
assert np.all(b.info.a[1:] == 0)
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
assert np.all(np.isnan(b.info.b.c[1:]))
def test_ignore_obs_next(size=10): def test_ignore_obs_next(size=10):

View File

@ -13,6 +13,35 @@ warnings.filterwarnings(
"ignore", message="pickle support for Storage will be removed in 1.5.") "ignore", message="pickle support for Storage will be removed in 1.5.")
def _is_batch_set(data: Any) -> bool:
if isinstance(data, (list, tuple)):
if len(data) > 0 and isinstance(data[0], (dict, Batch)):
return True
elif isinstance(data, np.ndarray):
if isinstance(data.item(0), (dict, Batch)):
return True
return False
def _valid_bounds(length: int, index: Union[
slice, int, np.integer, np.ndarray, List[int]]) -> bool:
if isinstance(index, (int, np.integer)):
return -length <= index and index < length
elif isinstance(index, (list, np.ndarray)):
return _valid_bounds(length, np.min(index)) and \
_valid_bounds(length, np.max(index))
elif isinstance(index, slice):
if index.start is not None:
start_valid = _valid_bounds(length, index.start)
else:
start_valid = True
if index.stop is not None:
stop_valid = _valid_bounds(length, index.stop - 1)
else:
stop_valid = True
return start_valid and stop_valid
class Batch: class Batch:
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data """Tianshou provides :class:`~tianshou.data.Batch` as the internal data
structure to pass any kind of data to other methods, for example, a structure to pass any kind of data to other methods, for example, a
@ -75,46 +104,30 @@ class Batch:
[11 22] [6 6] [11 22] [6 6]
""" """
def __new__(cls, *args, **kwargs) -> 'Batch':
self = super().__new__(cls)
self.__dict__['_data'] = {}
return self
def __init__(self, def __init__(self,
batch_dict: Optional[Union[ batch_dict: Optional[Union[
dict, 'Batch', Tuple[Union[dict, 'Batch']], dict, 'Batch', Tuple[Union[dict, 'Batch']],
List[Union[dict, 'Batch']], np.ndarray]] = None, List[Union[dict, 'Batch']], np.ndarray]] = None,
**kwargs) -> None: **kwargs) -> None:
def _is_batch_set(data: Any) -> bool:
if isinstance(data, (list, tuple)):
if len(data) > 0 and isinstance(data[0], (dict, Batch)):
return True
elif isinstance(data, np.ndarray):
if isinstance(data.item(0), (dict, Batch)):
return True
return False
if isinstance(batch_dict, np.ndarray) and batch_dict.ndim == 0:
batch_dict = batch_dict[()]
if _is_batch_set(batch_dict): if _is_batch_set(batch_dict):
for k, v in zip(batch_dict[0].keys(), for k, v in zip(batch_dict[0].keys(),
zip(*[e.values() for e in batch_dict])): zip(*[e.values() for e in batch_dict])):
if isinstance(v[0], dict) or _is_batch_set(v[0]): if isinstance(v[0], dict) or _is_batch_set(v[0]):
self[k] = Batch(v) self.__dict__[k] = Batch(v)
elif isinstance(v[0], (np.generic, np.ndarray)): elif isinstance(v[0], (np.generic, np.ndarray)):
self[k] = np.stack(v, axis=0) self.__dict__[k] = np.stack(v, axis=0)
elif isinstance(v[0], torch.Tensor): elif isinstance(v[0], torch.Tensor):
self[k] = torch.stack(v, dim=0) self.__dict__[k] = torch.stack(v, dim=0)
elif isinstance(v[0], Batch): elif isinstance(v[0], Batch):
self[k] = Batch.stack(v) self.__dict__[k] = Batch.stack(v)
else: else:
self[k] = np.array(v) # fall back to np.object self.__dict__[k] = np.array(v)
elif isinstance(batch_dict, (dict, Batch)): elif isinstance(batch_dict, (dict, Batch)):
for k, v in batch_dict.items(): for k, v in batch_dict.items():
if isinstance(v, dict) or _is_batch_set(v): if isinstance(v, dict) or _is_batch_set(v):
self[k] = Batch(v) self.__dict__[k] = Batch(v)
else: else:
self[k] = v self.__dict__[k] = v
if len(kwargs) > 0: if len(kwargs) > 0:
self.__init__(kwargs) self.__init__(kwargs)
@ -123,8 +136,7 @@ class Batch:
for both efficiency and simplicity. for both efficiency and simplicity.
""" """
state = {} state = {}
for k in self.keys(): for k, v in self.items():
v = self[k]
if isinstance(v, Batch): if isinstance(v, Batch):
v = v.__getstate__() v = v.__getstate__()
state[k] = v state[k] = v
@ -140,26 +152,8 @@ class Batch:
def __getitem__(self, index: Union[ def __getitem__(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]]) -> 'Batch': str, slice, int, np.integer, np.ndarray, List[int]]) -> 'Batch':
"""Return self[index].""" """Return self[index]."""
def _valid_bounds(length: int, index: Union[
slice, int, np.integer, np.ndarray, List[int]]) -> bool:
if isinstance(index, (int, np.integer)):
return -length <= index and index < length
elif isinstance(index, (list, np.ndarray)):
return _valid_bounds(length, np.min(index)) and \
_valid_bounds(length, np.max(index))
elif isinstance(index, slice):
if index.start is not None:
start_valid = _valid_bounds(length, index.start)
else:
start_valid = True
if index.stop is not None:
stop_valid = _valid_bounds(length, index.stop - 1)
else:
stop_valid = True
return start_valid and stop_valid
if isinstance(index, str): if isinstance(index, str):
return getattr(self, index) return self.__dict__[index]
if not _valid_bounds(len(self), index): if not _valid_bounds(len(self), index):
raise IndexError( raise IndexError(
@ -167,61 +161,57 @@ class Batch:
else: else:
b = Batch() b = Batch()
for k, v in self.items(): for k, v in self.items():
if isinstance(v, Batch) and v.size == 0: if isinstance(v, Batch) and len(v.__dict__) == 0:
b[k] = Batch() b.__dict__[k] = Batch()
elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
if isinstance(index, (int, np.integer)) or \
(isinstance(index, np.ndarray) and
index.ndim == 0) or not isinstance(v, list):
b[k] = v[index]
else: else:
b[k] = [v[i] for i in index] b.__dict__[k] = v[index]
return b return b
def __setitem__(self, index: Union[ def __setitem__(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]], str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None: value: Any) -> None:
if isinstance(index, str): if isinstance(index, str):
return setattr(self, index, value) self.__dict__[index] = value
if value is None: return
value = Batch()
if not isinstance(value, (dict, Batch)): if not isinstance(value, (dict, Batch)):
raise TypeError("Batch does not supported value type " raise TypeError("Batch does not supported value type "
f"{type(value)} for item assignment.") f"{type(value)} for item assignment.")
if not set(value.keys()).issubset(self.keys()): if not set(value.keys()).issubset(self.__dict__.keys()):
raise ValueError( raise KeyError(
"Creating keys is not supported by item assignment.") "Creating keys is not supported by item assignment.")
for key in self.keys(): for key, val in self.items():
if isinstance(self[key], Batch): try:
default = Batch() self.__dict__[key][index] = value[key]
elif isinstance(self[key], np.ndarray) and \ except KeyError:
self[key].dtype == np.integer: if isinstance(val, Batch):
self.__dict__[key][index] = Batch()
elif isinstance(val, np.ndarray) and \
val.dtype == np.integer:
# Fallback for np.array of integer, # Fallback for np.array of integer,
# since neither None or nan is supported. # since neither None or nan is supported.
default = 0 self.__dict__[key][index] = 0
else: else:
default = None self.__dict__[key][index] = None
self[key][index] = value.get(key, default)
def __iadd__(self, val: Union['Batch', Number]): def __iadd__(self, val: Union['Batch', Number]):
if isinstance(val, Batch): if isinstance(val, Batch):
for k, r, v in zip(self.keys(), self.values(), val.values()): for (k, r), v in zip(self.__dict__.items(),
val.__dict__.values()):
if r is None: if r is None:
self[k] = r continue
elif isinstance(r, list): elif isinstance(r, list):
self[k] = [r_ + v_ for r_, v_ in zip(r, v)] self.__dict__[k] = [r_ + v_ for r_, v_ in zip(r, v)]
else: else:
self[k] = r + v self.__dict__[k] += v
return self return self
elif isinstance(val, Number): elif isinstance(val, Number):
for k, r in zip(self.keys(), self.values()): for k, r in self.items():
if r is None: if r is None:
self[k] = r continue
elif isinstance(r, list): elif isinstance(r, list):
self[k] = [r_ + val for r_ in r] self.__dict__[k] = [r_ + val for r_ in r]
else: else:
self[k] = r + val self.__dict__[k] += val
return self return self
else: else:
raise TypeError("Only addition of Batch or number is supported.") raise TypeError("Only addition of Batch or number is supported.")
@ -229,37 +219,25 @@ class Batch:
def __add__(self, val: Union['Batch', Number]): def __add__(self, val: Union['Batch', Number]):
return copy.deepcopy(self).__iadd__(val) return copy.deepcopy(self).__iadd__(val)
def __mul__(self, val: Number): def __imul__(self, val: Number):
assert isinstance(val, Number), \ assert isinstance(val, Number), \
"Only multiplication by a number is supported." "Only multiplication by a number is supported."
result = self.__class__() for k in self.__dict__.keys():
for k, r in zip(self.keys(), self.values()): self.__dict__[k] *= val
result[k] = r * val return self
return result
def __truediv__(self, val: Number): def __mul__(self, val: Number):
return copy.deepcopy(self).__imul__(val)
def __itruediv__(self, val: Number):
assert isinstance(val, Number), \ assert isinstance(val, Number), \
"Only division by a number is supported." "Only division by a number is supported."
result = self.__class__() for k in self.__dict__.keys():
for k, r in zip(self.keys(), self.values()): self.__dict__[k] /= val
result[k] = r / val return self
return result
def __getattr__(self, key: str) -> Union['Batch', Any]: def __truediv__(self, val: Number):
"""Return self.key""" return copy.deepcopy(self).__itruediv__(val)
if key in self.__dict__.keys():
return self.__dict__[key]
elif key in self._data.keys():
return self._data[key]
raise AttributeError(key)
def __setattr__(self, key, value):
if key in self._data.keys():
self._data[key] = value
elif key in self.__dict__.keys():
self.__dict__[key] = value
else:
self._data[key] = value
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return str(self).""" """Return str(self)."""
@ -278,21 +256,19 @@ class Batch:
def keys(self) -> List[str]: def keys(self) -> List[str]:
"""Return self.keys().""" """Return self.keys()."""
return self._data.keys() return self.__dict__.keys()
def values(self) -> List[Any]: def values(self) -> List[Any]:
"""Return self.values().""" """Return self.values()."""
return self._data.values() return self.__dict__.values()
def items(self) -> Any: def items(self) -> List[Tuple[str, Any]]:
"""Return self.items().""" """Return self.items()."""
return self._data.items() return self.__dict__.items()
def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]: def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]:
"""Return self[k] if k in self else d. d defaults to None.""" """Return self[k] if k in self else d. d defaults to None."""
if k in self.keys(): return self.__dict__.get(k, d)
return self[k]
return d
def to_numpy(self) -> None: def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray. This is an in-place """Change all torch.Tensor to numpy.ndarray. This is an in-place
@ -300,7 +276,7 @@ class Batch:
""" """
for k, v in self.items(): for k, v in self.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
self[k] = v.detach().cpu().numpy() self.__dict__[k] = v.detach().cpu().numpy()
elif isinstance(v, Batch): elif isinstance(v, Batch):
v.to_numpy() v.to_numpy()
@ -319,7 +295,7 @@ class Batch:
v = torch.from_numpy(v).to(device) v = torch.from_numpy(v).to(device)
if dtype is not None: if dtype is not None:
v = v.type(dtype) v = v.type(dtype)
self[k] = v self.__dict__[k] = v
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
if dtype is not None and v.dtype != dtype: if dtype is not None and v.dtype != dtype:
must_update_tensor = True must_update_tensor = True
@ -333,7 +309,7 @@ class Batch:
if must_update_tensor: if must_update_tensor:
if dtype is not None: if dtype is not None:
v = v.type(dtype) v = v.type(dtype)
self[k] = v.to(device) self.__dict__[k] = v.to(device)
elif isinstance(v, Batch): elif isinstance(v, Batch):
v.to_torch(dtype, device) v.to_torch(dtype, device)
@ -351,16 +327,16 @@ class Batch:
for k, v in batch.items(): for k, v in batch.items():
if v is None: if v is None:
continue continue
if not hasattr(self, k) or self[k] is None: if not hasattr(self, k) or self.__dict__[k] is None:
self[k] = copy.deepcopy(v) self.__dict__[k] = copy.deepcopy(v)
elif isinstance(v, np.ndarray) and v.ndim > 0: elif isinstance(v, np.ndarray) and v.ndim > 0:
self[k] = np.concatenate([self[k], v]) self.__dict__[k] = np.concatenate([self.__dict__[k], v])
elif isinstance(v, torch.Tensor): elif isinstance(v, torch.Tensor):
self[k] = torch.cat([self[k], v]) self.__dict__[k] = torch.cat([self.__dict__[k], v])
elif isinstance(v, list): elif isinstance(v, list):
self[k] = self[k] + copy.deepcopy(v) self.__dict__[k] += copy.deepcopy(v)
elif isinstance(v, Batch): elif isinstance(v, Batch):
self[k].cat_(v) self.__dict__[k].cat_(v)
else: else:
s = 'No support for method "cat" with type '\ s = 'No support for method "cat" with type '\
f'{type(v)} in class Batch.' f'{type(v)} in class Batch.'
@ -394,11 +370,11 @@ class Batch:
for k, v in zip(batches[0].keys(), for k, v in zip(batches[0].keys(),
zip(*[e.values() for e in batches])): zip(*[e.values() for e in batches])):
if isinstance(v[0], (np.generic, np.ndarray, list)): if isinstance(v[0], (np.generic, np.ndarray, list)):
batch[k] = np.stack(v, axis) batch.__dict__[k] = np.stack(v, axis)
elif isinstance(v[0], torch.Tensor): elif isinstance(v[0], torch.Tensor):
batch[k] = torch.stack(v, axis) batch.__dict__[k] = torch.stack(v, axis)
elif isinstance(v[0], Batch): elif isinstance(v[0], Batch):
batch[k] = Batch.stack(v, axis) batch.__dict__[k] = Batch.stack(v, axis)
else: else:
s = 'No support for method "stack" with type '\ s = 'No support for method "stack" with type '\
f'{type(v[0])} in class Batch and axis != 0.' f'{type(v[0])} in class Batch and axis != 0.'
@ -408,10 +384,8 @@ class Batch:
def __len__(self) -> int: def __len__(self) -> int:
"""Return len(self).""" """Return len(self)."""
r = [] r = []
for v in self.values(): for v in self.__dict__.values():
if isinstance(v, Batch) and v.size == 0: if isinstance(v, Batch) and len(v.__dict__) == 0:
continue
elif isinstance(v, list) and len(v) == 0:
continue continue
elif hasattr(v, '__len__') and (not isinstance( elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0): v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
@ -425,11 +399,11 @@ class Batch:
@property @property
def size(self) -> int: def size(self) -> int:
"""Return self.size.""" """Return self.size."""
if len(self.keys()) == 0: if len(self.__dict__.keys()) == 0:
return 0 return 0
else: else:
r = [] r = []
for v in self.values(): for v in self.__dict__.values():
if isinstance(v, Batch): if isinstance(v, Batch):
r.append(v.size) r.append(v.size)
elif hasattr(v, '__len__') and (not isinstance( elif hasattr(v, '__len__') and (not isinstance(

View File

@ -5,7 +5,23 @@ from typing import Any, Tuple, Union, Optional
from .batch import Batch from .batch import Batch
class ReplayBuffer(Batch): def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
if isinstance(inst, np.ndarray):
return np.full(shape=(size, *inst.shape),
fill_value=None if inst.dtype == np.inexact else 0,
dtype=inst.dtype)
elif isinstance(inst, (dict, Batch)):
zero_batch = Batch()
for key, val in inst.items():
zero_batch.__dict__[key] = _create_value(val, size)
return zero_batch
elif isinstance(inst, (np.generic, Number)):
return _create_value(np.asarray(inst), size)
else: # fall back to np.object
return np.array([None for _ in range(size)])
class ReplayBuffer:
""":class:`~tianshou.data.ReplayBuffer` stores data generated from """:class:`~tianshou.data.ReplayBuffer` stores data generated from
interaction between the policy and environment. It stores basically 7 types interaction between the policy and environment. It stores basically 7 types
of data, as mentioned in :class:`~tianshou.data.Batch`, based on of data, as mentioned in :class:`~tianshou.data.Batch`, based on
@ -93,50 +109,46 @@ class ReplayBuffer(Batch):
[ 7. 7. 7. 8.] [ 7. 7. 7. 8.]
[ 7. 7. 8. 9.]] [ 7. 7. 8. 9.]]
""" """
def __init__(self, size: int, stack_num: Optional[int] = 0, def __init__(self, size: int, stack_num: Optional[int] = 0,
ignore_obs_next: bool = False, **kwargs) -> None: ignore_obs_next: bool = False, **kwargs) -> None:
super().__init__() super().__init__()
self.__dict__['_maxsize'] = size self._maxsize = size
self.__dict__['_stack'] = stack_num self._stack = stack_num
self.__dict__['_save_s_'] = not ignore_obs_next self._save_s_ = not ignore_obs_next
self.__dict__['_index'] = 0 self._index = 0
self.__dict__['_size'] = 0 self._size = 0
self._meta = Batch()
self.reset() self.reset()
def __len__(self) -> int: def __len__(self) -> int:
"""Return len(self).""" """Return len(self)."""
return self._size return self._size
def _add_to_buffer(self, name: str, inst: Any) -> None: def __repr__(self) -> str:
def _create_value(inst: Any) -> Union['Batch', np.ndarray]: return self.__class__.__name__ + self._meta.__repr__()[5:]
if isinstance(inst, np.ndarray):
return np.zeros(
(self._maxsize, *inst.shape), dtype=inst.dtype)
elif isinstance(inst, (dict, Batch)):
return Batch([Batch(inst) for _ in range(self._maxsize)])
elif isinstance(inst, (np.generic, Number)):
return np.zeros(
(self._maxsize,), dtype=np.asarray(inst).dtype)
else: # fall back to np.object
return np.array([None for _ in range(self._maxsize)])
if inst is None: def __getattr__(self, key: str) -> Union['Batch', Any]:
inst = Batch() """Return self.key"""
if name not in self.keys(): return self._meta.__dict__[key]
self[name] = _create_value(inst)
def _add_to_buffer(self, name: str, inst: Any) -> None:
try:
value = self._meta.__dict__[name]
except KeyError:
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
value = self._meta.__dict__[name]
if isinstance(inst, np.ndarray) and \ if isinstance(inst, np.ndarray) and \
self[name].shape[1:] != inst.shape: value.shape[1:] != inst.shape:
raise ValueError( raise ValueError(
"Cannot add data to a buffer with different shape, " "Cannot add data to a buffer with different shape, key: "
f"key: {name}, expect shape: {self[name].shape[1:]}" f"{name}, expect shape: {value.shape[1:]}"
f", given shape: {inst.shape}.") f", given shape: {inst.shape}.")
if isinstance(self[name], Batch): try:
field_keys = self[name].keys() value[self._index] = inst
for key, val in inst.items(): except KeyError:
if key not in field_keys: for key in set(inst.keys()).difference(value.__dict__.keys()):
self[name][key] = _create_value(val) value.__dict__[key] = _create_value(inst[key], self._maxsize)
self[name][self._index] = inst value[self._index] = inst
def update(self, buffer: 'ReplayBuffer') -> None: def update(self, buffer: 'ReplayBuffer') -> None:
"""Move the data from the given buffer to self.""" """Move the data from the given buffer to self."""
@ -148,11 +160,11 @@ class ReplayBuffer(Batch):
break break
def add(self, def add(self,
obs: Union[dict, np.ndarray], obs: Union[dict, Batch, np.ndarray],
act: Union[np.ndarray, float], act: Union[np.ndarray, float],
rew: float, rew: float,
done: bool, done: bool,
obs_next: Optional[Union[dict, np.ndarray]] = None, obs_next: Optional[Union[dict, Batch, np.ndarray]] = None,
info: dict = {}, info: dict = {},
policy: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {},
**kwargs) -> None: **kwargs) -> None:
@ -164,6 +176,8 @@ class ReplayBuffer(Batch):
self._add_to_buffer('rew', rew) self._add_to_buffer('rew', rew)
self._add_to_buffer('done', done) self._add_to_buffer('done', done)
if self._save_s_: if self._save_s_:
if obs_next is None:
obs_next = Batch()
self._add_to_buffer('obs_next', obs_next) self._add_to_buffer('obs_next', obs_next)
self._add_to_buffer('info', info) self._add_to_buffer('info', info)
self._add_to_buffer('policy', policy) self._add_to_buffer('policy', policy)
@ -210,6 +224,7 @@ class ReplayBuffer(Batch):
else self._size - indice.stop if indice.stop < 0 else self._size - indice.stop if indice.stop < 0
else indice.stop, else indice.stop,
1 if indice.step is None else indice.step) 1 if indice.step is None else indice.step)
else:
indice = np.array(indice, copy=True) indice = np.array(indice, copy=True)
# set last frame done to True # set last frame done to True
last_index = (self._index - 1 + self._size) % self._size last_index = (self._index - 1 + self._size) % self._size
@ -218,21 +233,9 @@ class ReplayBuffer(Batch):
indice += 1 - self.done[indice].astype(np.int) indice += 1 - self.done[indice].astype(np.int)
indice[indice == self._size] = 0 indice[indice == self._size] = 0
key = 'obs' key = 'obs'
if stack_num == 0: val = self._meta.__dict__[key]
self.done[last_index] = last_done try:
val = self[key] if stack_num > 0:
if isinstance(val, Batch) and val.size == 0:
return val
else:
if isinstance(indice, (int, np.integer)) or \
(isinstance(indice, np.ndarray) and
indice.ndim == 0) or not isinstance(val, list):
return val[indice]
else:
return [val[i] for i in indice]
else:
val = self[key]
if not isinstance(val, Batch) or val.size > 0:
stack = [] stack = []
for _ in range(stack_num): for _ in range(stack_num):
stack = [val[indice]] + stack stack = [val[indice]] + stack
@ -241,11 +244,13 @@ class ReplayBuffer(Batch):
indice = np.asarray( indice = np.asarray(
pre_indice + self.done[pre_indice].astype(np.int)) pre_indice + self.done[pre_indice].astype(np.int))
indice[indice == self._size] = 0 indice[indice == self._size] = 0
if isinstance(stack[0], Batch): if isinstance(val, Batch):
stack = Batch.stack(stack, axis=indice.ndim) stack = Batch.stack(stack, axis=indice.ndim)
else: else:
stack = np.stack(stack, axis=indice.ndim) stack = np.stack(stack, axis=indice.ndim)
else: else:
stack = val[indice]
except TypeError:
stack = Batch() stack = Batch()
self.done[last_index] = last_done self.done[last_index] = last_done
return stack return stack
@ -255,17 +260,15 @@ class ReplayBuffer(Batch):
"""Return a data batch: self[index]. If stack_num is set to be > 0, """Return a data batch: self[index]. If stack_num is set to be > 0,
return the stacked obs and obs_next with shape [batch, len, ...]. return the stacked obs and obs_next with shape [batch, len, ...].
""" """
if isinstance(index, str):
return getattr(self, index)
return Batch( return Batch(
obs=self.get(index, 'obs'), obs=self.get(index, 'obs'),
act=self.get(index, 'act', stack_num=0), act=self.act[index],
# act_=self.get(index, 'act'), # stacked action, for RNN # act_=self.get(index, 'act'), # stacked action, for RNN
rew=self.get(index, 'rew', stack_num=0), rew=self.rew[index],
done=self.get(index, 'done', stack_num=0), done=self.done[index],
obs_next=self.get(index, 'obs_next'), obs_next=self.get(index, 'obs_next'),
info=self.get(index, 'info', stack_num=0), info=self.get(index, 'info', stack_num=0),
policy=self.get(index, 'policy'), policy=self.get(index, 'policy')
) )
@ -288,15 +291,15 @@ class ListReplayBuffer(ReplayBuffer):
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None: inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
if inst is None: if inst is None:
return return
if self._data.get(name, None) is None: if self._meta.__dict__.get(name, None) is None:
self._data[name] = [] self._meta.__dict__[name] = []
self._data[name].append(inst) self._meta.__dict__[name].append(inst)
def reset(self) -> None: def reset(self) -> None:
self._index = self._size = 0 self._index = self._size = 0
for k in list(self._data): for k in list(self._meta.__dict__.keys()):
if isinstance(self._data[k], list): if isinstance(self._meta.__dict__[k], list):
self._data[k] = [] self._meta.__dict__[k] = []
class PrioritizedReplayBuffer(ReplayBuffer): class PrioritizedReplayBuffer(ReplayBuffer):
@ -322,10 +325,10 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self._alpha = alpha self._alpha = alpha
self._beta = beta self._beta = beta
self._weight_sum = 0.0 self._weight_sum = 0.0
self.weight = np.zeros(size, dtype=np.float64)
self._amortization_freq = 50 self._amortization_freq = 50
self._amortization_counter = 0 self._amortization_counter = 0
self._replace = replace self._replace = replace
self._meta.__dict__['weight'] = np.zeros(size, dtype=np.float64)
def add(self, def add(self,
obs: Union[dict, np.ndarray], obs: Union[dict, np.ndarray],
@ -338,9 +341,9 @@ class PrioritizedReplayBuffer(ReplayBuffer):
weight: float = 1.0, weight: float = 1.0,
**kwargs) -> None: **kwargs) -> None:
"""Add a batch of data into replay buffer.""" """Add a batch of data into replay buffer."""
# we have to sacrifice some convenience for speed
self._weight_sum += np.abs(weight) ** self._alpha - \ self._weight_sum += np.abs(weight) ** self._alpha - \
self.weight[self._index] self._meta.__dict__['weight'][self._index]
# we have to sacrifice some convenience for speed :(
self._add_to_buffer('weight', np.abs(weight) ** self._alpha) self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
super().add(obs, act, rew, done, obs_next, info, policy) super().add(obs, act, rew, done, obs_next, info, policy)
self._check_weight_sum() self._check_weight_sum()
@ -414,18 +417,16 @@ class PrioritizedReplayBuffer(ReplayBuffer):
- self.weight[indice].sum() - self.weight[indice].sum()
self.weight[indice] = np.power(np.abs(new_weight), self._alpha) self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
def __getitem__(self, index: Union[str, slice, np.ndarray]) -> Batch: def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch:
if isinstance(index, str):
return getattr(self, index)
return Batch( return Batch(
obs=self.get(index, 'obs'), obs=self.get(index, 'obs'),
act=self.get(index, 'act', stack_num=0), act=self.act[index],
# act_=self.get(index, 'act'), # stacked action, for RNN # act_=self.get(index, 'act'), # stacked action, for RNN
rew=self.get(index, 'rew', stack_num=0), rew=self.rew[index],
done=self.get(index, 'done', stack_num=0), done=self.done[index],
obs_next=self.get(index, 'obs_next'), obs_next=self.get(index, 'obs_next'),
info=self.get(index, 'info'), info=self.get(index, 'info'),
weight=self.get(index, 'weight', stack_num=0), weight=self.weight[index],
policy=self.get(index, 'policy'), policy=self.get(index, 'policy'),
) )