| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  | import pprint | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2020-06-01 08:30:09 +08:00
										 |  |  | from typing import Any, Tuple, Union, Optional | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  | from tianshou.data.batch import Batch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ReplayBuffer(object): | 
					
						
							| 
									
										
										
										
											2020-04-05 18:34:45 +08:00
										 |  |  |     """:class:`~tianshou.data.ReplayBuffer` stores data generated from
 | 
					
						
							| 
									
										
										
										
											2020-04-30 16:31:40 +08:00
										 |  |  |     interaction between the policy and environment. It stores basically 7 types | 
					
						
							| 
									
										
										
										
											2020-04-05 18:34:45 +08:00
										 |  |  |     of data, as mentioned in :class:`~tianshou.data.Batch`, based on | 
					
						
							|  |  |  |     ``numpy.ndarray``. Here is the usage: | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |     :: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-09 21:36:53 +08:00
										 |  |  |         >>> import numpy as np | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |         >>> from tianshou.data import ReplayBuffer | 
					
						
							|  |  |  |         >>> buf = ReplayBuffer(size=20) | 
					
						
							|  |  |  |         >>> for i in range(3): | 
					
						
							|  |  |  |         ...     buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) | 
					
						
							| 
									
										
										
										
											2020-04-05 18:34:45 +08:00
										 |  |  |         >>> len(buf) | 
					
						
							|  |  |  |         3 | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |         >>> buf.obs | 
					
						
							|  |  |  |         # since we set size = 20, len(buf.obs) == 20. | 
					
						
							|  |  |  |         array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., | 
					
						
							|  |  |  |                0., 0., 0., 0.]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         >>> buf2 = ReplayBuffer(size=10) | 
					
						
							|  |  |  |         >>> for i in range(15): | 
					
						
							|  |  |  |         ...     buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) | 
					
						
							| 
									
										
										
										
											2020-04-05 18:34:45 +08:00
										 |  |  |         >>> len(buf2) | 
					
						
							|  |  |  |         10 | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |         >>> buf2.obs | 
					
						
							|  |  |  |         # since its size = 10, it only stores the last 10 steps' result. | 
					
						
							|  |  |  |         array([10., 11., 12., 13., 14.,  5.,  6.,  7.,  8.,  9.]) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-05 18:34:45 +08:00
										 |  |  |         >>> # move buf2's result into buf (meanwhile keep it chronologically) | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |         >>> buf.update(buf2) | 
					
						
							|  |  |  |         array([ 0.,  1.,  2.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14., | 
					
						
							|  |  |  |                 0.,  0.,  0.,  0.,  0.,  0.,  0.]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         >>> # get a random sample from buffer | 
					
						
							|  |  |  |         >>> # the batch_data is equal to buf[incide]. | 
					
						
							|  |  |  |         >>> batch_data, indice = buf.sample(batch_size=4) | 
					
						
							|  |  |  |         >>> batch_data.obs == buf[indice].obs | 
					
						
							|  |  |  |         array([ True,  True,  True,  True]) | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-29 17:48:48 +08:00
										 |  |  |     :class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling | 
					
						
							|  |  |  |     (typically for RNN usage, see issue#19), ignoring storing the next | 
					
						
							|  |  |  |     observation (save memory in atari tasks), and multi-modal observation (see | 
					
						
							| 
									
										
										
										
											2020-06-08 21:53:00 +08:00
										 |  |  |     issue#38): | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |     :: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-10 09:01:17 +08:00
										 |  |  |         >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |         >>> for i in range(16): | 
					
						
							|  |  |  |         ...     done = i % 5 == 0 | 
					
						
							| 
									
										
										
										
											2020-04-29 17:48:48 +08:00
										 |  |  |         ...     buf.add(obs={'id': i}, act=i, rew=i, done=done, | 
					
						
							|  |  |  |         ...             obs_next={'id': i + 1}) | 
					
						
							|  |  |  |         >>> print(buf)  # you can see obs_next is not saved in buf | 
					
						
							| 
									
										
										
										
											2020-04-09 21:36:53 +08:00
										 |  |  |         ReplayBuffer( | 
					
						
							| 
									
										
										
										
											2020-04-29 17:48:48 +08:00
										 |  |  |             act: array([ 9., 10., 11., 12., 13., 14., 15.,  7.,  8.]), | 
					
						
							|  |  |  |             done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]), | 
					
						
							| 
									
										
										
										
											2020-06-01 08:30:09 +08:00
										 |  |  |             info: Batch(), | 
					
						
							| 
									
										
										
										
											2020-04-29 17:48:48 +08:00
										 |  |  |             obs: Batch( | 
					
						
							|  |  |  |                      id: array([ 9., 10., 11., 12., 13., 14., 15.,  7.,  8.]), | 
					
						
							|  |  |  |                  ), | 
					
						
							|  |  |  |             policy: Batch(), | 
					
						
							|  |  |  |             rew: array([ 9., 10., 11., 12., 13., 14., 15.,  7.,  8.]), | 
					
						
							| 
									
										
										
										
											2020-04-09 21:36:53 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |         >>> index = np.arange(len(buf)) | 
					
						
							| 
									
										
										
										
											2020-04-29 17:48:48 +08:00
										 |  |  |         >>> print(buf.get(index, 'obs').id) | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |         [[ 7.  7.  8.  9.] | 
					
						
							|  |  |  |          [ 7.  8.  9. 10.] | 
					
						
							|  |  |  |          [11. 11. 11. 11.] | 
					
						
							|  |  |  |          [11. 11. 11. 12.] | 
					
						
							|  |  |  |          [11. 11. 12. 13.] | 
					
						
							|  |  |  |          [11. 12. 13. 14.] | 
					
						
							|  |  |  |          [12. 13. 14. 15.] | 
					
						
							|  |  |  |          [ 7.  7.  7.  7.] | 
					
						
							|  |  |  |          [ 7.  7.  7.  8.]] | 
					
						
							|  |  |  |         >>> # here is another way to get the stacked data | 
					
						
							|  |  |  |         >>> # (stack only for obs and obs_next) | 
					
						
							| 
									
										
										
										
											2020-04-29 17:48:48 +08:00
										 |  |  |         >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum() | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |         0.0 | 
					
						
							| 
									
										
										
										
											2020-04-29 17:48:48 +08:00
										 |  |  |         >>> # we can get obs_next through __getitem__, even if it doesn't exist | 
					
						
							|  |  |  |         >>> print(buf[:].obs_next.id) | 
					
						
							| 
									
										
										
										
											2020-04-11 16:54:27 +08:00
										 |  |  |         [[ 7.  8.  9. 10.] | 
					
						
							|  |  |  |          [ 7.  8.  9. 10.] | 
					
						
							|  |  |  |          [11. 11. 11. 12.] | 
					
						
							|  |  |  |          [11. 11. 12. 13.] | 
					
						
							|  |  |  |          [11. 12. 13. 14.] | 
					
						
							|  |  |  |          [12. 13. 14. 15.] | 
					
						
							|  |  |  |          [12. 13. 14. 15.] | 
					
						
							|  |  |  |          [ 7.  7.  7.  8.] | 
					
						
							|  |  |  |          [ 7.  7.  8.  9.]] | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2020-03-13 17:49:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def __init__(self, size: int, stack_num: Optional[int] = 0, | 
					
						
							| 
									
										
										
										
											2020-05-16 20:08:32 +08:00
										 |  |  |                  ignore_obs_next: bool = False, **kwargs) -> None: | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  |         super().__init__() | 
					
						
							|  |  |  |         self._maxsize = size | 
					
						
							| 
									
										
										
										
											2020-04-08 21:13:15 +08:00
										 |  |  |         self._stack = stack_num | 
					
						
							| 
									
										
										
										
											2020-04-10 09:01:17 +08:00
										 |  |  |         self._save_s_ = not ignore_obs_next | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |         self._meta = {} | 
					
						
							| 
									
										
										
										
											2020-03-11 17:28:51 +08:00
										 |  |  |         self.reset() | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def __len__(self) -> int: | 
					
						
							| 
									
										
										
										
											2020-04-04 21:02:06 +08:00
										 |  |  |         """Return len(self).""" | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  |         return self._size | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def __repr__(self) -> str: | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |         """Return str(self).""" | 
					
						
							|  |  |  |         s = self.__class__.__name__ + '(\n' | 
					
						
							|  |  |  |         flag = False | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |         for k in sorted(list(self.__dict__) + list(self._meta)): | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |             if k[0] != '_' and (self.__dict__.get(k, None) is not None or | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |                                 k in self._meta): | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |                 rpl = '\n' + ' ' * (6 + len(k)) | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |                 obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl) | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |                 s += f'    {k}: {obj},\n' | 
					
						
							|  |  |  |                 flag = True | 
					
						
							|  |  |  |         if flag: | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |             s += ')' | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |             s = self.__class__.__name__ + '()' | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |         return s | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def __getattr__(self, key: str) -> Union[Batch, np.ndarray]: | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |         """Return self.key""" | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |         if key not in self._meta: | 
					
						
							|  |  |  |             if key not in self.__dict__: | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |                 raise AttributeError(key) | 
					
						
							|  |  |  |             return self.__dict__[key] | 
					
						
							|  |  |  |         d = {} | 
					
						
							|  |  |  |         for k_ in self._meta[key]: | 
					
						
							|  |  |  |             k__ = '_' + key + '@' + k_ | 
					
						
							| 
									
										
										
										
											2020-06-09 18:46:14 +08:00
										 |  |  |             if k__ in self.__dict__: | 
					
						
							|  |  |  |                 d[k_] = self.__dict__[k__] | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 d[k_] = self.__getattr__(k__) | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |         return Batch(**d) | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-01 08:30:09 +08:00
										 |  |  |     def _add_to_buffer(self, name: str, inst: Any) -> None: | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  |         if inst is None: | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |             if getattr(self, name, None) is None: | 
					
						
							|  |  |  |                 self.__dict__[name] = None | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  |             return | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |         if name in self._meta: | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |             for k in inst.keys(): | 
					
						
							|  |  |  |                 self._add_to_buffer('_' + name + '@' + k, inst[k]) | 
					
						
							|  |  |  |             return | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  |         if self.__dict__.get(name, None) is None: | 
					
						
							|  |  |  |             if isinstance(inst, np.ndarray): | 
					
						
							| 
									
										
										
										
											2020-05-29 16:27:03 +02:00
										 |  |  |                 self.__dict__[name] = np.zeros( | 
					
						
							|  |  |  |                     (self._maxsize, *inst.shape), dtype=inst.dtype) | 
					
						
							| 
									
										
										
										
											2020-05-27 11:02:23 +08:00
										 |  |  |             elif isinstance(inst, (dict, Batch)): | 
					
						
							| 
									
										
										
										
											2020-06-01 08:30:09 +08:00
										 |  |  |                 if self._meta.get(name, None) is None: | 
					
						
							|  |  |  |                     self._meta[name] = list(inst.keys()) | 
					
						
							|  |  |  |                 for k in inst.keys(): | 
					
						
							|  |  |  |                     k_ = '_' + name + '@' + k | 
					
						
							|  |  |  |                     self._add_to_buffer(k_, inst[k]) | 
					
						
							|  |  |  |             elif np.isscalar(inst): | 
					
						
							| 
									
										
										
										
											2020-05-29 16:27:03 +02:00
										 |  |  |                 self.__dict__[name] = np.zeros( | 
					
						
							|  |  |  |                     (self._maxsize,), dtype=np.asarray(inst).dtype) | 
					
						
							| 
									
										
										
										
											2020-06-01 08:30:09 +08:00
										 |  |  |             else:  # fall back to np.object | 
					
						
							|  |  |  |                 self.__dict__[name] = np.array( | 
					
						
							|  |  |  |                     [None for _ in range(self._maxsize)]) | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  |         if isinstance(inst, np.ndarray) and \ | 
					
						
							|  |  |  |                 self.__dict__[name].shape[1:] != inst.shape: | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |             raise ValueError( | 
					
						
							|  |  |  |                 "Cannot add data to a buffer with different shape, " | 
					
						
							|  |  |  |                 f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, " | 
					
						
							|  |  |  |                 f"given shape: {inst.shape}.") | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |         if name not in self._meta: | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |             self.__dict__[name][self._index] = inst | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def update(self, buffer: 'ReplayBuffer') -> None: | 
					
						
							| 
									
										
										
										
											2020-04-04 21:02:06 +08:00
										 |  |  |         """Move the data from the given buffer to self.""" | 
					
						
							| 
									
										
										
										
											2020-03-16 11:11:29 +08:00
										 |  |  |         i = begin = buffer._index % len(buffer) | 
					
						
							|  |  |  |         while True: | 
					
						
							| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  |             self.add( | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |                 buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i], | 
					
						
							| 
									
										
										
										
											2020-04-10 09:01:17 +08:00
										 |  |  |                 buffer.obs_next[i] if self._save_s_ else None, | 
					
						
							| 
									
										
										
										
											2020-04-29 17:48:48 +08:00
										 |  |  |                 buffer.info[i], buffer.policy[i]) | 
					
						
							| 
									
										
										
										
											2020-03-16 11:11:29 +08:00
										 |  |  |             i = (i + 1) % len(buffer) | 
					
						
							|  |  |  |             if i == begin: | 
					
						
							|  |  |  |                 break | 
					
						
							| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def add(self, | 
					
						
							|  |  |  |             obs: Union[dict, np.ndarray], | 
					
						
							|  |  |  |             act: Union[np.ndarray, float], | 
					
						
							|  |  |  |             rew: float, | 
					
						
							|  |  |  |             done: bool, | 
					
						
							|  |  |  |             obs_next: Optional[Union[dict, np.ndarray]] = None, | 
					
						
							| 
									
										
										
										
											2020-05-16 20:08:32 +08:00
										 |  |  |             info: dict = {}, | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |             policy: Optional[Union[dict, Batch]] = {}, | 
					
						
							|  |  |  |             **kwargs) -> None: | 
					
						
							| 
									
										
										
										
											2020-04-04 21:02:06 +08:00
										 |  |  |         """Add a batch of data into replay buffer.""" | 
					
						
							| 
									
										
										
										
											2020-06-01 08:30:09 +08:00
										 |  |  |         assert isinstance(info, (dict, Batch)), \ | 
					
						
							| 
									
										
										
										
											2020-03-13 17:49:22 +08:00
										 |  |  |             'You should return a dict in the last argument of env.step().' | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  |         self._add_to_buffer('obs', obs) | 
					
						
							|  |  |  |         self._add_to_buffer('act', act) | 
					
						
							|  |  |  |         self._add_to_buffer('rew', rew) | 
					
						
							|  |  |  |         self._add_to_buffer('done', done) | 
					
						
							| 
									
										
										
										
											2020-04-10 09:01:17 +08:00
										 |  |  |         if self._save_s_: | 
					
						
							|  |  |  |             self._add_to_buffer('obs_next', obs_next) | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  |         self._add_to_buffer('info', info) | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |         self._add_to_buffer('policy', policy) | 
					
						
							| 
									
										
										
										
											2020-03-28 15:14:41 +08:00
										 |  |  |         if self._maxsize > 0: | 
					
						
							|  |  |  |             self._size = min(self._size + 1, self._maxsize) | 
					
						
							|  |  |  |             self._index = (self._index + 1) % self._maxsize | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self._size = self._index = self._index + 1 | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def reset(self) -> None: | 
					
						
							| 
									
										
										
										
											2020-04-04 21:02:06 +08:00
										 |  |  |         """Clear all the data in replay buffer.""" | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  |         self._index = self._size = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: | 
					
						
							| 
									
										
										
										
											2020-04-05 18:34:45 +08:00
										 |  |  |         """Get a random sample from buffer with size equal to batch_size. \
 | 
					
						
							|  |  |  |         Return all the data in the buffer if batch_size is ``0``. | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         :return: Sample data and its corresponding index inside the buffer. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  |         if batch_size > 0: | 
					
						
							| 
									
										
										
										
											2020-03-13 17:49:22 +08:00
										 |  |  |             indice = np.random.choice(self._size, batch_size) | 
					
						
							| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |             indice = np.concatenate([ | 
					
						
							|  |  |  |                 np.arange(self._index, self._size), | 
					
						
							|  |  |  |                 np.arange(0, self._index), | 
					
						
							|  |  |  |             ]) | 
					
						
							| 
									
										
										
										
											2020-03-30 22:52:25 +08:00
										 |  |  |         return self[indice], indice | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def get(self, indice: Union[slice, np.ndarray], key: str, | 
					
						
							|  |  |  |             stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]: | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |         """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
 | 
					
						
							|  |  |  |         where s is self.key, t is indice. The stack_num (here equals to 4) is | 
					
						
							|  |  |  |         given from buffer initialization procedure. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |         if stack_num is None: | 
					
						
							|  |  |  |             stack_num = self._stack | 
					
						
							| 
									
										
										
										
											2020-04-10 09:01:17 +08:00
										 |  |  |         if not isinstance(indice, np.ndarray): | 
					
						
							|  |  |  |             if np.isscalar(indice): | 
					
						
							|  |  |  |                 indice = np.array(indice) | 
					
						
							|  |  |  |             elif isinstance(indice, slice): | 
					
						
							|  |  |  |                 indice = np.arange( | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |                     0 if indice.start is None | 
					
						
							|  |  |  |                     else self._size - indice.start if indice.start < 0 | 
					
						
							|  |  |  |                     else indice.start, | 
					
						
							|  |  |  |                     self._size if indice.stop is None | 
					
						
							|  |  |  |                     else self._size - indice.stop if indice.stop < 0 | 
					
						
							|  |  |  |                     else indice.stop, | 
					
						
							| 
									
										
										
										
											2020-04-10 09:01:17 +08:00
										 |  |  |                     1 if indice.step is None else indice.step) | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |         # set last frame done to True | 
					
						
							|  |  |  |         last_index = (self._index - 1 + self._size) % self._size | 
					
						
							|  |  |  |         last_done, self.done[last_index] = self.done[last_index], True | 
					
						
							| 
									
										
										
										
											2020-06-08 21:53:00 +08:00
										 |  |  |         if key == 'obs_next' and (not self._save_s_ or self.obs_next is None): | 
					
						
							| 
									
										
										
										
											2020-04-10 09:01:17 +08:00
										 |  |  |             indice += 1 - self.done[indice].astype(np.int) | 
					
						
							|  |  |  |             indice[indice == self._size] = 0 | 
					
						
							|  |  |  |             key = 'obs' | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |         if stack_num == 0: | 
					
						
							| 
									
										
										
										
											2020-04-10 09:01:17 +08:00
										 |  |  |             self.done[last_index] = last_done | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |             if key in self._meta: | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |                 return {k: self.__dict__['_' + key + '@' + k][indice] | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |                         for k in self._meta[key]} | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 return self.__dict__[key][indice] | 
					
						
							|  |  |  |         if key in self._meta: | 
					
						
							|  |  |  |             many_keys = self._meta[key] | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |             stack = {k: [] for k in self._meta[key]} | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |         else: | 
					
						
							|  |  |  |             stack = [] | 
					
						
							|  |  |  |             many_keys = None | 
					
						
							|  |  |  |         for i in range(stack_num): | 
					
						
							|  |  |  |             if many_keys is not None: | 
					
						
							|  |  |  |                 for k_ in many_keys: | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |                     k__ = '_' + key + '@' + k_ | 
					
						
							|  |  |  |                     stack[k_] = [self.__dict__[k__][indice]] + stack[k_] | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |             else: | 
					
						
							|  |  |  |                 stack = [self.__dict__[key][indice]] + stack | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |             pre_indice = indice - 1 | 
					
						
							|  |  |  |             pre_indice[pre_indice == -1] = self._size - 1 | 
					
						
							|  |  |  |             indice = pre_indice + self.done[pre_indice].astype(np.int) | 
					
						
							|  |  |  |             indice[indice == self._size] = 0 | 
					
						
							|  |  |  |         self.done[last_index] = last_done | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |         if many_keys is not None: | 
					
						
							|  |  |  |             for k in stack: | 
					
						
							|  |  |  |                 stack[k] = np.stack(stack[k], axis=1) | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |             stack = Batch(**stack) | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |         else: | 
					
						
							|  |  |  |             stack = np.stack(stack, axis=1) | 
					
						
							|  |  |  |         return stack | 
					
						
							| 
									
										
										
										
											2020-04-08 21:13:15 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch: | 
					
						
							| 
									
										
										
										
											2020-04-08 21:13:15 +08:00
										 |  |  |         """Return a data batch: self[index]. If stack_num is set to be > 0,
 | 
					
						
							|  |  |  |         return the stacked obs and obs_next with shape [batch, len, ...]. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |         return Batch( | 
					
						
							| 
									
										
										
										
											2020-04-10 09:01:17 +08:00
										 |  |  |             obs=self.get(index, 'obs'), | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |             act=self.act[index], | 
					
						
							| 
									
										
										
										
											2020-04-30 16:31:40 +08:00
										 |  |  |             # act_=self.get(index, 'act'),  # stacked action, for RNN | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |             rew=self.rew[index], | 
					
						
							|  |  |  |             done=self.done[index], | 
					
						
							| 
									
										
										
										
											2020-04-10 09:01:17 +08:00
										 |  |  |             obs_next=self.get(index, 'obs_next'), | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |             info=self.info[index], | 
					
						
							|  |  |  |             policy=self.get(index, 'policy'), | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-28 15:14:41 +08:00
										 |  |  | class ListReplayBuffer(ReplayBuffer): | 
					
						
							| 
									
										
										
										
											2020-04-05 18:34:45 +08:00
										 |  |  |     """The function of :class:`~tianshou.data.ListReplayBuffer` is almost the
 | 
					
						
							|  |  |  |     same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |     :class:`~tianshou.data.ListReplayBuffer` is based on ``list``. | 
					
						
							| 
									
										
										
										
											2020-04-09 21:36:53 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     .. seealso:: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |         Please refer to :class:`~tianshou.data.ReplayBuffer` for more | 
					
						
							| 
									
										
										
										
											2020-04-09 21:36:53 +08:00
										 |  |  |         detailed explanation. | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def __init__(self, **kwargs) -> None: | 
					
						
							| 
									
										
										
										
											2020-04-10 09:01:17 +08:00
										 |  |  |         super().__init__(size=0, ignore_obs_next=False, **kwargs) | 
					
						
							| 
									
										
										
										
											2020-03-28 15:14:41 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def _add_to_buffer( | 
					
						
							|  |  |  |             self, name: str, | 
					
						
							|  |  |  |             inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None: | 
					
						
							| 
									
										
										
										
											2020-03-28 15:14:41 +08:00
										 |  |  |         if inst is None: | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  |         if self.__dict__.get(name, None) is None: | 
					
						
							|  |  |  |             self.__dict__[name] = [] | 
					
						
							|  |  |  |         self.__dict__[name].append(inst) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def reset(self) -> None: | 
					
						
							| 
									
										
										
										
											2020-03-28 15:14:41 +08:00
										 |  |  |         self._index = self._size = 0 | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |         for k in list(self.__dict__): | 
					
						
							|  |  |  |             if isinstance(self.__dict__[k], list): | 
					
						
							| 
									
										
										
										
											2020-03-28 15:14:41 +08:00
										 |  |  |                 self.__dict__[k] = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  | class PrioritizedReplayBuffer(ReplayBuffer): | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |     """Prioritized replay buffer implementation.
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |     :param float alpha: the prioritization exponent. | 
					
						
							|  |  |  |     :param float beta: the importance sample soft coefficient. | 
					
						
							|  |  |  |     :param str mode: defaults to ``weight``. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |     .. seealso:: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Please refer to :class:`~tianshou.data.ReplayBuffer` for more | 
					
						
							|  |  |  |         detailed explanation. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2020-03-13 17:49:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def __init__(self, size: int, alpha: float, beta: float, | 
					
						
							| 
									
										
										
										
											2020-05-16 20:08:32 +08:00
										 |  |  |                  mode: str = 'weight', **kwargs) -> None: | 
					
						
							| 
									
										
										
										
											2020-04-26 12:05:58 +08:00
										 |  |  |         if mode != 'weight': | 
					
						
							|  |  |  |             raise NotImplementedError | 
					
						
							| 
									
										
										
										
											2020-04-10 09:01:17 +08:00
										 |  |  |         super().__init__(size, **kwargs) | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |         self._alpha = alpha | 
					
						
							|  |  |  |         self._beta = beta | 
					
						
							| 
									
										
										
										
											2020-04-26 12:05:58 +08:00
										 |  |  |         self._weight_sum = 0.0 | 
					
						
							|  |  |  |         self.weight = np.zeros(size, dtype=np.float64) | 
					
						
							|  |  |  |         self._amortization_freq = 50 | 
					
						
							|  |  |  |         self._amortization_counter = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def add(self, | 
					
						
							|  |  |  |             obs: Union[dict, np.ndarray], | 
					
						
							|  |  |  |             act: Union[np.ndarray, float], | 
					
						
							|  |  |  |             rew: float, | 
					
						
							|  |  |  |             done: bool, | 
					
						
							|  |  |  |             obs_next: Optional[Union[dict, np.ndarray]] = None, | 
					
						
							| 
									
										
										
										
											2020-05-16 20:08:32 +08:00
										 |  |  |             info: dict = {}, | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |             policy: Optional[Union[dict, Batch]] = {}, | 
					
						
							| 
									
										
										
										
											2020-05-16 20:08:32 +08:00
										 |  |  |             weight: float = 1.0, | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |             **kwargs) -> None: | 
					
						
							| 
									
										
										
										
											2020-04-26 12:05:58 +08:00
										 |  |  |         """Add a batch of data into replay buffer.""" | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |         self._weight_sum += np.abs(weight) ** self._alpha - \ | 
					
						
							| 
									
										
										
										
											2020-04-26 12:05:58 +08:00
										 |  |  |             self.weight[self._index] | 
					
						
							|  |  |  |         # we have to sacrifice some convenience for speed :( | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |         self._add_to_buffer('weight', np.abs(weight) ** self._alpha) | 
					
						
							|  |  |  |         super().add(obs, act, rew, done, obs_next, info, policy) | 
					
						
							| 
									
										
										
										
											2020-04-26 12:05:58 +08:00
										 |  |  |         self._check_weight_sum() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-16 20:08:32 +08:00
										 |  |  |     def sample(self, batch_size: int, | 
					
						
							|  |  |  |                importance_sample: bool = True | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |                ) -> Tuple[Batch, np.ndarray]: | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |         """Get a random sample from buffer with priority probability. \
 | 
					
						
							| 
									
										
										
										
											2020-04-26 12:05:58 +08:00
										 |  |  |         Return all the data in the buffer if batch_size is ``0``. | 
					
						
							| 
									
										
										
										
											2020-03-11 09:38:14 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-26 12:05:58 +08:00
										 |  |  |         :return: Sample data and its corresponding index inside the buffer. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if batch_size > 0 and batch_size <= self._size: | 
					
						
							|  |  |  |             # Multiple sampling of the same sample | 
					
						
							|  |  |  |             # will cause weight update conflict | 
					
						
							|  |  |  |             indice = np.random.choice( | 
					
						
							|  |  |  |                 self._size, batch_size, | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |                 p=(self.weight / self.weight.sum())[:self._size], | 
					
						
							|  |  |  |                 replace=False) | 
					
						
							| 
									
										
										
										
											2020-04-26 12:05:58 +08:00
										 |  |  |             # self._weight_sum is not work for the accuracy issue | 
					
						
							|  |  |  |             # p=(self.weight/self._weight_sum)[:self._size], replace=False) | 
					
						
							|  |  |  |         elif batch_size == 0: | 
					
						
							|  |  |  |             indice = np.concatenate([ | 
					
						
							|  |  |  |                 np.arange(self._index, self._size), | 
					
						
							|  |  |  |                 np.arange(0, self._index), | 
					
						
							|  |  |  |             ]) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             # if batch_size larger than len(self), | 
					
						
							|  |  |  |             # it will lead to a bug in update weight | 
					
						
							|  |  |  |             raise ValueError("batch_size should be less than len(self)") | 
					
						
							|  |  |  |         batch = self[indice] | 
					
						
							|  |  |  |         if importance_sample: | 
					
						
							|  |  |  |             impt_weight = Batch( | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |                 impt_weight=1 / np.power( | 
					
						
							|  |  |  |                     self._size * (batch.weight / self._weight_sum), | 
					
						
							|  |  |  |                     self._beta)) | 
					
						
							| 
									
										
										
										
											2020-04-26 12:05:58 +08:00
										 |  |  |             batch.append(impt_weight) | 
					
						
							|  |  |  |         self._check_weight_sum() | 
					
						
							|  |  |  |         return batch, indice | 
					
						
							| 
									
										
										
										
											2020-03-11 09:09:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def reset(self) -> None: | 
					
						
							| 
									
										
										
										
											2020-04-26 12:05:58 +08:00
										 |  |  |         self._amortization_counter = 0 | 
					
						
							|  |  |  |         super().reset() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def update_weight(self, indice: Union[slice, np.ndarray], | 
					
						
							|  |  |  |                       new_weight: np.ndarray) -> None: | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |         """Update priority weight by indice in this buffer.
 | 
					
						
							| 
									
										
										
										
											2020-04-26 12:05:58 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |         :param np.ndarray indice: indice you want to update weight | 
					
						
							|  |  |  |         :param np.ndarray new_weight: new priority weight you wangt to update | 
					
						
							| 
									
										
										
										
											2020-04-26 12:05:58 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \ | 
					
						
							|  |  |  |             - self.weight[indice].sum() | 
					
						
							|  |  |  |         self.weight[indice] = np.power(np.abs(new_weight), self._alpha) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch: | 
					
						
							| 
									
										
										
										
											2020-04-26 12:05:58 +08:00
										 |  |  |         return Batch( | 
					
						
							|  |  |  |             obs=self.get(index, 'obs'), | 
					
						
							|  |  |  |             act=self.act[index], | 
					
						
							| 
									
										
										
										
											2020-04-30 16:31:40 +08:00
										 |  |  |             # act_=self.get(index, 'act'),  # stacked action, for RNN | 
					
						
							| 
									
										
										
										
											2020-04-26 12:05:58 +08:00
										 |  |  |             rew=self.rew[index], | 
					
						
							|  |  |  |             done=self.done[index], | 
					
						
							|  |  |  |             obs_next=self.get(index, 'obs_next'), | 
					
						
							|  |  |  |             info=self.info[index], | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |             weight=self.weight[index], | 
					
						
							|  |  |  |             policy=self.get(index, 'policy'), | 
					
						
							| 
									
										
										
										
											2020-04-26 12:05:58 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |     def _check_weight_sum(self) -> None: | 
					
						
							| 
									
										
										
										
											2020-04-29 12:14:53 +08:00
										 |  |  |         # keep an accurate _weight_sum | 
					
						
							| 
									
										
										
										
											2020-04-26 12:05:58 +08:00
										 |  |  |         self._amortization_counter += 1 | 
					
						
							|  |  |  |         if self._amortization_counter % self._amortization_freq == 0: | 
					
						
							|  |  |  |             self._weight_sum = np.sum(self.weight) | 
					
						
							|  |  |  |             self._amortization_counter = 0 |