support Batch of Batch and fix bugs (#38)
This commit is contained in:
parent
8f718d9b13
commit
bb2f833d0e
@ -94,6 +94,17 @@ def test_collector_with_dict_state():
|
||||
c1 = Collector(policy, envs, ReplayBuffer(size=100))
|
||||
c1.collect(n_step=10)
|
||||
c1.collect(n_episode=[2, 1, 1, 2])
|
||||
batch = c1.sample(10)
|
||||
print(batch)
|
||||
c0.buffer.update(c1.buffer)
|
||||
assert equal(c0.buffer[:len(c0.buffer)].obs.index, [
|
||||
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
|
||||
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
|
||||
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.])
|
||||
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4))
|
||||
c2.collect(n_episode=[0, 0, 0, 10])
|
||||
batch = c2.sample(10)
|
||||
print(batch['obs_next']['index'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -76,7 +76,7 @@ class Batch(object):
|
||||
self.__dict__[k__] = np.array([
|
||||
v[i][k_] for i in range(len(v))
|
||||
])
|
||||
elif isinstance(v, dict):
|
||||
elif isinstance(v, dict) or isinstance(v, Batch):
|
||||
self._meta[k] = list(v.keys())
|
||||
for k_ in v.keys():
|
||||
k__ = '_' + k + '@' + k_
|
||||
@ -89,7 +89,7 @@ class Batch(object):
|
||||
if isinstance(index, str):
|
||||
return self.__getattr__(index)
|
||||
b = Batch()
|
||||
for k in self.__dict__.keys():
|
||||
for k in self.__dict__:
|
||||
if k != '_meta' and self.__dict__[k] is not None:
|
||||
b.__dict__.update(**{k: self.__dict__[k][index]})
|
||||
b._meta = self._meta
|
||||
@ -97,44 +97,44 @@ class Batch(object):
|
||||
|
||||
def __getattr__(self, key):
|
||||
"""Return self.key"""
|
||||
if key not in self._meta.keys():
|
||||
if key not in self.__dict__.keys():
|
||||
if key not in self._meta:
|
||||
if key not in self.__dict__:
|
||||
raise AttributeError(key)
|
||||
return self.__dict__[key]
|
||||
d = {}
|
||||
for k_ in self._meta[key]:
|
||||
k__ = '_' + key + '@' + k_
|
||||
d[k_] = self.__dict__[k__]
|
||||
return d
|
||||
return Batch(**d)
|
||||
|
||||
def __repr__(self):
|
||||
"""Return str(self)."""
|
||||
s = self.__class__.__name__ + '(\n'
|
||||
flag = False
|
||||
for k in sorted(list(self.__dict__.keys()) + list(self._meta.keys())):
|
||||
for k in sorted(list(self.__dict__) + list(self._meta)):
|
||||
if k[0] != '_' and (self.__dict__.get(k, None) is not None or
|
||||
k in self._meta.keys()):
|
||||
k in self._meta):
|
||||
rpl = '\n' + ' ' * (6 + len(k))
|
||||
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
|
||||
s += f' {k}: {obj},\n'
|
||||
flag = True
|
||||
if flag:
|
||||
s += ')\n'
|
||||
s += ')'
|
||||
else:
|
||||
s = self.__class__.__name__ + '()\n'
|
||||
s = self.__class__.__name__ + '()'
|
||||
return s
|
||||
|
||||
def keys(self):
|
||||
"""Return self.keys()."""
|
||||
return sorted([i for i in self.__dict__.keys() if i[0] != '_'] +
|
||||
list(self._meta.keys()))
|
||||
return sorted([i for i in self.__dict__ if i[0] != '_'] +
|
||||
list(self._meta))
|
||||
|
||||
def append(self, batch):
|
||||
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
||||
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
||||
for k in batch.__dict__.keys():
|
||||
for k in batch.__dict__:
|
||||
if k == '_meta':
|
||||
self._meta.update(batch.__dict__[k])
|
||||
self._meta.update(batch._meta)
|
||||
continue
|
||||
if batch.__dict__[k] is None:
|
||||
continue
|
||||
@ -149,15 +149,15 @@ class Batch(object):
|
||||
elif isinstance(batch.__dict__[k], list):
|
||||
self.__dict__[k] += batch.__dict__[k]
|
||||
else:
|
||||
s = 'No support for append with type'\
|
||||
+ str(type(batch.__dict__[k]))\
|
||||
s = 'No support for append with type' \
|
||||
+ str(type(batch.__dict__[k])) \
|
||||
+ 'in class Batch.'
|
||||
raise TypeError(s)
|
||||
|
||||
def __len__(self):
|
||||
"""Return len(self)."""
|
||||
return min([
|
||||
len(self.__dict__[k]) for k in self.__dict__.keys()
|
||||
len(self.__dict__[k]) for k in self.__dict__
|
||||
if k != '_meta' and self.__dict__[k] is not None])
|
||||
|
||||
def split(self, size=None, shuffle=True):
|
||||
|
||||
@ -104,51 +104,50 @@ class ReplayBuffer(object):
|
||||
"""Return str(self)."""
|
||||
s = self.__class__.__name__ + '(\n'
|
||||
flag = False
|
||||
for k in sorted(list(self.__dict__.keys()) + list(self._meta.keys())):
|
||||
for k in sorted(list(self.__dict__) + list(self._meta)):
|
||||
if k[0] != '_' and (self.__dict__.get(k, None) is not None or
|
||||
k in self._meta.keys()):
|
||||
k in self._meta):
|
||||
rpl = '\n' + ' ' * (6 + len(k))
|
||||
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
|
||||
s += f' {k}: {obj},\n'
|
||||
flag = True
|
||||
if flag:
|
||||
s += ')\n'
|
||||
s += ')'
|
||||
else:
|
||||
s = self.__class__.__name__ + '()\n'
|
||||
s = self.__class__.__name__ + '()'
|
||||
return s
|
||||
|
||||
def __getattr__(self, key):
|
||||
"""Return self.key"""
|
||||
if key not in self._meta.keys():
|
||||
if key not in self.__dict__.keys():
|
||||
if key not in self._meta:
|
||||
if key not in self.__dict__:
|
||||
raise AttributeError(key)
|
||||
return self.__dict__[key]
|
||||
d = {}
|
||||
for k_ in self._meta[key]:
|
||||
k__ = '_' + key + '@' + k_
|
||||
d[k_] = self.__dict__[k__]
|
||||
return d
|
||||
return Batch(**d)
|
||||
|
||||
def _add_to_buffer(self, name, inst):
|
||||
if inst is None:
|
||||
if getattr(self, name, None) is None:
|
||||
self.__dict__[name] = None
|
||||
return
|
||||
if name in self._meta.keys():
|
||||
if name in self._meta:
|
||||
for k in inst.keys():
|
||||
self._add_to_buffer('_' + name + '@' + k, inst[k])
|
||||
return
|
||||
if self.__dict__.get(name, None) is None:
|
||||
if isinstance(inst, np.ndarray):
|
||||
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
|
||||
elif isinstance(inst, dict):
|
||||
elif isinstance(inst, dict) or isinstance(inst, Batch):
|
||||
if name == 'info':
|
||||
self.__dict__[name] = np.array(
|
||||
[{} for _ in range(self._maxsize)])
|
||||
else:
|
||||
if self._meta.get(name, None) is None:
|
||||
self._meta[name] = [
|
||||
'_' + name + '@' + k for k in inst.keys()]
|
||||
self._meta[name] = list(inst.keys())
|
||||
for k in inst.keys():
|
||||
k_ = '_' + name + '@' + k
|
||||
self._add_to_buffer(k_, inst[k])
|
||||
@ -160,7 +159,7 @@ class ReplayBuffer(object):
|
||||
"Cannot add data to a buffer with different shape, "
|
||||
f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, "
|
||||
f"given shape: {inst.shape}.")
|
||||
if name not in self._meta.keys():
|
||||
if name not in self._meta:
|
||||
self.__dict__[name][self._index] = inst
|
||||
|
||||
def update(self, buffer):
|
||||
@ -225,8 +224,12 @@ class ReplayBuffer(object):
|
||||
indice = np.array(indice)
|
||||
elif isinstance(indice, slice):
|
||||
indice = np.arange(
|
||||
0 if indice.start is None else indice.start,
|
||||
self._size if indice.stop is None else indice.stop,
|
||||
0 if indice.start is None
|
||||
else self._size - indice.start if indice.start < 0
|
||||
else indice.start,
|
||||
self._size if indice.stop is None
|
||||
else self._size - indice.stop if indice.stop < 0
|
||||
else indice.stop,
|
||||
1 if indice.step is None else indice.step)
|
||||
# set last frame done to True
|
||||
last_index = (self._index - 1 + self._size) % self._size
|
||||
@ -238,21 +241,21 @@ class ReplayBuffer(object):
|
||||
if stack_num == 0:
|
||||
self.done[last_index] = last_done
|
||||
if key in self._meta:
|
||||
return {k.split('@')[-1]: self.__dict__[k][indice]
|
||||
return {k: self.__dict__['_' + key + '@' + k][indice]
|
||||
for k in self._meta[key]}
|
||||
else:
|
||||
return self.__dict__[key][indice]
|
||||
if key in self._meta:
|
||||
many_keys = self._meta[key]
|
||||
stack = {k.split('@')[-1]: [] for k in self._meta[key]}
|
||||
stack = {k: [] for k in self._meta[key]}
|
||||
else:
|
||||
stack = []
|
||||
many_keys = None
|
||||
for i in range(stack_num):
|
||||
if many_keys is not None:
|
||||
for k_ in many_keys:
|
||||
k = k_.split('@')[-1]
|
||||
stack[k] = [self.__dict__[k_][indice]] + stack[k]
|
||||
k__ = '_' + key + '@' + k_
|
||||
stack[k_] = [self.__dict__[k__][indice]] + stack[k_]
|
||||
else:
|
||||
stack = [self.__dict__[key][indice]] + stack
|
||||
pre_indice = indice - 1
|
||||
@ -263,6 +266,7 @@ class ReplayBuffer(object):
|
||||
if many_keys is not None:
|
||||
for k in stack:
|
||||
stack[k] = np.stack(stack[k], axis=1)
|
||||
stack = Batch(**stack)
|
||||
else:
|
||||
stack = np.stack(stack, axis=1)
|
||||
return stack
|
||||
@ -305,14 +309,18 @@ class ListReplayBuffer(ReplayBuffer):
|
||||
|
||||
def reset(self):
|
||||
self._index = self._size = 0
|
||||
for k in list(self.__dict__.keys()):
|
||||
if not k.startswith('_'):
|
||||
for k in list(self.__dict__):
|
||||
if isinstance(self.__dict__[k], list):
|
||||
self.__dict__[k] = []
|
||||
|
||||
|
||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
"""Prioritized replay buffer implementation.
|
||||
|
||||
:param float alpha: the prioritization exponent.
|
||||
:param float beta: the importance sample soft coefficient.
|
||||
:param str mode: defaults to ``weight``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.data.ReplayBuffer` for more
|
||||
@ -324,8 +332,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
if mode != 'weight':
|
||||
raise NotImplementedError
|
||||
super().__init__(size, **kwargs)
|
||||
self._alpha = alpha # prioritization exponent
|
||||
self._beta = beta # importance sample soft coefficient
|
||||
self._alpha = alpha
|
||||
self._beta = beta
|
||||
self._weight_sum = 0.0
|
||||
self.weight = np.zeros(size, dtype=np.float64)
|
||||
self._amortization_freq = 50
|
||||
@ -382,8 +390,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
def update_weight(self, indice, new_weight: np.ndarray):
|
||||
"""Update priority weight by indice in this buffer.
|
||||
|
||||
:param indice: indice you want to update weight
|
||||
:param new_weight: new priority weight you wangt to update
|
||||
:param np.ndarray indice: indice you want to update weight
|
||||
:param np.ndarray new_weight: new priority weight you wangt to update
|
||||
"""
|
||||
self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \
|
||||
- self.weight[indice].sum()
|
||||
@ -402,7 +410,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
)
|
||||
|
||||
def _check_weight_sum(self):
|
||||
# keep a accurate _weight_sum
|
||||
# keep an accurate _weight_sum
|
||||
self._amortization_counter += 1
|
||||
if self._amortization_counter % self._amortization_freq == 0:
|
||||
self._weight_sum = np.sum(self.weight)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user