support Batch of Batch and fix bugs (#38)

This commit is contained in:
Trinkle23897 2020-04-29 12:14:53 +08:00
parent 8f718d9b13
commit bb2f833d0e
3 changed files with 60 additions and 41 deletions

View File

@ -94,6 +94,17 @@ def test_collector_with_dict_state():
c1 = Collector(policy, envs, ReplayBuffer(size=100))
c1.collect(n_step=10)
c1.collect(n_episode=[2, 1, 1, 2])
batch = c1.sample(10)
print(batch)
c0.buffer.update(c1.buffer)
assert equal(c0.buffer[:len(c0.buffer)].obs.index, [
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.])
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4))
c2.collect(n_episode=[0, 0, 0, 10])
batch = c2.sample(10)
print(batch['obs_next']['index'])
if __name__ == '__main__':

View File

@ -76,7 +76,7 @@ class Batch(object):
self.__dict__[k__] = np.array([
v[i][k_] for i in range(len(v))
])
elif isinstance(v, dict):
elif isinstance(v, dict) or isinstance(v, Batch):
self._meta[k] = list(v.keys())
for k_ in v.keys():
k__ = '_' + k + '@' + k_
@ -89,7 +89,7 @@ class Batch(object):
if isinstance(index, str):
return self.__getattr__(index)
b = Batch()
for k in self.__dict__.keys():
for k in self.__dict__:
if k != '_meta' and self.__dict__[k] is not None:
b.__dict__.update(**{k: self.__dict__[k][index]})
b._meta = self._meta
@ -97,44 +97,44 @@ class Batch(object):
def __getattr__(self, key):
"""Return self.key"""
if key not in self._meta.keys():
if key not in self.__dict__.keys():
if key not in self._meta:
if key not in self.__dict__:
raise AttributeError(key)
return self.__dict__[key]
d = {}
for k_ in self._meta[key]:
k__ = '_' + key + '@' + k_
d[k_] = self.__dict__[k__]
return d
return Batch(**d)
def __repr__(self):
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
for k in sorted(list(self.__dict__.keys()) + list(self._meta.keys())):
for k in sorted(list(self.__dict__) + list(self._meta)):
if k[0] != '_' and (self.__dict__.get(k, None) is not None or
k in self._meta.keys()):
k in self._meta):
rpl = '\n' + ' ' * (6 + len(k))
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
s += f' {k}: {obj},\n'
flag = True
if flag:
s += ')\n'
s += ')'
else:
s = self.__class__.__name__ + '()\n'
s = self.__class__.__name__ + '()'
return s
def keys(self):
"""Return self.keys()."""
return sorted([i for i in self.__dict__.keys() if i[0] != '_'] +
list(self._meta.keys()))
return sorted([i for i in self.__dict__ if i[0] != '_'] +
list(self._meta))
def append(self, batch):
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
for k in batch.__dict__.keys():
for k in batch.__dict__:
if k == '_meta':
self._meta.update(batch.__dict__[k])
self._meta.update(batch._meta)
continue
if batch.__dict__[k] is None:
continue
@ -149,15 +149,15 @@ class Batch(object):
elif isinstance(batch.__dict__[k], list):
self.__dict__[k] += batch.__dict__[k]
else:
s = 'No support for append with type'\
+ str(type(batch.__dict__[k]))\
s = 'No support for append with type' \
+ str(type(batch.__dict__[k])) \
+ 'in class Batch.'
raise TypeError(s)
def __len__(self):
"""Return len(self)."""
return min([
len(self.__dict__[k]) for k in self.__dict__.keys()
len(self.__dict__[k]) for k in self.__dict__
if k != '_meta' and self.__dict__[k] is not None])
def split(self, size=None, shuffle=True):

View File

@ -104,51 +104,50 @@ class ReplayBuffer(object):
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
for k in sorted(list(self.__dict__.keys()) + list(self._meta.keys())):
for k in sorted(list(self.__dict__) + list(self._meta)):
if k[0] != '_' and (self.__dict__.get(k, None) is not None or
k in self._meta.keys()):
k in self._meta):
rpl = '\n' + ' ' * (6 + len(k))
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
s += f' {k}: {obj},\n'
flag = True
if flag:
s += ')\n'
s += ')'
else:
s = self.__class__.__name__ + '()\n'
s = self.__class__.__name__ + '()'
return s
def __getattr__(self, key):
"""Return self.key"""
if key not in self._meta.keys():
if key not in self.__dict__.keys():
if key not in self._meta:
if key not in self.__dict__:
raise AttributeError(key)
return self.__dict__[key]
d = {}
for k_ in self._meta[key]:
k__ = '_' + key + '@' + k_
d[k_] = self.__dict__[k__]
return d
return Batch(**d)
def _add_to_buffer(self, name, inst):
if inst is None:
if getattr(self, name, None) is None:
self.__dict__[name] = None
return
if name in self._meta.keys():
if name in self._meta:
for k in inst.keys():
self._add_to_buffer('_' + name + '@' + k, inst[k])
return
if self.__dict__.get(name, None) is None:
if isinstance(inst, np.ndarray):
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
elif isinstance(inst, dict):
elif isinstance(inst, dict) or isinstance(inst, Batch):
if name == 'info':
self.__dict__[name] = np.array(
[{} for _ in range(self._maxsize)])
else:
if self._meta.get(name, None) is None:
self._meta[name] = [
'_' + name + '@' + k for k in inst.keys()]
self._meta[name] = list(inst.keys())
for k in inst.keys():
k_ = '_' + name + '@' + k
self._add_to_buffer(k_, inst[k])
@ -160,7 +159,7 @@ class ReplayBuffer(object):
"Cannot add data to a buffer with different shape, "
f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, "
f"given shape: {inst.shape}.")
if name not in self._meta.keys():
if name not in self._meta:
self.__dict__[name][self._index] = inst
def update(self, buffer):
@ -225,8 +224,12 @@ class ReplayBuffer(object):
indice = np.array(indice)
elif isinstance(indice, slice):
indice = np.arange(
0 if indice.start is None else indice.start,
self._size if indice.stop is None else indice.stop,
0 if indice.start is None
else self._size - indice.start if indice.start < 0
else indice.start,
self._size if indice.stop is None
else self._size - indice.stop if indice.stop < 0
else indice.stop,
1 if indice.step is None else indice.step)
# set last frame done to True
last_index = (self._index - 1 + self._size) % self._size
@ -238,21 +241,21 @@ class ReplayBuffer(object):
if stack_num == 0:
self.done[last_index] = last_done
if key in self._meta:
return {k.split('@')[-1]: self.__dict__[k][indice]
return {k: self.__dict__['_' + key + '@' + k][indice]
for k in self._meta[key]}
else:
return self.__dict__[key][indice]
if key in self._meta:
many_keys = self._meta[key]
stack = {k.split('@')[-1]: [] for k in self._meta[key]}
stack = {k: [] for k in self._meta[key]}
else:
stack = []
many_keys = None
for i in range(stack_num):
if many_keys is not None:
for k_ in many_keys:
k = k_.split('@')[-1]
stack[k] = [self.__dict__[k_][indice]] + stack[k]
k__ = '_' + key + '@' + k_
stack[k_] = [self.__dict__[k__][indice]] + stack[k_]
else:
stack = [self.__dict__[key][indice]] + stack
pre_indice = indice - 1
@ -263,6 +266,7 @@ class ReplayBuffer(object):
if many_keys is not None:
for k in stack:
stack[k] = np.stack(stack[k], axis=1)
stack = Batch(**stack)
else:
stack = np.stack(stack, axis=1)
return stack
@ -305,14 +309,18 @@ class ListReplayBuffer(ReplayBuffer):
def reset(self):
self._index = self._size = 0
for k in list(self.__dict__.keys()):
if not k.startswith('_'):
for k in list(self.__dict__):
if isinstance(self.__dict__[k], list):
self.__dict__[k] = []
class PrioritizedReplayBuffer(ReplayBuffer):
"""Prioritized replay buffer implementation.
:param float alpha: the prioritization exponent.
:param float beta: the importance sample soft coefficient.
:param str mode: defaults to ``weight``.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for more
@ -324,8 +332,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
if mode != 'weight':
raise NotImplementedError
super().__init__(size, **kwargs)
self._alpha = alpha # prioritization exponent
self._beta = beta # importance sample soft coefficient
self._alpha = alpha
self._beta = beta
self._weight_sum = 0.0
self.weight = np.zeros(size, dtype=np.float64)
self._amortization_freq = 50
@ -382,8 +390,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
def update_weight(self, indice, new_weight: np.ndarray):
"""Update priority weight by indice in this buffer.
:param indice: indice you want to update weight
:param new_weight: new priority weight you wangt to update
:param np.ndarray indice: indice you want to update weight
:param np.ndarray new_weight: new priority weight you wangt to update
"""
self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \
- self.weight[indice].sum()
@ -402,7 +410,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
)
def _check_weight_sum(self):
# keep a accurate _weight_sum
# keep an accurate _weight_sum
self._amortization_counter += 1
if self._amortization_counter % self._amortization_freq == 0:
self._weight_sum = np.sum(self.weight)