support Batch of Batch and fix bugs (#38)
This commit is contained in:
		
							parent
							
								
									8f718d9b13
								
							
						
					
					
						commit
						bb2f833d0e
					
				@ -94,6 +94,17 @@ def test_collector_with_dict_state():
 | 
			
		||||
    c1 = Collector(policy, envs, ReplayBuffer(size=100))
 | 
			
		||||
    c1.collect(n_step=10)
 | 
			
		||||
    c1.collect(n_episode=[2, 1, 1, 2])
 | 
			
		||||
    batch = c1.sample(10)
 | 
			
		||||
    print(batch)
 | 
			
		||||
    c0.buffer.update(c1.buffer)
 | 
			
		||||
    assert equal(c0.buffer[:len(c0.buffer)].obs.index, [
 | 
			
		||||
        0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
 | 
			
		||||
        0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
 | 
			
		||||
        1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.])
 | 
			
		||||
    c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4))
 | 
			
		||||
    c2.collect(n_episode=[0, 0, 0, 10])
 | 
			
		||||
    batch = c2.sample(10)
 | 
			
		||||
    print(batch['obs_next']['index'])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
 | 
			
		||||
@ -76,7 +76,7 @@ class Batch(object):
 | 
			
		||||
                    self.__dict__[k__] = np.array([
 | 
			
		||||
                        v[i][k_] for i in range(len(v))
 | 
			
		||||
                    ])
 | 
			
		||||
            elif isinstance(v, dict):
 | 
			
		||||
            elif isinstance(v, dict) or isinstance(v, Batch):
 | 
			
		||||
                self._meta[k] = list(v.keys())
 | 
			
		||||
                for k_ in v.keys():
 | 
			
		||||
                    k__ = '_' + k + '@' + k_
 | 
			
		||||
@ -89,7 +89,7 @@ class Batch(object):
 | 
			
		||||
        if isinstance(index, str):
 | 
			
		||||
            return self.__getattr__(index)
 | 
			
		||||
        b = Batch()
 | 
			
		||||
        for k in self.__dict__.keys():
 | 
			
		||||
        for k in self.__dict__:
 | 
			
		||||
            if k != '_meta' and self.__dict__[k] is not None:
 | 
			
		||||
                b.__dict__.update(**{k: self.__dict__[k][index]})
 | 
			
		||||
        b._meta = self._meta
 | 
			
		||||
@ -97,44 +97,44 @@ class Batch(object):
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, key):
 | 
			
		||||
        """Return self.key"""
 | 
			
		||||
        if key not in self._meta.keys():
 | 
			
		||||
            if key not in self.__dict__.keys():
 | 
			
		||||
        if key not in self._meta:
 | 
			
		||||
            if key not in self.__dict__:
 | 
			
		||||
                raise AttributeError(key)
 | 
			
		||||
            return self.__dict__[key]
 | 
			
		||||
        d = {}
 | 
			
		||||
        for k_ in self._meta[key]:
 | 
			
		||||
            k__ = '_' + key + '@' + k_
 | 
			
		||||
            d[k_] = self.__dict__[k__]
 | 
			
		||||
        return d
 | 
			
		||||
        return Batch(**d)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        """Return str(self)."""
 | 
			
		||||
        s = self.__class__.__name__ + '(\n'
 | 
			
		||||
        flag = False
 | 
			
		||||
        for k in sorted(list(self.__dict__.keys()) + list(self._meta.keys())):
 | 
			
		||||
        for k in sorted(list(self.__dict__) + list(self._meta)):
 | 
			
		||||
            if k[0] != '_' and (self.__dict__.get(k, None) is not None or
 | 
			
		||||
                                k in self._meta.keys()):
 | 
			
		||||
                                k in self._meta):
 | 
			
		||||
                rpl = '\n' + ' ' * (6 + len(k))
 | 
			
		||||
                obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
 | 
			
		||||
                s += f'    {k}: {obj},\n'
 | 
			
		||||
                flag = True
 | 
			
		||||
        if flag:
 | 
			
		||||
            s += ')\n'
 | 
			
		||||
            s += ')'
 | 
			
		||||
        else:
 | 
			
		||||
            s = self.__class__.__name__ + '()\n'
 | 
			
		||||
            s = self.__class__.__name__ + '()'
 | 
			
		||||
        return s
 | 
			
		||||
 | 
			
		||||
    def keys(self):
 | 
			
		||||
        """Return self.keys()."""
 | 
			
		||||
        return sorted([i for i in self.__dict__.keys() if i[0] != '_'] +
 | 
			
		||||
                      list(self._meta.keys()))
 | 
			
		||||
        return sorted([i for i in self.__dict__ if i[0] != '_'] +
 | 
			
		||||
                      list(self._meta))
 | 
			
		||||
 | 
			
		||||
    def append(self, batch):
 | 
			
		||||
        """Append a :class:`~tianshou.data.Batch` object to current batch."""
 | 
			
		||||
        assert isinstance(batch, Batch), 'Only append Batch is allowed!'
 | 
			
		||||
        for k in batch.__dict__.keys():
 | 
			
		||||
        for k in batch.__dict__:
 | 
			
		||||
            if k == '_meta':
 | 
			
		||||
                self._meta.update(batch.__dict__[k])
 | 
			
		||||
                self._meta.update(batch._meta)
 | 
			
		||||
                continue
 | 
			
		||||
            if batch.__dict__[k] is None:
 | 
			
		||||
                continue
 | 
			
		||||
@ -157,7 +157,7 @@ class Batch(object):
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        """Return len(self)."""
 | 
			
		||||
        return min([
 | 
			
		||||
            len(self.__dict__[k]) for k in self.__dict__.keys()
 | 
			
		||||
            len(self.__dict__[k]) for k in self.__dict__
 | 
			
		||||
            if k != '_meta' and self.__dict__[k] is not None])
 | 
			
		||||
 | 
			
		||||
    def split(self, size=None, shuffle=True):
 | 
			
		||||
 | 
			
		||||
@ -104,51 +104,50 @@ class ReplayBuffer(object):
 | 
			
		||||
        """Return str(self)."""
 | 
			
		||||
        s = self.__class__.__name__ + '(\n'
 | 
			
		||||
        flag = False
 | 
			
		||||
        for k in sorted(list(self.__dict__.keys()) + list(self._meta.keys())):
 | 
			
		||||
        for k in sorted(list(self.__dict__) + list(self._meta)):
 | 
			
		||||
            if k[0] != '_' and (self.__dict__.get(k, None) is not None or
 | 
			
		||||
                                k in self._meta.keys()):
 | 
			
		||||
                                k in self._meta):
 | 
			
		||||
                rpl = '\n' + ' ' * (6 + len(k))
 | 
			
		||||
                obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
 | 
			
		||||
                s += f'    {k}: {obj},\n'
 | 
			
		||||
                flag = True
 | 
			
		||||
        if flag:
 | 
			
		||||
            s += ')\n'
 | 
			
		||||
            s += ')'
 | 
			
		||||
        else:
 | 
			
		||||
            s = self.__class__.__name__ + '()\n'
 | 
			
		||||
            s = self.__class__.__name__ + '()'
 | 
			
		||||
        return s
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, key):
 | 
			
		||||
        """Return self.key"""
 | 
			
		||||
        if key not in self._meta.keys():
 | 
			
		||||
            if key not in self.__dict__.keys():
 | 
			
		||||
        if key not in self._meta:
 | 
			
		||||
            if key not in self.__dict__:
 | 
			
		||||
                raise AttributeError(key)
 | 
			
		||||
            return self.__dict__[key]
 | 
			
		||||
        d = {}
 | 
			
		||||
        for k_ in self._meta[key]:
 | 
			
		||||
            k__ = '_' + key + '@' + k_
 | 
			
		||||
            d[k_] = self.__dict__[k__]
 | 
			
		||||
        return d
 | 
			
		||||
        return Batch(**d)
 | 
			
		||||
 | 
			
		||||
    def _add_to_buffer(self, name, inst):
 | 
			
		||||
        if inst is None:
 | 
			
		||||
            if getattr(self, name, None) is None:
 | 
			
		||||
                self.__dict__[name] = None
 | 
			
		||||
            return
 | 
			
		||||
        if name in self._meta.keys():
 | 
			
		||||
        if name in self._meta:
 | 
			
		||||
            for k in inst.keys():
 | 
			
		||||
                self._add_to_buffer('_' + name + '@' + k, inst[k])
 | 
			
		||||
            return
 | 
			
		||||
        if self.__dict__.get(name, None) is None:
 | 
			
		||||
            if isinstance(inst, np.ndarray):
 | 
			
		||||
                self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
 | 
			
		||||
            elif isinstance(inst, dict):
 | 
			
		||||
            elif isinstance(inst, dict) or isinstance(inst, Batch):
 | 
			
		||||
                if name == 'info':
 | 
			
		||||
                    self.__dict__[name] = np.array(
 | 
			
		||||
                        [{} for _ in range(self._maxsize)])
 | 
			
		||||
                else:
 | 
			
		||||
                    if self._meta.get(name, None) is None:
 | 
			
		||||
                        self._meta[name] = [
 | 
			
		||||
                            '_' + name + '@' + k for k in inst.keys()]
 | 
			
		||||
                        self._meta[name] = list(inst.keys())
 | 
			
		||||
                    for k in inst.keys():
 | 
			
		||||
                        k_ = '_' + name + '@' + k
 | 
			
		||||
                        self._add_to_buffer(k_, inst[k])
 | 
			
		||||
@ -160,7 +159,7 @@ class ReplayBuffer(object):
 | 
			
		||||
                "Cannot add data to a buffer with different shape, "
 | 
			
		||||
                f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, "
 | 
			
		||||
                f"given shape: {inst.shape}.")
 | 
			
		||||
        if name not in self._meta.keys():
 | 
			
		||||
        if name not in self._meta:
 | 
			
		||||
            self.__dict__[name][self._index] = inst
 | 
			
		||||
 | 
			
		||||
    def update(self, buffer):
 | 
			
		||||
@ -225,8 +224,12 @@ class ReplayBuffer(object):
 | 
			
		||||
                indice = np.array(indice)
 | 
			
		||||
            elif isinstance(indice, slice):
 | 
			
		||||
                indice = np.arange(
 | 
			
		||||
                    0 if indice.start is None else indice.start,
 | 
			
		||||
                    self._size if indice.stop is None else indice.stop,
 | 
			
		||||
                    0 if indice.start is None
 | 
			
		||||
                    else self._size - indice.start if indice.start < 0
 | 
			
		||||
                    else indice.start,
 | 
			
		||||
                    self._size if indice.stop is None
 | 
			
		||||
                    else self._size - indice.stop if indice.stop < 0
 | 
			
		||||
                    else indice.stop,
 | 
			
		||||
                    1 if indice.step is None else indice.step)
 | 
			
		||||
        # set last frame done to True
 | 
			
		||||
        last_index = (self._index - 1 + self._size) % self._size
 | 
			
		||||
@ -238,21 +241,21 @@ class ReplayBuffer(object):
 | 
			
		||||
        if stack_num == 0:
 | 
			
		||||
            self.done[last_index] = last_done
 | 
			
		||||
            if key in self._meta:
 | 
			
		||||
                return {k.split('@')[-1]: self.__dict__[k][indice]
 | 
			
		||||
                return {k: self.__dict__['_' + key + '@' + k][indice]
 | 
			
		||||
                        for k in self._meta[key]}
 | 
			
		||||
            else:
 | 
			
		||||
                return self.__dict__[key][indice]
 | 
			
		||||
        if key in self._meta:
 | 
			
		||||
            many_keys = self._meta[key]
 | 
			
		||||
            stack = {k.split('@')[-1]: [] for k in self._meta[key]}
 | 
			
		||||
            stack = {k: [] for k in self._meta[key]}
 | 
			
		||||
        else:
 | 
			
		||||
            stack = []
 | 
			
		||||
            many_keys = None
 | 
			
		||||
        for i in range(stack_num):
 | 
			
		||||
            if many_keys is not None:
 | 
			
		||||
                for k_ in many_keys:
 | 
			
		||||
                    k = k_.split('@')[-1]
 | 
			
		||||
                    stack[k] = [self.__dict__[k_][indice]] + stack[k]
 | 
			
		||||
                    k__ = '_' + key + '@' + k_
 | 
			
		||||
                    stack[k_] = [self.__dict__[k__][indice]] + stack[k_]
 | 
			
		||||
            else:
 | 
			
		||||
                stack = [self.__dict__[key][indice]] + stack
 | 
			
		||||
            pre_indice = indice - 1
 | 
			
		||||
@ -263,6 +266,7 @@ class ReplayBuffer(object):
 | 
			
		||||
        if many_keys is not None:
 | 
			
		||||
            for k in stack:
 | 
			
		||||
                stack[k] = np.stack(stack[k], axis=1)
 | 
			
		||||
            stack = Batch(**stack)
 | 
			
		||||
        else:
 | 
			
		||||
            stack = np.stack(stack, axis=1)
 | 
			
		||||
        return stack
 | 
			
		||||
@ -305,14 +309,18 @@ class ListReplayBuffer(ReplayBuffer):
 | 
			
		||||
 | 
			
		||||
    def reset(self):
 | 
			
		||||
        self._index = self._size = 0
 | 
			
		||||
        for k in list(self.__dict__.keys()):
 | 
			
		||||
            if not k.startswith('_'):
 | 
			
		||||
        for k in list(self.__dict__):
 | 
			
		||||
            if isinstance(self.__dict__[k], list):
 | 
			
		||||
                self.__dict__[k] = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PrioritizedReplayBuffer(ReplayBuffer):
 | 
			
		||||
    """Prioritized replay buffer implementation.
 | 
			
		||||
 | 
			
		||||
    :param float alpha: the prioritization exponent.
 | 
			
		||||
    :param float beta: the importance sample soft coefficient.
 | 
			
		||||
    :param str mode: defaults to ``weight``.
 | 
			
		||||
 | 
			
		||||
    .. seealso::
 | 
			
		||||
 | 
			
		||||
        Please refer to :class:`~tianshou.data.ReplayBuffer` for more
 | 
			
		||||
@ -324,8 +332,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
 | 
			
		||||
        if mode != 'weight':
 | 
			
		||||
            raise NotImplementedError
 | 
			
		||||
        super().__init__(size, **kwargs)
 | 
			
		||||
        self._alpha = alpha  # prioritization exponent
 | 
			
		||||
        self._beta = beta  # importance sample soft coefficient
 | 
			
		||||
        self._alpha = alpha
 | 
			
		||||
        self._beta = beta
 | 
			
		||||
        self._weight_sum = 0.0
 | 
			
		||||
        self.weight = np.zeros(size, dtype=np.float64)
 | 
			
		||||
        self._amortization_freq = 50
 | 
			
		||||
@ -382,8 +390,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
 | 
			
		||||
    def update_weight(self, indice, new_weight: np.ndarray):
 | 
			
		||||
        """Update priority weight by indice in this buffer.
 | 
			
		||||
 | 
			
		||||
        :param indice: indice you want to update weight
 | 
			
		||||
        :param new_weight: new priority weight you wangt to update
 | 
			
		||||
        :param np.ndarray indice: indice you want to update weight
 | 
			
		||||
        :param np.ndarray new_weight: new priority weight you wangt to update
 | 
			
		||||
        """
 | 
			
		||||
        self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \
 | 
			
		||||
            - self.weight[indice].sum()
 | 
			
		||||
@ -402,7 +410,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _check_weight_sum(self):
 | 
			
		||||
        # keep a accurate _weight_sum
 | 
			
		||||
        # keep an accurate _weight_sum
 | 
			
		||||
        self._amortization_counter += 1
 | 
			
		||||
        if self._amortization_counter % self._amortization_freq == 0:
 | 
			
		||||
            self._weight_sum = np.sum(self.weight)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user