Buffer refactoring to support batch over batch reliably (#93)
* Fix support of batch over batch for Buffer. * Do not use internal __dict__ attribute to store batch data since it breaks inheritance. * Various fixes. * Improve robustness of Batch/Buffer by avoiding direct attribute assignment. Buffer refactoring. * Add axis optional argument to Batch stack method. * Add item assignment to Batch class. * Fix list support for Buffer. * Convert list to np.array by default for efficiency. * Add missing unit test for Batch. Fix unit tests. * Batch item assignment is now robust to key order. * Do not use getattr/setattr explicity for simplicity. * More flexible __setitem__. * Fixes * Remove broacasting at Batch level since it is unreliable. * Forbid item assignement for inconsistent batches. * Implement broadcasting at Buffer level. * Add more unit test for Batch item assignment. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
506cc97ba5
commit
3086b5c31d
@ -65,6 +65,15 @@ def test_batch():
|
||||
assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
|
||||
assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
|
||||
assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
|
||||
batch3 = Batch(a={
|
||||
'c': np.zeros(1),
|
||||
'd': Batch(e=np.array([0.0]), f=np.array([3.0]))})
|
||||
batch3.a.d[0] = {'e': 4.0}
|
||||
assert batch3.a.d.e[0] == 4.0
|
||||
batch3.a.d[0] = Batch(f=5.0)
|
||||
assert batch3.a.d.f[0] == 5.0
|
||||
with pytest.raises(ValueError):
|
||||
batch3.a.d[0] = Batch(f=5.0, g=0.0)
|
||||
|
||||
|
||||
def test_batch_over_batch():
|
||||
@ -93,16 +102,20 @@ def test_batch_over_batch():
|
||||
def test_batch_cat_and_stack():
|
||||
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
|
||||
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
|
||||
b_cat_out = Batch.cat((b1, b2))
|
||||
b_cat_in = copy.deepcopy(b1)
|
||||
b_cat_in.cat_(b2)
|
||||
assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e)
|
||||
assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e)
|
||||
assert isinstance(b_cat_in.a.d.e, np.ndarray)
|
||||
assert b_cat_in.a.d.e.ndim == 1
|
||||
b_stack = Batch.stack((b1, b2))
|
||||
assert isinstance(b_stack.a.d.e, np.ndarray)
|
||||
assert b_stack.a.d.e.ndim == 2
|
||||
b12_cat_out = Batch.cat((b1, b2))
|
||||
b12_cat_in = copy.deepcopy(b1)
|
||||
b12_cat_in.cat_(b2)
|
||||
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
|
||||
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
|
||||
assert isinstance(b12_cat_in.a.d.e, np.ndarray)
|
||||
assert b12_cat_in.a.d.e.ndim == 1
|
||||
b12_stack = Batch.stack((b1, b2))
|
||||
assert isinstance(b12_stack.a.d.e, np.ndarray)
|
||||
assert b12_stack.a.d.e.ndim == 2
|
||||
b3 = Batch(a=np.zeros((3, 4)))
|
||||
b4 = Batch(a=np.ones((3, 4)))
|
||||
b34_stack = Batch.stack((b3, b4), axis=1)
|
||||
assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
|
||||
|
||||
|
||||
def test_batch_over_batch_to_torch():
|
||||
|
@ -75,6 +75,11 @@ class Batch:
|
||||
[11 22] [6 6]
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs) -> 'Batch':
|
||||
self = super().__new__(cls)
|
||||
self.__dict__['_data'] = {}
|
||||
return self
|
||||
|
||||
def __init__(self,
|
||||
batch_dict: Optional[Union[
|
||||
dict, 'Batch', Tuple[Union[dict, 'Batch']],
|
||||
@ -95,21 +100,21 @@ class Batch:
|
||||
for k, v in zip(batch_dict[0].keys(),
|
||||
zip(*[e.values() for e in batch_dict])):
|
||||
if isinstance(v[0], dict) or _is_batch_set(v[0]):
|
||||
self.__dict__[k] = Batch(v)
|
||||
self[k] = Batch(v)
|
||||
elif isinstance(v[0], (np.generic, np.ndarray)):
|
||||
self.__dict__[k] = np.stack(v, axis=0)
|
||||
self[k] = np.stack(v, axis=0)
|
||||
elif isinstance(v[0], torch.Tensor):
|
||||
self.__dict__[k] = torch.stack(v, dim=0)
|
||||
self[k] = torch.stack(v, dim=0)
|
||||
elif isinstance(v[0], Batch):
|
||||
self.__dict__[k] = Batch.stack(v)
|
||||
self[k] = Batch.stack(v)
|
||||
else:
|
||||
self.__dict__[k] = list(v)
|
||||
self[k] = np.array(v) # fall back to np.object
|
||||
elif isinstance(batch_dict, (dict, Batch)):
|
||||
for k, v in batch_dict.items():
|
||||
if isinstance(v, dict) or _is_batch_set(v):
|
||||
self.__dict__[k] = Batch(v)
|
||||
self[k] = Batch(v)
|
||||
else:
|
||||
self.__dict__[k] = v
|
||||
self[k] = v
|
||||
if len(kwargs) > 0:
|
||||
self.__init__(kwargs)
|
||||
|
||||
@ -140,8 +145,8 @@ class Batch:
|
||||
if isinstance(index, (int, np.integer)):
|
||||
return -length <= index and index < length
|
||||
elif isinstance(index, (list, np.ndarray)):
|
||||
return _valid_bounds(length, min(index)) and \
|
||||
_valid_bounds(length, max(index))
|
||||
return _valid_bounds(length, np.min(index)) and \
|
||||
_valid_bounds(length, np.max(index))
|
||||
elif isinstance(index, slice):
|
||||
if index.start is not None:
|
||||
start_valid = _valid_bounds(length, index.start)
|
||||
@ -154,48 +159,75 @@ class Batch:
|
||||
return start_valid and stop_valid
|
||||
|
||||
if isinstance(index, str):
|
||||
return self.__getattr__(index)
|
||||
return getattr(self, index)
|
||||
|
||||
if not _valid_bounds(len(self), index):
|
||||
raise IndexError(
|
||||
f"Index {index} out of bounds for Batch of len {len(self)}.")
|
||||
else:
|
||||
b = Batch()
|
||||
for k, v in self.__dict__.items():
|
||||
for k, v in self.items():
|
||||
if isinstance(v, Batch) and v.size == 0:
|
||||
b.__dict__[k] = Batch()
|
||||
elif isinstance(v, list) and len(v) == 0:
|
||||
b.__dict__[k] = []
|
||||
b[k] = Batch()
|
||||
elif hasattr(v, '__len__') and (not isinstance(
|
||||
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
|
||||
if _valid_bounds(len(v), index):
|
||||
b.__dict__[k] = v[index]
|
||||
if isinstance(index, (int, np.integer)) or \
|
||||
(isinstance(index, np.ndarray) and
|
||||
index.ndim == 0) or \
|
||||
not isinstance(v, list):
|
||||
b[k] = v[index]
|
||||
else:
|
||||
b[k] = [v[i] for i in index]
|
||||
else:
|
||||
raise IndexError(
|
||||
f"Index {index} out of bounds for {type(v)} of "
|
||||
f"len {len(self)}.")
|
||||
return b
|
||||
|
||||
def __setitem__(self, index: Union[
|
||||
str, slice, int, np.integer, np.ndarray, List[int]],
|
||||
value: Any) -> None:
|
||||
if isinstance(index, str):
|
||||
return setattr(self, index, value)
|
||||
if value is None:
|
||||
value = Batch()
|
||||
if not isinstance(value, (dict, Batch)):
|
||||
raise TypeError("Batch does not supported value type "
|
||||
f"{type(value)} for item assignment.")
|
||||
if not set(value.keys()).issubset(self.keys()):
|
||||
raise ValueError(
|
||||
"Creating keys is not supported by item assignment.")
|
||||
for key in self.keys():
|
||||
if isinstance(self[key], Batch):
|
||||
default = Batch()
|
||||
elif isinstance(self[key], np.ndarray) and \
|
||||
self[key].dtype == np.integer:
|
||||
# Fallback for np.array of integer,
|
||||
# since neither None or nan is supported.
|
||||
default = 0
|
||||
else:
|
||||
default = None
|
||||
self[key][index] = value.get(key, default)
|
||||
|
||||
def __iadd__(self, val: Union['Batch', Number]):
|
||||
if isinstance(val, Batch):
|
||||
for k, r, v in zip(self.__dict__.keys(),
|
||||
self.__dict__.values(),
|
||||
val.__dict__.values()):
|
||||
for k, r, v in zip(self.keys(), self.values(), val.values()):
|
||||
if r is None:
|
||||
self.__dict__[k] = r
|
||||
self[k] = r
|
||||
elif isinstance(r, list):
|
||||
self.__dict__[k] = [r_ + v_ for r_, v_ in zip(r, v)]
|
||||
self[k] = [r_ + v_ for r_, v_ in zip(r, v)]
|
||||
else:
|
||||
self.__dict__[k] = r + v
|
||||
self[k] = r + v
|
||||
return self
|
||||
elif isinstance(val, Number):
|
||||
for k, r in zip(self.__dict__.keys(), self.__dict__.values()):
|
||||
for k, r in zip(self.keys(), self.values()):
|
||||
if r is None:
|
||||
self.__dict__[k] = r
|
||||
self[k] = r
|
||||
elif isinstance(r, list):
|
||||
self.__dict__[k] = [r_ + val for r_ in r]
|
||||
self[k] = [r_ + val for r_ in r]
|
||||
else:
|
||||
self.__dict__[k] = r + val
|
||||
self[k] = r + val
|
||||
return self
|
||||
else:
|
||||
raise TypeError("Only addition of Batch or number is supported.")
|
||||
@ -206,30 +238,40 @@ class Batch:
|
||||
def __mul__(self, val: Number):
|
||||
assert isinstance(val, Number), \
|
||||
"Only multiplication by a number is supported."
|
||||
result = Batch()
|
||||
for k, r in zip(self.__dict__.keys(), self.__dict__.values()):
|
||||
result.__dict__[k] = r * val
|
||||
result = self.__class__()
|
||||
for k, r in zip(self.keys(), self.values()):
|
||||
result[k] = r * val
|
||||
return result
|
||||
|
||||
def __truediv__(self, val: Number):
|
||||
assert isinstance(val, Number), \
|
||||
"Only division by a number is supported."
|
||||
result = Batch()
|
||||
for k, r in zip(self.__dict__.keys(), self.__dict__.values()):
|
||||
result.__dict__[k] = r / val
|
||||
result = self.__class__()
|
||||
for k, r in zip(self.keys(), self.values()):
|
||||
result[k] = r / val
|
||||
return result
|
||||
|
||||
def __getattr__(self, key: str) -> Union['Batch', Any]:
|
||||
"""Return self.key"""
|
||||
if key not in self.__dict__:
|
||||
raise AttributeError(key)
|
||||
if key in self.__dict__.keys():
|
||||
return self.__dict__[key]
|
||||
elif key in self._data.keys():
|
||||
return self._data[key]
|
||||
raise AttributeError(key)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in self._data.keys():
|
||||
self._data[key] = value
|
||||
elif key in self.__dict__.keys():
|
||||
self.__dict__[key] = value
|
||||
else:
|
||||
self._data[key] = value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return str(self)."""
|
||||
s = self.__class__.__name__ + '(\n'
|
||||
flag = False
|
||||
for k, v in self.__dict__.items():
|
||||
for k, v in self.items():
|
||||
rpl = '\n' + ' ' * (6 + len(k))
|
||||
obj = pprint.pformat(v).replace('\n', rpl)
|
||||
s += f' {k}: {obj},\n'
|
||||
@ -242,29 +284,29 @@ class Batch:
|
||||
|
||||
def keys(self) -> List[str]:
|
||||
"""Return self.keys()."""
|
||||
return self.__dict__.keys()
|
||||
return self._data.keys()
|
||||
|
||||
def values(self) -> List[Any]:
|
||||
"""Return self.values()."""
|
||||
return self.__dict__.values()
|
||||
return self._data.values()
|
||||
|
||||
def items(self) -> Any:
|
||||
"""Return self.items()."""
|
||||
return self.__dict__.items()
|
||||
return self._data.items()
|
||||
|
||||
def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]:
|
||||
"""Return self[k] if k in self else d. d defaults to None."""
|
||||
if k in self.__dict__:
|
||||
return self.__getattr__(k)
|
||||
if k in self.keys():
|
||||
return self[k]
|
||||
return d
|
||||
|
||||
def to_numpy(self) -> None:
|
||||
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
|
||||
"""Change all torch.Tensor to numpy.ndarray. This is an in-place
|
||||
operation.
|
||||
"""
|
||||
for k, v in self.__dict__.items():
|
||||
for k, v in self.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
self.__dict__[k] = v.detach().cpu().numpy()
|
||||
self[k] = v.detach().cpu().numpy()
|
||||
elif isinstance(v, Batch):
|
||||
v.to_numpy()
|
||||
|
||||
@ -272,18 +314,18 @@ class Batch:
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Union[str, int, torch.device] = 'cpu'
|
||||
) -> None:
|
||||
"""Change all numpy.ndarray to torch.Tensor. This is an inplace
|
||||
"""Change all numpy.ndarray to torch.Tensor. This is an in-place
|
||||
operation.
|
||||
"""
|
||||
if not isinstance(device, torch.device):
|
||||
device = torch.device(device)
|
||||
|
||||
for k, v in self.__dict__.items():
|
||||
for k, v in self.items():
|
||||
if isinstance(v, (np.generic, np.ndarray)):
|
||||
v = torch.from_numpy(v).to(device)
|
||||
if dtype is not None:
|
||||
v = v.type(dtype)
|
||||
self.__dict__[k] = v
|
||||
self[k] = v
|
||||
if isinstance(v, torch.Tensor):
|
||||
if dtype is not None and v.dtype != dtype:
|
||||
must_update_tensor = True
|
||||
@ -297,7 +339,7 @@ class Batch:
|
||||
if must_update_tensor:
|
||||
if dtype is not None:
|
||||
v = v.type(dtype)
|
||||
self.__dict__[k] = v.to(device)
|
||||
self[k] = v.to(device)
|
||||
elif isinstance(v, Batch):
|
||||
v.to_torch(dtype, device)
|
||||
|
||||
@ -312,51 +354,67 @@ class Batch:
|
||||
"""
|
||||
assert isinstance(batch, Batch), \
|
||||
'Only Batch is allowed to be concatenated in-place!'
|
||||
for k, v in batch.__dict__.items():
|
||||
for k, v in batch.items():
|
||||
if v is None:
|
||||
continue
|
||||
if not hasattr(self, k) or self.__dict__[k] is None:
|
||||
self.__dict__[k] = copy.deepcopy(v)
|
||||
if not hasattr(self, k) or self[k] is None:
|
||||
self[k] = copy.deepcopy(v)
|
||||
elif isinstance(v, np.ndarray) and v.ndim > 0:
|
||||
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
|
||||
self[k] = np.concatenate([self[k], v])
|
||||
elif isinstance(v, torch.Tensor):
|
||||
self.__dict__[k] = torch.cat([self.__dict__[k], v])
|
||||
self[k] = torch.cat([self[k], v])
|
||||
elif isinstance(v, list):
|
||||
self.__dict__[k] += copy.deepcopy(v)
|
||||
self[k] = self[k] + copy.deepcopy(v)
|
||||
elif isinstance(v, Batch):
|
||||
self.__dict__[k].cat_(v)
|
||||
self[k].cat_(v)
|
||||
else:
|
||||
s = 'No support for method "cat" with type '\
|
||||
f'{type(v)} in class Batch.'
|
||||
raise TypeError(s)
|
||||
|
||||
@staticmethod
|
||||
def cat(batches: List['Batch']) -> None:
|
||||
@classmethod
|
||||
def cat(cls, batches: List['Batch']) -> 'Batch':
|
||||
"""Concatenate a :class:`~tianshou.data.Batch` object into a
|
||||
single new batch.
|
||||
"""
|
||||
assert isinstance(batches, (tuple, list)), \
|
||||
'Only list of Batch instances is allowed to be '\
|
||||
'concatenated out-of-place!'
|
||||
batch = Batch()
|
||||
batch = cls()
|
||||
for batch_ in batches:
|
||||
batch.cat_(batch_)
|
||||
return batch
|
||||
|
||||
@staticmethod
|
||||
def stack(batches: List['Batch']):
|
||||
@classmethod
|
||||
def stack(cls, batches: List['Batch'], axis: int = 0) -> 'Batch':
|
||||
"""Stack a :class:`~tianshou.data.Batch` object into a
|
||||
single new batch.
|
||||
"""
|
||||
assert isinstance(batches, (tuple, list)), \
|
||||
'Only list of Batch instances is allowed to be '\
|
||||
'stacked out-of-place!'
|
||||
return Batch(np.array([batch.__dict__ for batch in batches]))
|
||||
if axis == 0:
|
||||
return cls(batches)
|
||||
else:
|
||||
batch = Batch()
|
||||
for k, v in zip(batches[0].keys(),
|
||||
zip(*[e.values() for e in batches])):
|
||||
if isinstance(v[0], (np.generic, np.ndarray, list)):
|
||||
batch[k] = np.stack(v, axis)
|
||||
elif isinstance(v[0], torch.Tensor):
|
||||
batch[k] = torch.stack(v, axis)
|
||||
elif isinstance(v[0], Batch):
|
||||
batch[k] = Batch.stack(v, axis)
|
||||
else:
|
||||
s = 'No support for method "stack" with type '\
|
||||
f'{type(v[0])} in class Batch and axis != 0.'
|
||||
raise TypeError(s)
|
||||
return batch
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self)."""
|
||||
r = []
|
||||
for v in self.__dict__.values():
|
||||
for v in self.values():
|
||||
if isinstance(v, Batch) and v.size == 0:
|
||||
continue
|
||||
elif isinstance(v, list) and len(v) == 0:
|
||||
@ -373,11 +431,11 @@ class Batch:
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Return self.size."""
|
||||
if len(self.__dict__) == 0:
|
||||
if len(self.keys()) == 0:
|
||||
return 0
|
||||
else:
|
||||
r = []
|
||||
for v in self.__dict__.values():
|
||||
for v in self.values():
|
||||
if isinstance(v, Batch):
|
||||
r.append(v.size)
|
||||
elif hasattr(v, '__len__') and (not isinstance(
|
||||
|
@ -1,11 +1,11 @@
|
||||
import pprint
|
||||
import numpy as np
|
||||
from numbers import Number
|
||||
from typing import Any, Tuple, Union, Optional
|
||||
|
||||
from tianshou.data.batch import Batch
|
||||
from .batch import Batch
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
class ReplayBuffer(Batch):
|
||||
""":class:`~tianshou.data.ReplayBuffer` stores data generated from
|
||||
interaction between the policy and environment. It stores basically 7 types
|
||||
of data, as mentioned in :class:`~tianshou.data.Batch`, based on
|
||||
@ -96,81 +96,47 @@ class ReplayBuffer:
|
||||
|
||||
def __init__(self, size: int, stack_num: Optional[int] = 0,
|
||||
ignore_obs_next: bool = False, **kwargs) -> None:
|
||||
self._maxsize = size
|
||||
self._stack = stack_num
|
||||
self._save_s_ = not ignore_obs_next
|
||||
self._meta = {}
|
||||
super().__init__()
|
||||
self.__dict__['_maxsize'] = size
|
||||
self.__dict__['_stack'] = stack_num
|
||||
self.__dict__['_save_s_'] = not ignore_obs_next
|
||||
self.__dict__['_index'] = 0
|
||||
self.__dict__['_size'] = 0
|
||||
self.reset()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self)."""
|
||||
return self._size
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return str(self)."""
|
||||
s = self.__class__.__name__ + '(\n'
|
||||
flag = False
|
||||
for k in sorted(list(self.__dict__) + list(self._meta)):
|
||||
if k[0] != '_' and (self.__dict__.get(k, None) is not None or
|
||||
k in self._meta):
|
||||
rpl = '\n' + ' ' * (6 + len(k))
|
||||
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
|
||||
s += f' {k}: {obj},\n'
|
||||
flag = True
|
||||
if flag:
|
||||
s += ')'
|
||||
else:
|
||||
s = self.__class__.__name__ + '()'
|
||||
return s
|
||||
|
||||
def __getattr__(self, key: str) -> Union[Batch, np.ndarray]:
|
||||
"""Return self.key"""
|
||||
if key not in self._meta:
|
||||
if key not in self.__dict__:
|
||||
raise AttributeError(key)
|
||||
return self.__dict__[key]
|
||||
d = {}
|
||||
for k_ in self._meta[key]:
|
||||
k__ = '_' + key + '@' + k_
|
||||
if k__ in self.__dict__:
|
||||
d[k_] = self.__dict__[k__]
|
||||
else:
|
||||
d[k_] = self.__getattr__(k__)
|
||||
return Batch(d)
|
||||
|
||||
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
||||
if inst is None:
|
||||
if getattr(self, name, None) is None:
|
||||
self.__dict__[name] = None
|
||||
return
|
||||
if name in self._meta:
|
||||
for k in inst.keys():
|
||||
self._add_to_buffer('_' + name + '@' + k, inst[k])
|
||||
return
|
||||
if self.__dict__.get(name, None) is None:
|
||||
def _create_value(inst: Any) -> Union['Batch', np.ndarray]:
|
||||
if isinstance(inst, np.ndarray):
|
||||
self.__dict__[name] = np.zeros(
|
||||
return np.zeros(
|
||||
(self._maxsize, *inst.shape), dtype=inst.dtype)
|
||||
elif isinstance(inst, (dict, Batch)):
|
||||
if self._meta.get(name, None) is None:
|
||||
self._meta[name] = list(inst.keys())
|
||||
for k in inst.keys():
|
||||
k_ = '_' + name + '@' + k
|
||||
self._add_to_buffer(k_, inst[k])
|
||||
elif np.isscalar(inst):
|
||||
self.__dict__[name] = np.zeros(
|
||||
return Batch([Batch(inst) for _ in range(self._maxsize)])
|
||||
elif isinstance(inst, (np.generic, Number)):
|
||||
return np.zeros(
|
||||
(self._maxsize,), dtype=np.asarray(inst).dtype)
|
||||
else: # fall back to np.object
|
||||
self.__dict__[name] = np.array(
|
||||
[None for _ in range(self._maxsize)])
|
||||
return np.array([None for _ in range(self._maxsize)])
|
||||
|
||||
if inst is None:
|
||||
inst = Batch()
|
||||
if name not in self.keys():
|
||||
self[name] = _create_value(inst)
|
||||
if isinstance(inst, np.ndarray) and \
|
||||
self.__dict__[name].shape[1:] != inst.shape:
|
||||
self[name].shape[1:] != inst.shape:
|
||||
raise ValueError(
|
||||
"Cannot add data to a buffer with different shape, "
|
||||
f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, "
|
||||
f"given shape: {inst.shape}.")
|
||||
if name not in self._meta:
|
||||
self.__dict__[name][self._index] = inst
|
||||
f"key: {name}, expect shape: {self[name].shape[1:]}"
|
||||
f", given shape: {inst.shape}.")
|
||||
if isinstance(self[name], Batch):
|
||||
field_keys = self[name].keys()
|
||||
for key, val in inst.items():
|
||||
if key not in field_keys:
|
||||
self[name][key] = _create_value(val)
|
||||
self[name][self._index] = inst
|
||||
|
||||
def update(self, buffer: 'ReplayBuffer') -> None:
|
||||
"""Move the data from the given buffer to self."""
|
||||
@ -209,7 +175,8 @@ class ReplayBuffer:
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all the data in replay buffer."""
|
||||
self._index = self._size = 0
|
||||
self._index = 0
|
||||
self._size = 0
|
||||
|
||||
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
||||
"""Get a random sample from buffer with size equal to batch_size. \
|
||||
@ -226,7 +193,7 @@ class ReplayBuffer:
|
||||
])
|
||||
return self[indice], indice
|
||||
|
||||
def get(self, indice: Union[slice, np.ndarray], key: str,
|
||||
def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str,
|
||||
stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]:
|
||||
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
|
||||
where s is self.key, t is indice. The stack_num (here equals to 4) is
|
||||
@ -234,10 +201,7 @@ class ReplayBuffer:
|
||||
"""
|
||||
if stack_num is None:
|
||||
stack_num = self._stack
|
||||
if not isinstance(indice, np.ndarray):
|
||||
if np.isscalar(indice):
|
||||
indice = np.array(indice)
|
||||
elif isinstance(indice, slice):
|
||||
if isinstance(indice, slice):
|
||||
indice = np.arange(
|
||||
0 if indice.start is None
|
||||
else self._size - indice.start if indice.start < 0
|
||||
@ -246,8 +210,7 @@ class ReplayBuffer:
|
||||
else self._size - indice.stop if indice.stop < 0
|
||||
else indice.stop,
|
||||
1 if indice.step is None else indice.step)
|
||||
else:
|
||||
indice = np.array(indice)
|
||||
indice = np.array(indice, copy=True)
|
||||
# set last frame done to True
|
||||
last_index = (self._index - 1 + self._size) % self._size
|
||||
last_done, self.done[last_index] = self.done[last_index], True
|
||||
@ -257,49 +220,51 @@ class ReplayBuffer:
|
||||
key = 'obs'
|
||||
if stack_num == 0:
|
||||
self.done[last_index] = last_done
|
||||
if key in self._meta:
|
||||
return {k: self.__dict__['_' + key + '@' + k][indice]
|
||||
for k in self._meta[key]}
|
||||
val = self[key]
|
||||
if isinstance(val, Batch) and val.size == 0:
|
||||
return val
|
||||
else:
|
||||
return self.__dict__[key][indice]
|
||||
if key in self._meta:
|
||||
many_keys = self._meta[key]
|
||||
stack = {k: [] for k in self._meta[key]}
|
||||
if isinstance(indice, (int, np.integer)) or \
|
||||
(isinstance(indice, np.ndarray) and
|
||||
indice.ndim == 0) or not isinstance(val, list):
|
||||
return val[indice]
|
||||
else:
|
||||
return [val[i] for i in indice]
|
||||
else:
|
||||
val = self[key]
|
||||
if not isinstance(val, Batch) or val.size > 0:
|
||||
stack = []
|
||||
many_keys = None
|
||||
for _ in range(stack_num):
|
||||
if many_keys is not None:
|
||||
for k_ in many_keys:
|
||||
k__ = '_' + key + '@' + k_
|
||||
stack[k_] = [self.__dict__[k__][indice]] + stack[k_]
|
||||
else:
|
||||
stack = [self.__dict__[key][indice]] + stack
|
||||
pre_indice = indice - 1
|
||||
stack = [val[indice]] + stack
|
||||
pre_indice = np.asarray(indice - 1)
|
||||
pre_indice[pre_indice == -1] = self._size - 1
|
||||
indice = pre_indice + self.done[pre_indice].astype(np.int)
|
||||
indice = np.asarray(
|
||||
pre_indice + self.done[pre_indice].astype(np.int))
|
||||
indice[indice == self._size] = 0
|
||||
self.done[last_index] = last_done
|
||||
if many_keys is not None:
|
||||
for k in stack:
|
||||
stack[k] = np.stack(stack[k], axis=1)
|
||||
stack = Batch(stack)
|
||||
if isinstance(stack[0], Batch):
|
||||
stack = Batch.stack(stack, axis=indice.ndim)
|
||||
else:
|
||||
stack = np.stack(stack, axis=1)
|
||||
stack = np.stack(stack, axis=indice.ndim)
|
||||
else:
|
||||
stack = Batch()
|
||||
self.done[last_index] = last_done
|
||||
return stack
|
||||
|
||||
def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch:
|
||||
def __getitem__(self, index: Union[
|
||||
slice, int, np.integer, np.ndarray]) -> Batch:
|
||||
"""Return a data batch: self[index]. If stack_num is set to be > 0,
|
||||
return the stacked obs and obs_next with shape [batch, len, ...].
|
||||
"""
|
||||
if isinstance(index, str):
|
||||
return getattr(self, index)
|
||||
return Batch(
|
||||
obs=self.get(index, 'obs'),
|
||||
act=self.act[index],
|
||||
act=self.get(index, 'act', stack_num=0),
|
||||
# act_=self.get(index, 'act'), # stacked action, for RNN
|
||||
rew=self.rew[index],
|
||||
done=self.done[index],
|
||||
rew=self.get(index, 'rew', stack_num=0),
|
||||
done=self.get(index, 'done', stack_num=0),
|
||||
obs_next=self.get(index, 'obs_next'),
|
||||
info=self.get(index, 'info'),
|
||||
info=self.get(index, 'info', stack_num=0),
|
||||
policy=self.get(index, 'policy'),
|
||||
)
|
||||
|
||||
@ -323,15 +288,15 @@ class ListReplayBuffer(ReplayBuffer):
|
||||
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
|
||||
if inst is None:
|
||||
return
|
||||
if self.__dict__.get(name, None) is None:
|
||||
self.__dict__[name] = []
|
||||
self.__dict__[name].append(inst)
|
||||
if self._data.get(name, None) is None:
|
||||
self._data[name] = []
|
||||
self._data[name].append(inst)
|
||||
|
||||
def reset(self) -> None:
|
||||
self._index = self._size = 0
|
||||
for k in list(self.__dict__):
|
||||
if isinstance(self.__dict__[k], list):
|
||||
self.__dict__[k] = []
|
||||
for k in list(self._data):
|
||||
if isinstance(self._data[k], list):
|
||||
self._data[k] = []
|
||||
|
||||
|
||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
@ -449,16 +414,18 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
- self.weight[indice].sum()
|
||||
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
|
||||
|
||||
def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch:
|
||||
def __getitem__(self, index: Union[str, slice, np.ndarray]) -> Batch:
|
||||
if isinstance(index, str):
|
||||
return getattr(self, index)
|
||||
return Batch(
|
||||
obs=self.get(index, 'obs'),
|
||||
act=self.act[index],
|
||||
act=self.get(index, 'act', stack_num=0),
|
||||
# act_=self.get(index, 'act'), # stacked action, for RNN
|
||||
rew=self.rew[index],
|
||||
done=self.done[index],
|
||||
rew=self.get(index, 'rew', stack_num=0),
|
||||
done=self.get(index, 'done', stack_num=0),
|
||||
obs_next=self.get(index, 'obs_next'),
|
||||
info=self.get(index, 'info'),
|
||||
weight=self.weight[index],
|
||||
weight=self.get(index, 'weight', stack_num=0),
|
||||
policy=self.get(index, 'policy'),
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user