Use lower-level API to reduce overhead. (#97)
* Use lower-level API to reduce overhead. * Further improvements. * Buffer _add_to_buffer improvement. * Do not use _data field to store Batch data to avoid overhead. Add back _meta field in Buffer. * Restore metadata attribute to store batch in Buffer. * Move out nested methods. * Update try/catch instead of actual check to efficiency. * Remove unsed branches for efficiency. * Use np.array over list when possible for efficiency. * Final performance improvement. * Add unit tests for Batch size method. * Add missing stack unit tests. * Enforce Buffer initialization to zero. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
		
							parent
							
								
									5ac9f9b144
								
							
						
					
					
						commit
						70aa7bf93e
					
				@ -39,12 +39,17 @@ def test_batch():
 | 
			
		||||
        'c': np.zeros(1),
 | 
			
		||||
        'd': Batch(e=np.array(3.0))}])
 | 
			
		||||
    assert len(batch2) == 1
 | 
			
		||||
    assert Batch().size == 0
 | 
			
		||||
    assert batch2.size == 1
 | 
			
		||||
    with pytest.raises(IndexError):
 | 
			
		||||
        batch2[-2]
 | 
			
		||||
    with pytest.raises(IndexError):
 | 
			
		||||
        batch2[1]
 | 
			
		||||
    assert batch2[0].size == 1
 | 
			
		||||
    with pytest.raises(TypeError):
 | 
			
		||||
        batch2[0][0]
 | 
			
		||||
    with pytest.raises(TypeError):
 | 
			
		||||
        len(batch2[0])
 | 
			
		||||
    assert isinstance(batch2[0].a.c, np.ndarray)
 | 
			
		||||
    assert isinstance(batch2[0].a.b, np.float64)
 | 
			
		||||
    assert isinstance(batch2[0].a.d.e, np.float64)
 | 
			
		||||
@ -72,7 +77,7 @@ def test_batch():
 | 
			
		||||
    assert batch3.a.d.e[0] == 4.0
 | 
			
		||||
    batch3.a.d[0] = Batch(f=5.0)
 | 
			
		||||
    assert batch3.a.d.f[0] == 5.0
 | 
			
		||||
    with pytest.raises(ValueError):
 | 
			
		||||
    with pytest.raises(KeyError):
 | 
			
		||||
        batch3.a.d[0] = Batch(f=5.0, g=0.0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -112,10 +117,15 @@ def test_batch_cat_and_stack():
 | 
			
		||||
    b12_stack = Batch.stack((b1, b2))
 | 
			
		||||
    assert isinstance(b12_stack.a.d.e, np.ndarray)
 | 
			
		||||
    assert b12_stack.a.d.e.ndim == 2
 | 
			
		||||
    b3 = Batch(a=np.zeros((3, 4)))
 | 
			
		||||
    b4 = Batch(a=np.ones((3, 4)))
 | 
			
		||||
    b3 = Batch(a=np.zeros((3, 4)),
 | 
			
		||||
               b=torch.ones((2, 5)),
 | 
			
		||||
               c=Batch(d=[[1], [2]]))
 | 
			
		||||
    b4 = Batch(a=np.ones((3, 4)),
 | 
			
		||||
               b=torch.ones((2, 5)),
 | 
			
		||||
               c=Batch(d=[[0], [3]]))
 | 
			
		||||
    b34_stack = Batch.stack((b3, b4), axis=1)
 | 
			
		||||
    assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
 | 
			
		||||
    assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d))))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_batch_over_batch_to_torch():
 | 
			
		||||
 | 
			
		||||
@ -27,6 +27,16 @@ def test_replaybuffer(size=10, bufsize=20):
 | 
			
		||||
    assert len(buf) == len(buf2)
 | 
			
		||||
    assert buf2[0].obs == buf[5].obs
 | 
			
		||||
    assert buf2[-1].obs == buf[4].obs
 | 
			
		||||
    b = ReplayBuffer(size=10)
 | 
			
		||||
    b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
 | 
			
		||||
    assert b.obs[0] == 1
 | 
			
		||||
    assert b.done[0] == 'str'
 | 
			
		||||
    assert np.all(b.obs[1:] == 0)
 | 
			
		||||
    assert np.all(b.done[1:] == np.array(None))
 | 
			
		||||
    assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
 | 
			
		||||
    assert np.all(b.info.a[1:] == 0)
 | 
			
		||||
    assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
 | 
			
		||||
    assert np.all(np.isnan(b.info.b.c[1:]))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_ignore_obs_next(size=10):
 | 
			
		||||
 | 
			
		||||
@ -13,6 +13,35 @@ warnings.filterwarnings(
 | 
			
		||||
    "ignore", message="pickle support for Storage will be removed in 1.5.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _is_batch_set(data: Any) -> bool:
 | 
			
		||||
    if isinstance(data, (list, tuple)):
 | 
			
		||||
        if len(data) > 0 and isinstance(data[0], (dict, Batch)):
 | 
			
		||||
            return True
 | 
			
		||||
    elif isinstance(data, np.ndarray):
 | 
			
		||||
        if isinstance(data.item(0), (dict, Batch)):
 | 
			
		||||
            return True
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _valid_bounds(length: int, index: Union[
 | 
			
		||||
        slice, int, np.integer, np.ndarray, List[int]]) -> bool:
 | 
			
		||||
    if isinstance(index, (int, np.integer)):
 | 
			
		||||
        return -length <= index and index < length
 | 
			
		||||
    elif isinstance(index, (list, np.ndarray)):
 | 
			
		||||
        return _valid_bounds(length, np.min(index)) and \
 | 
			
		||||
            _valid_bounds(length, np.max(index))
 | 
			
		||||
    elif isinstance(index, slice):
 | 
			
		||||
        if index.start is not None:
 | 
			
		||||
            start_valid = _valid_bounds(length, index.start)
 | 
			
		||||
        else:
 | 
			
		||||
            start_valid = True
 | 
			
		||||
        if index.stop is not None:
 | 
			
		||||
            stop_valid = _valid_bounds(length, index.stop - 1)
 | 
			
		||||
        else:
 | 
			
		||||
            stop_valid = True
 | 
			
		||||
        return start_valid and stop_valid
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Batch:
 | 
			
		||||
    """Tianshou provides :class:`~tianshou.data.Batch` as the internal data
 | 
			
		||||
    structure to pass any kind of data to other methods, for example, a
 | 
			
		||||
@ -75,46 +104,30 @@ class Batch:
 | 
			
		||||
        [11 22] [6 6]
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __new__(cls, *args, **kwargs) -> 'Batch':
 | 
			
		||||
        self = super().__new__(cls)
 | 
			
		||||
        self.__dict__['_data'] = {}
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 batch_dict: Optional[Union[
 | 
			
		||||
                     dict, 'Batch', Tuple[Union[dict, 'Batch']],
 | 
			
		||||
                     List[Union[dict, 'Batch']], np.ndarray]] = None,
 | 
			
		||||
                 **kwargs) -> None:
 | 
			
		||||
        def _is_batch_set(data: Any) -> bool:
 | 
			
		||||
            if isinstance(data, (list, tuple)):
 | 
			
		||||
                if len(data) > 0 and isinstance(data[0], (dict, Batch)):
 | 
			
		||||
                    return True
 | 
			
		||||
            elif isinstance(data, np.ndarray):
 | 
			
		||||
                if isinstance(data.item(0), (dict, Batch)):
 | 
			
		||||
                    return True
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        if isinstance(batch_dict, np.ndarray) and batch_dict.ndim == 0:
 | 
			
		||||
            batch_dict = batch_dict[()]
 | 
			
		||||
        if _is_batch_set(batch_dict):
 | 
			
		||||
            for k, v in zip(batch_dict[0].keys(),
 | 
			
		||||
                            zip(*[e.values() for e in batch_dict])):
 | 
			
		||||
                if isinstance(v[0], dict) or _is_batch_set(v[0]):
 | 
			
		||||
                    self[k] = Batch(v)
 | 
			
		||||
                    self.__dict__[k] = Batch(v)
 | 
			
		||||
                elif isinstance(v[0], (np.generic, np.ndarray)):
 | 
			
		||||
                    self[k] = np.stack(v, axis=0)
 | 
			
		||||
                    self.__dict__[k] = np.stack(v, axis=0)
 | 
			
		||||
                elif isinstance(v[0], torch.Tensor):
 | 
			
		||||
                    self[k] = torch.stack(v, dim=0)
 | 
			
		||||
                    self.__dict__[k] = torch.stack(v, dim=0)
 | 
			
		||||
                elif isinstance(v[0], Batch):
 | 
			
		||||
                    self[k] = Batch.stack(v)
 | 
			
		||||
                    self.__dict__[k] = Batch.stack(v)
 | 
			
		||||
                else:
 | 
			
		||||
                    self[k] = np.array(v)  # fall back to np.object
 | 
			
		||||
                    self.__dict__[k] = np.array(v)
 | 
			
		||||
        elif isinstance(batch_dict, (dict, Batch)):
 | 
			
		||||
            for k, v in batch_dict.items():
 | 
			
		||||
                if isinstance(v, dict) or _is_batch_set(v):
 | 
			
		||||
                    self[k] = Batch(v)
 | 
			
		||||
                    self.__dict__[k] = Batch(v)
 | 
			
		||||
                else:
 | 
			
		||||
                    self[k] = v
 | 
			
		||||
                    self.__dict__[k] = v
 | 
			
		||||
        if len(kwargs) > 0:
 | 
			
		||||
            self.__init__(kwargs)
 | 
			
		||||
 | 
			
		||||
@ -123,8 +136,7 @@ class Batch:
 | 
			
		||||
        for both efficiency and simplicity.
 | 
			
		||||
        """
 | 
			
		||||
        state = {}
 | 
			
		||||
        for k in self.keys():
 | 
			
		||||
            v = self[k]
 | 
			
		||||
        for k, v in self.items():
 | 
			
		||||
            if isinstance(v, Batch):
 | 
			
		||||
                v = v.__getstate__()
 | 
			
		||||
            state[k] = v
 | 
			
		||||
@ -140,26 +152,8 @@ class Batch:
 | 
			
		||||
    def __getitem__(self, index: Union[
 | 
			
		||||
            str, slice, int, np.integer, np.ndarray, List[int]]) -> 'Batch':
 | 
			
		||||
        """Return self[index]."""
 | 
			
		||||
        def _valid_bounds(length: int, index: Union[
 | 
			
		||||
                slice, int, np.integer, np.ndarray, List[int]]) -> bool:
 | 
			
		||||
            if isinstance(index, (int, np.integer)):
 | 
			
		||||
                return -length <= index and index < length
 | 
			
		||||
            elif isinstance(index, (list, np.ndarray)):
 | 
			
		||||
                return _valid_bounds(length, np.min(index)) and \
 | 
			
		||||
                    _valid_bounds(length, np.max(index))
 | 
			
		||||
            elif isinstance(index, slice):
 | 
			
		||||
                if index.start is not None:
 | 
			
		||||
                    start_valid = _valid_bounds(length, index.start)
 | 
			
		||||
                else:
 | 
			
		||||
                    start_valid = True
 | 
			
		||||
                if index.stop is not None:
 | 
			
		||||
                    stop_valid = _valid_bounds(length, index.stop - 1)
 | 
			
		||||
                else:
 | 
			
		||||
                    stop_valid = True
 | 
			
		||||
                return start_valid and stop_valid
 | 
			
		||||
 | 
			
		||||
        if isinstance(index, str):
 | 
			
		||||
            return getattr(self, index)
 | 
			
		||||
            return self.__dict__[index]
 | 
			
		||||
 | 
			
		||||
        if not _valid_bounds(len(self), index):
 | 
			
		||||
            raise IndexError(
 | 
			
		||||
@ -167,61 +161,57 @@ class Batch:
 | 
			
		||||
        else:
 | 
			
		||||
            b = Batch()
 | 
			
		||||
            for k, v in self.items():
 | 
			
		||||
                if isinstance(v, Batch) and v.size == 0:
 | 
			
		||||
                    b[k] = Batch()
 | 
			
		||||
                elif hasattr(v, '__len__') and (not isinstance(
 | 
			
		||||
                        v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
 | 
			
		||||
                    if isinstance(index, (int, np.integer)) or \
 | 
			
		||||
                            (isinstance(index, np.ndarray) and
 | 
			
		||||
                                index.ndim == 0) or not isinstance(v, list):
 | 
			
		||||
                        b[k] = v[index]
 | 
			
		||||
                if isinstance(v, Batch) and len(v.__dict__) == 0:
 | 
			
		||||
                    b.__dict__[k] = Batch()
 | 
			
		||||
                else:
 | 
			
		||||
                        b[k] = [v[i] for i in index]
 | 
			
		||||
                    b.__dict__[k] = v[index]
 | 
			
		||||
            return b
 | 
			
		||||
 | 
			
		||||
    def __setitem__(self, index: Union[
 | 
			
		||||
                        str, slice, int, np.integer, np.ndarray, List[int]],
 | 
			
		||||
                    value: Any) -> None:
 | 
			
		||||
        if isinstance(index, str):
 | 
			
		||||
            return setattr(self, index, value)
 | 
			
		||||
        if value is None:
 | 
			
		||||
            value = Batch()
 | 
			
		||||
            self.__dict__[index] = value
 | 
			
		||||
            return
 | 
			
		||||
        if not isinstance(value, (dict, Batch)):
 | 
			
		||||
            raise TypeError("Batch does not supported value type "
 | 
			
		||||
                            f"{type(value)} for item assignment.")
 | 
			
		||||
        if not set(value.keys()).issubset(self.keys()):
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
        if not set(value.keys()).issubset(self.__dict__.keys()):
 | 
			
		||||
            raise KeyError(
 | 
			
		||||
                "Creating keys is not supported by item assignment.")
 | 
			
		||||
        for key in self.keys():
 | 
			
		||||
            if isinstance(self[key], Batch):
 | 
			
		||||
                default = Batch()
 | 
			
		||||
            elif isinstance(self[key], np.ndarray) and \
 | 
			
		||||
                    self[key].dtype == np.integer:
 | 
			
		||||
        for key, val in self.items():
 | 
			
		||||
            try:
 | 
			
		||||
                self.__dict__[key][index] = value[key]
 | 
			
		||||
            except KeyError:
 | 
			
		||||
                if isinstance(val, Batch):
 | 
			
		||||
                    self.__dict__[key][index] = Batch()
 | 
			
		||||
                elif isinstance(val, np.ndarray) and \
 | 
			
		||||
                        val.dtype == np.integer:
 | 
			
		||||
                    # Fallback for np.array of integer,
 | 
			
		||||
                    # since neither None or nan is supported.
 | 
			
		||||
                default = 0
 | 
			
		||||
                    self.__dict__[key][index] = 0
 | 
			
		||||
                else:
 | 
			
		||||
                default = None
 | 
			
		||||
            self[key][index] = value.get(key, default)
 | 
			
		||||
                    self.__dict__[key][index] = None
 | 
			
		||||
 | 
			
		||||
    def __iadd__(self, val: Union['Batch', Number]):
 | 
			
		||||
        if isinstance(val, Batch):
 | 
			
		||||
            for k, r, v in zip(self.keys(), self.values(), val.values()):
 | 
			
		||||
            for (k, r), v in zip(self.__dict__.items(),
 | 
			
		||||
                                 val.__dict__.values()):
 | 
			
		||||
                if r is None:
 | 
			
		||||
                    self[k] = r
 | 
			
		||||
                    continue
 | 
			
		||||
                elif isinstance(r, list):
 | 
			
		||||
                    self[k] = [r_ + v_ for r_, v_ in zip(r, v)]
 | 
			
		||||
                    self.__dict__[k] = [r_ + v_ for r_, v_ in zip(r, v)]
 | 
			
		||||
                else:
 | 
			
		||||
                    self[k] = r + v
 | 
			
		||||
                    self.__dict__[k] += v
 | 
			
		||||
            return self
 | 
			
		||||
        elif isinstance(val, Number):
 | 
			
		||||
            for k, r in zip(self.keys(), self.values()):
 | 
			
		||||
            for k, r in self.items():
 | 
			
		||||
                if r is None:
 | 
			
		||||
                    self[k] = r
 | 
			
		||||
                    continue
 | 
			
		||||
                elif isinstance(r, list):
 | 
			
		||||
                    self[k] = [r_ + val for r_ in r]
 | 
			
		||||
                    self.__dict__[k] = [r_ + val for r_ in r]
 | 
			
		||||
                else:
 | 
			
		||||
                    self[k] = r + val
 | 
			
		||||
                    self.__dict__[k] += val
 | 
			
		||||
            return self
 | 
			
		||||
        else:
 | 
			
		||||
            raise TypeError("Only addition of Batch or number is supported.")
 | 
			
		||||
@ -229,37 +219,25 @@ class Batch:
 | 
			
		||||
    def __add__(self, val: Union['Batch', Number]):
 | 
			
		||||
        return copy.deepcopy(self).__iadd__(val)
 | 
			
		||||
 | 
			
		||||
    def __mul__(self, val: Number):
 | 
			
		||||
    def __imul__(self, val: Number):
 | 
			
		||||
        assert isinstance(val, Number), \
 | 
			
		||||
            "Only multiplication by a number is supported."
 | 
			
		||||
        result = self.__class__()
 | 
			
		||||
        for k, r in zip(self.keys(), self.values()):
 | 
			
		||||
            result[k] = r * val
 | 
			
		||||
        return result
 | 
			
		||||
        for k in self.__dict__.keys():
 | 
			
		||||
            self.__dict__[k] *= val
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def __truediv__(self, val: Number):
 | 
			
		||||
    def __mul__(self, val: Number):
 | 
			
		||||
        return copy.deepcopy(self).__imul__(val)
 | 
			
		||||
 | 
			
		||||
    def __itruediv__(self, val: Number):
 | 
			
		||||
        assert isinstance(val, Number), \
 | 
			
		||||
            "Only division by a number is supported."
 | 
			
		||||
        result = self.__class__()
 | 
			
		||||
        for k, r in zip(self.keys(), self.values()):
 | 
			
		||||
            result[k] = r / val
 | 
			
		||||
        return result
 | 
			
		||||
        for k in self.__dict__.keys():
 | 
			
		||||
            self.__dict__[k] /= val
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, key: str) -> Union['Batch', Any]:
 | 
			
		||||
        """Return self.key"""
 | 
			
		||||
        if key in self.__dict__.keys():
 | 
			
		||||
            return self.__dict__[key]
 | 
			
		||||
        elif key in self._data.keys():
 | 
			
		||||
            return self._data[key]
 | 
			
		||||
        raise AttributeError(key)
 | 
			
		||||
 | 
			
		||||
    def __setattr__(self, key, value):
 | 
			
		||||
        if key in self._data.keys():
 | 
			
		||||
            self._data[key] = value
 | 
			
		||||
        elif key in self.__dict__.keys():
 | 
			
		||||
            self.__dict__[key] = value
 | 
			
		||||
        else:
 | 
			
		||||
            self._data[key] = value
 | 
			
		||||
    def __truediv__(self, val: Number):
 | 
			
		||||
        return copy.deepcopy(self).__itruediv__(val)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        """Return str(self)."""
 | 
			
		||||
@ -278,21 +256,19 @@ class Batch:
 | 
			
		||||
 | 
			
		||||
    def keys(self) -> List[str]:
 | 
			
		||||
        """Return self.keys()."""
 | 
			
		||||
        return self._data.keys()
 | 
			
		||||
        return self.__dict__.keys()
 | 
			
		||||
 | 
			
		||||
    def values(self) -> List[Any]:
 | 
			
		||||
        """Return self.values()."""
 | 
			
		||||
        return self._data.values()
 | 
			
		||||
        return self.__dict__.values()
 | 
			
		||||
 | 
			
		||||
    def items(self) -> Any:
 | 
			
		||||
    def items(self) -> List[Tuple[str, Any]]:
 | 
			
		||||
        """Return self.items()."""
 | 
			
		||||
        return self._data.items()
 | 
			
		||||
        return self.__dict__.items()
 | 
			
		||||
 | 
			
		||||
    def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]:
 | 
			
		||||
        """Return self[k] if k in self else d. d defaults to None."""
 | 
			
		||||
        if k in self.keys():
 | 
			
		||||
            return self[k]
 | 
			
		||||
        return d
 | 
			
		||||
        return self.__dict__.get(k, d)
 | 
			
		||||
 | 
			
		||||
    def to_numpy(self) -> None:
 | 
			
		||||
        """Change all torch.Tensor to numpy.ndarray. This is an in-place
 | 
			
		||||
@ -300,7 +276,7 @@ class Batch:
 | 
			
		||||
        """
 | 
			
		||||
        for k, v in self.items():
 | 
			
		||||
            if isinstance(v, torch.Tensor):
 | 
			
		||||
                self[k] = v.detach().cpu().numpy()
 | 
			
		||||
                self.__dict__[k] = v.detach().cpu().numpy()
 | 
			
		||||
            elif isinstance(v, Batch):
 | 
			
		||||
                v.to_numpy()
 | 
			
		||||
 | 
			
		||||
@ -319,7 +295,7 @@ class Batch:
 | 
			
		||||
                v = torch.from_numpy(v).to(device)
 | 
			
		||||
                if dtype is not None:
 | 
			
		||||
                    v = v.type(dtype)
 | 
			
		||||
                self[k] = v
 | 
			
		||||
                self.__dict__[k] = v
 | 
			
		||||
            if isinstance(v, torch.Tensor):
 | 
			
		||||
                if dtype is not None and v.dtype != dtype:
 | 
			
		||||
                    must_update_tensor = True
 | 
			
		||||
@ -333,7 +309,7 @@ class Batch:
 | 
			
		||||
                if must_update_tensor:
 | 
			
		||||
                    if dtype is not None:
 | 
			
		||||
                        v = v.type(dtype)
 | 
			
		||||
                    self[k] = v.to(device)
 | 
			
		||||
                    self.__dict__[k] = v.to(device)
 | 
			
		||||
            elif isinstance(v, Batch):
 | 
			
		||||
                v.to_torch(dtype, device)
 | 
			
		||||
 | 
			
		||||
@ -351,16 +327,16 @@ class Batch:
 | 
			
		||||
        for k, v in batch.items():
 | 
			
		||||
            if v is None:
 | 
			
		||||
                continue
 | 
			
		||||
            if not hasattr(self, k) or self[k] is None:
 | 
			
		||||
                self[k] = copy.deepcopy(v)
 | 
			
		||||
            if not hasattr(self, k) or self.__dict__[k] is None:
 | 
			
		||||
                self.__dict__[k] = copy.deepcopy(v)
 | 
			
		||||
            elif isinstance(v, np.ndarray) and v.ndim > 0:
 | 
			
		||||
                self[k] = np.concatenate([self[k], v])
 | 
			
		||||
                self.__dict__[k] = np.concatenate([self.__dict__[k], v])
 | 
			
		||||
            elif isinstance(v, torch.Tensor):
 | 
			
		||||
                self[k] = torch.cat([self[k], v])
 | 
			
		||||
                self.__dict__[k] = torch.cat([self.__dict__[k], v])
 | 
			
		||||
            elif isinstance(v, list):
 | 
			
		||||
                self[k] = self[k] + copy.deepcopy(v)
 | 
			
		||||
                self.__dict__[k] += copy.deepcopy(v)
 | 
			
		||||
            elif isinstance(v, Batch):
 | 
			
		||||
                self[k].cat_(v)
 | 
			
		||||
                self.__dict__[k].cat_(v)
 | 
			
		||||
            else:
 | 
			
		||||
                s = 'No support for method "cat" with type '\
 | 
			
		||||
                    f'{type(v)} in class Batch.'
 | 
			
		||||
@ -394,11 +370,11 @@ class Batch:
 | 
			
		||||
            for k, v in zip(batches[0].keys(),
 | 
			
		||||
                            zip(*[e.values() for e in batches])):
 | 
			
		||||
                if isinstance(v[0], (np.generic, np.ndarray, list)):
 | 
			
		||||
                    batch[k] = np.stack(v, axis)
 | 
			
		||||
                    batch.__dict__[k] = np.stack(v, axis)
 | 
			
		||||
                elif isinstance(v[0], torch.Tensor):
 | 
			
		||||
                    batch[k] = torch.stack(v, axis)
 | 
			
		||||
                    batch.__dict__[k] = torch.stack(v, axis)
 | 
			
		||||
                elif isinstance(v[0], Batch):
 | 
			
		||||
                    batch[k] = Batch.stack(v, axis)
 | 
			
		||||
                    batch.__dict__[k] = Batch.stack(v, axis)
 | 
			
		||||
                else:
 | 
			
		||||
                    s = 'No support for method "stack" with type '\
 | 
			
		||||
                        f'{type(v[0])} in class Batch and axis != 0.'
 | 
			
		||||
@ -408,10 +384,8 @@ class Batch:
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        """Return len(self)."""
 | 
			
		||||
        r = []
 | 
			
		||||
        for v in self.values():
 | 
			
		||||
            if isinstance(v, Batch) and v.size == 0:
 | 
			
		||||
                continue
 | 
			
		||||
            elif isinstance(v, list) and len(v) == 0:
 | 
			
		||||
        for v in self.__dict__.values():
 | 
			
		||||
            if isinstance(v, Batch) and len(v.__dict__) == 0:
 | 
			
		||||
                continue
 | 
			
		||||
            elif hasattr(v, '__len__') and (not isinstance(
 | 
			
		||||
                    v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
 | 
			
		||||
@ -425,11 +399,11 @@ class Batch:
 | 
			
		||||
    @property
 | 
			
		||||
    def size(self) -> int:
 | 
			
		||||
        """Return self.size."""
 | 
			
		||||
        if len(self.keys()) == 0:
 | 
			
		||||
        if len(self.__dict__.keys()) == 0:
 | 
			
		||||
            return 0
 | 
			
		||||
        else:
 | 
			
		||||
            r = []
 | 
			
		||||
            for v in self.values():
 | 
			
		||||
            for v in self.__dict__.values():
 | 
			
		||||
                if isinstance(v, Batch):
 | 
			
		||||
                    r.append(v.size)
 | 
			
		||||
                elif hasattr(v, '__len__') and (not isinstance(
 | 
			
		||||
 | 
			
		||||
@ -5,7 +5,23 @@ from typing import Any, Tuple, Union, Optional
 | 
			
		||||
from .batch import Batch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReplayBuffer(Batch):
 | 
			
		||||
def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
 | 
			
		||||
    if isinstance(inst, np.ndarray):
 | 
			
		||||
        return np.full(shape=(size, *inst.shape),
 | 
			
		||||
                       fill_value=None if inst.dtype == np.inexact else 0,
 | 
			
		||||
                       dtype=inst.dtype)
 | 
			
		||||
    elif isinstance(inst, (dict, Batch)):
 | 
			
		||||
        zero_batch = Batch()
 | 
			
		||||
        for key, val in inst.items():
 | 
			
		||||
            zero_batch.__dict__[key] = _create_value(val, size)
 | 
			
		||||
        return zero_batch
 | 
			
		||||
    elif isinstance(inst, (np.generic, Number)):
 | 
			
		||||
        return _create_value(np.asarray(inst), size)
 | 
			
		||||
    else:  # fall back to np.object
 | 
			
		||||
        return np.array([None for _ in range(size)])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReplayBuffer:
 | 
			
		||||
    """:class:`~tianshou.data.ReplayBuffer` stores data generated from
 | 
			
		||||
    interaction between the policy and environment. It stores basically 7 types
 | 
			
		||||
    of data, as mentioned in :class:`~tianshou.data.Batch`, based on
 | 
			
		||||
@ -93,50 +109,46 @@ class ReplayBuffer(Batch):
 | 
			
		||||
         [ 7.  7.  7.  8.]
 | 
			
		||||
         [ 7.  7.  8.  9.]]
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, size: int, stack_num: Optional[int] = 0,
 | 
			
		||||
                 ignore_obs_next: bool = False, **kwargs) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.__dict__['_maxsize'] = size
 | 
			
		||||
        self.__dict__['_stack'] = stack_num
 | 
			
		||||
        self.__dict__['_save_s_'] = not ignore_obs_next
 | 
			
		||||
        self.__dict__['_index'] = 0
 | 
			
		||||
        self.__dict__['_size'] = 0
 | 
			
		||||
        self._maxsize = size
 | 
			
		||||
        self._stack = stack_num
 | 
			
		||||
        self._save_s_ = not ignore_obs_next
 | 
			
		||||
        self._index = 0
 | 
			
		||||
        self._size = 0
 | 
			
		||||
        self._meta = Batch()
 | 
			
		||||
        self.reset()
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        """Return len(self)."""
 | 
			
		||||
        return self._size
 | 
			
		||||
 | 
			
		||||
    def _add_to_buffer(self, name: str, inst: Any) -> None:
 | 
			
		||||
        def _create_value(inst: Any) -> Union['Batch', np.ndarray]:
 | 
			
		||||
            if isinstance(inst, np.ndarray):
 | 
			
		||||
                return np.zeros(
 | 
			
		||||
                    (self._maxsize, *inst.shape), dtype=inst.dtype)
 | 
			
		||||
            elif isinstance(inst, (dict, Batch)):
 | 
			
		||||
                return Batch([Batch(inst) for _ in range(self._maxsize)])
 | 
			
		||||
            elif isinstance(inst, (np.generic, Number)):
 | 
			
		||||
                return np.zeros(
 | 
			
		||||
                    (self._maxsize,), dtype=np.asarray(inst).dtype)
 | 
			
		||||
            else:  # fall back to np.object
 | 
			
		||||
                return np.array([None for _ in range(self._maxsize)])
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        return self.__class__.__name__ + self._meta.__repr__()[5:]
 | 
			
		||||
 | 
			
		||||
        if inst is None:
 | 
			
		||||
            inst = Batch()
 | 
			
		||||
        if name not in self.keys():
 | 
			
		||||
            self[name] = _create_value(inst)
 | 
			
		||||
    def __getattr__(self, key: str) -> Union['Batch', Any]:
 | 
			
		||||
        """Return self.key"""
 | 
			
		||||
        return self._meta.__dict__[key]
 | 
			
		||||
 | 
			
		||||
    def _add_to_buffer(self, name: str, inst: Any) -> None:
 | 
			
		||||
        try:
 | 
			
		||||
            value = self._meta.__dict__[name]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            self._meta.__dict__[name] = _create_value(inst, self._maxsize)
 | 
			
		||||
            value = self._meta.__dict__[name]
 | 
			
		||||
        if isinstance(inst, np.ndarray) and \
 | 
			
		||||
                self[name].shape[1:] != inst.shape:
 | 
			
		||||
                value.shape[1:] != inst.shape:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "Cannot add data to a buffer with different shape, "
 | 
			
		||||
                f"key: {name}, expect shape: {self[name].shape[1:]}"
 | 
			
		||||
                "Cannot add data to a buffer with different shape, key: "
 | 
			
		||||
                f"{name}, expect shape: {value.shape[1:]}"
 | 
			
		||||
                f", given shape: {inst.shape}.")
 | 
			
		||||
        if isinstance(self[name], Batch):
 | 
			
		||||
            field_keys = self[name].keys()
 | 
			
		||||
            for key, val in inst.items():
 | 
			
		||||
                if key not in field_keys:
 | 
			
		||||
                    self[name][key] = _create_value(val)
 | 
			
		||||
        self[name][self._index] = inst
 | 
			
		||||
        try:
 | 
			
		||||
            value[self._index] = inst
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            for key in set(inst.keys()).difference(value.__dict__.keys()):
 | 
			
		||||
                value.__dict__[key] = _create_value(inst[key], self._maxsize)
 | 
			
		||||
            value[self._index] = inst
 | 
			
		||||
 | 
			
		||||
    def update(self, buffer: 'ReplayBuffer') -> None:
 | 
			
		||||
        """Move the data from the given buffer to self."""
 | 
			
		||||
@ -148,11 +160,11 @@ class ReplayBuffer(Batch):
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
    def add(self,
 | 
			
		||||
            obs: Union[dict, np.ndarray],
 | 
			
		||||
            obs: Union[dict, Batch, np.ndarray],
 | 
			
		||||
            act: Union[np.ndarray, float],
 | 
			
		||||
            rew: float,
 | 
			
		||||
            done: bool,
 | 
			
		||||
            obs_next: Optional[Union[dict, np.ndarray]] = None,
 | 
			
		||||
            obs_next: Optional[Union[dict, Batch, np.ndarray]] = None,
 | 
			
		||||
            info: dict = {},
 | 
			
		||||
            policy: Optional[Union[dict, Batch]] = {},
 | 
			
		||||
            **kwargs) -> None:
 | 
			
		||||
@ -164,6 +176,8 @@ class ReplayBuffer(Batch):
 | 
			
		||||
        self._add_to_buffer('rew', rew)
 | 
			
		||||
        self._add_to_buffer('done', done)
 | 
			
		||||
        if self._save_s_:
 | 
			
		||||
            if obs_next is None:
 | 
			
		||||
                obs_next = Batch()
 | 
			
		||||
            self._add_to_buffer('obs_next', obs_next)
 | 
			
		||||
        self._add_to_buffer('info', info)
 | 
			
		||||
        self._add_to_buffer('policy', policy)
 | 
			
		||||
@ -210,6 +224,7 @@ class ReplayBuffer(Batch):
 | 
			
		||||
                else self._size - indice.stop if indice.stop < 0
 | 
			
		||||
                else indice.stop,
 | 
			
		||||
                1 if indice.step is None else indice.step)
 | 
			
		||||
        else:
 | 
			
		||||
            indice = np.array(indice, copy=True)
 | 
			
		||||
        # set last frame done to True
 | 
			
		||||
        last_index = (self._index - 1 + self._size) % self._size
 | 
			
		||||
@ -218,21 +233,9 @@ class ReplayBuffer(Batch):
 | 
			
		||||
            indice += 1 - self.done[indice].astype(np.int)
 | 
			
		||||
            indice[indice == self._size] = 0
 | 
			
		||||
            key = 'obs'
 | 
			
		||||
        if stack_num == 0:
 | 
			
		||||
            self.done[last_index] = last_done
 | 
			
		||||
            val = self[key]
 | 
			
		||||
            if isinstance(val, Batch) and val.size == 0:
 | 
			
		||||
                return val
 | 
			
		||||
            else:
 | 
			
		||||
                if isinstance(indice, (int, np.integer)) or \
 | 
			
		||||
                        (isinstance(indice, np.ndarray) and
 | 
			
		||||
                            indice.ndim == 0) or not isinstance(val, list):
 | 
			
		||||
                    return val[indice]
 | 
			
		||||
                else:
 | 
			
		||||
                    return [val[i] for i in indice]
 | 
			
		||||
        else:
 | 
			
		||||
            val = self[key]
 | 
			
		||||
            if not isinstance(val, Batch) or val.size > 0:
 | 
			
		||||
        val = self._meta.__dict__[key]
 | 
			
		||||
        try:
 | 
			
		||||
            if stack_num > 0:
 | 
			
		||||
                stack = []
 | 
			
		||||
                for _ in range(stack_num):
 | 
			
		||||
                    stack = [val[indice]] + stack
 | 
			
		||||
@ -241,11 +244,13 @@ class ReplayBuffer(Batch):
 | 
			
		||||
                    indice = np.asarray(
 | 
			
		||||
                        pre_indice + self.done[pre_indice].astype(np.int))
 | 
			
		||||
                    indice[indice == self._size] = 0
 | 
			
		||||
                if isinstance(stack[0], Batch):
 | 
			
		||||
                if isinstance(val, Batch):
 | 
			
		||||
                    stack = Batch.stack(stack, axis=indice.ndim)
 | 
			
		||||
                else:
 | 
			
		||||
                    stack = np.stack(stack, axis=indice.ndim)
 | 
			
		||||
            else:
 | 
			
		||||
                stack = val[indice]
 | 
			
		||||
        except TypeError:
 | 
			
		||||
            stack = Batch()
 | 
			
		||||
        self.done[last_index] = last_done
 | 
			
		||||
        return stack
 | 
			
		||||
@ -255,17 +260,15 @@ class ReplayBuffer(Batch):
 | 
			
		||||
        """Return a data batch: self[index]. If stack_num is set to be > 0,
 | 
			
		||||
        return the stacked obs and obs_next with shape [batch, len, ...].
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(index, str):
 | 
			
		||||
            return getattr(self, index)
 | 
			
		||||
        return Batch(
 | 
			
		||||
            obs=self.get(index, 'obs'),
 | 
			
		||||
            act=self.get(index, 'act', stack_num=0),
 | 
			
		||||
            act=self.act[index],
 | 
			
		||||
            # act_=self.get(index, 'act'),  # stacked action, for RNN
 | 
			
		||||
            rew=self.get(index, 'rew', stack_num=0),
 | 
			
		||||
            done=self.get(index, 'done', stack_num=0),
 | 
			
		||||
            rew=self.rew[index],
 | 
			
		||||
            done=self.done[index],
 | 
			
		||||
            obs_next=self.get(index, 'obs_next'),
 | 
			
		||||
            info=self.get(index, 'info', stack_num=0),
 | 
			
		||||
            policy=self.get(index, 'policy'),
 | 
			
		||||
            policy=self.get(index, 'policy')
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -288,15 +291,15 @@ class ListReplayBuffer(ReplayBuffer):
 | 
			
		||||
            inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
 | 
			
		||||
        if inst is None:
 | 
			
		||||
            return
 | 
			
		||||
        if self._data.get(name, None) is None:
 | 
			
		||||
            self._data[name] = []
 | 
			
		||||
        self._data[name].append(inst)
 | 
			
		||||
        if self._meta.__dict__.get(name, None) is None:
 | 
			
		||||
            self._meta.__dict__[name] = []
 | 
			
		||||
        self._meta.__dict__[name].append(inst)
 | 
			
		||||
 | 
			
		||||
    def reset(self) -> None:
 | 
			
		||||
        self._index = self._size = 0
 | 
			
		||||
        for k in list(self._data):
 | 
			
		||||
            if isinstance(self._data[k], list):
 | 
			
		||||
                self._data[k] = []
 | 
			
		||||
        for k in list(self._meta.__dict__.keys()):
 | 
			
		||||
            if isinstance(self._meta.__dict__[k], list):
 | 
			
		||||
                self._meta.__dict__[k] = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PrioritizedReplayBuffer(ReplayBuffer):
 | 
			
		||||
@ -322,10 +325,10 @@ class PrioritizedReplayBuffer(ReplayBuffer):
 | 
			
		||||
        self._alpha = alpha
 | 
			
		||||
        self._beta = beta
 | 
			
		||||
        self._weight_sum = 0.0
 | 
			
		||||
        self.weight = np.zeros(size, dtype=np.float64)
 | 
			
		||||
        self._amortization_freq = 50
 | 
			
		||||
        self._amortization_counter = 0
 | 
			
		||||
        self._replace = replace
 | 
			
		||||
        self._meta.__dict__['weight'] = np.zeros(size, dtype=np.float64)
 | 
			
		||||
 | 
			
		||||
    def add(self,
 | 
			
		||||
            obs: Union[dict, np.ndarray],
 | 
			
		||||
@ -338,9 +341,9 @@ class PrioritizedReplayBuffer(ReplayBuffer):
 | 
			
		||||
            weight: float = 1.0,
 | 
			
		||||
            **kwargs) -> None:
 | 
			
		||||
        """Add a batch of data into replay buffer."""
 | 
			
		||||
        # we have to sacrifice some convenience for speed
 | 
			
		||||
        self._weight_sum += np.abs(weight) ** self._alpha - \
 | 
			
		||||
            self.weight[self._index]
 | 
			
		||||
        # we have to sacrifice some convenience for speed :(
 | 
			
		||||
            self._meta.__dict__['weight'][self._index]
 | 
			
		||||
        self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
 | 
			
		||||
        super().add(obs, act, rew, done, obs_next, info, policy)
 | 
			
		||||
        self._check_weight_sum()
 | 
			
		||||
@ -414,18 +417,16 @@ class PrioritizedReplayBuffer(ReplayBuffer):
 | 
			
		||||
            - self.weight[indice].sum()
 | 
			
		||||
        self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, index: Union[str, slice, np.ndarray]) -> Batch:
 | 
			
		||||
        if isinstance(index, str):
 | 
			
		||||
            return getattr(self, index)
 | 
			
		||||
    def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch:
 | 
			
		||||
        return Batch(
 | 
			
		||||
            obs=self.get(index, 'obs'),
 | 
			
		||||
            act=self.get(index, 'act', stack_num=0),
 | 
			
		||||
            act=self.act[index],
 | 
			
		||||
            # act_=self.get(index, 'act'),  # stacked action, for RNN
 | 
			
		||||
            rew=self.get(index, 'rew', stack_num=0),
 | 
			
		||||
            done=self.get(index, 'done', stack_num=0),
 | 
			
		||||
            rew=self.rew[index],
 | 
			
		||||
            done=self.done[index],
 | 
			
		||||
            obs_next=self.get(index, 'obs_next'),
 | 
			
		||||
            info=self.get(index, 'info'),
 | 
			
		||||
            weight=self.get(index, 'weight', stack_num=0),
 | 
			
		||||
            weight=self.weight[index],
 | 
			
		||||
            policy=self.get(index, 'policy'),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user