n+e 09692c84fe
fix numpy>=1.20 typing check (#323)
Change the behavior of to_numpy and to_torch: from now on, dict is automatically converted to Batch and list is automatically converted to np.ndarray (if an error occurs, raise the exception instead of converting each element in the list).
2021-03-30 16:06:03 +08:00

82 lines
3.4 KiB
Python

import numpy as np
from typing import List, Tuple, Union, Optional
from tianshou.data import Batch, ReplayBuffer, ReplayBufferManager
class CachedReplayBuffer(ReplayBufferManager):
"""CachedReplayBuffer contains a given main buffer and n cached buffers, \
``cached_buffer_num * ReplayBuffer(size=max_episode_length)``.
The memory layout is: ``| main_buffer | cached_buffers[0] | cached_buffers[1] | ...
| cached_buffers[cached_buffer_num - 1] |``.
The data is first stored in cached buffers. When an episode is terminated, the data
will move to the main buffer and the corresponding cached buffer will be reset.
:param ReplayBuffer main_buffer: the main buffer whose ``.update()`` function
behaves normally.
:param int cached_buffer_num: number of ReplayBuffer needs to be created for cached
buffer.
:param int max_episode_length: the maximum length of one episode, used in each
cached buffer's maxsize.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
def __init__(
self,
main_buffer: ReplayBuffer,
cached_buffer_num: int,
max_episode_length: int,
) -> None:
assert cached_buffer_num > 0 and max_episode_length > 0
assert type(main_buffer) == ReplayBuffer
kwargs = main_buffer.options
buffers = [main_buffer] + [
ReplayBuffer(max_episode_length, **kwargs)
for _ in range(cached_buffer_num)
]
super().__init__(buffer_list=buffers)
self.main_buffer = self.buffers[0]
self.cached_buffers = self.buffers[1:]
self.cached_buffer_num = cached_buffer_num
def add(
self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Add a batch of data into CachedReplayBuffer.
Each of the data's length (first dimension) must equal to the length of
buffer_ids. By default the buffer_ids is [0, 1, ..., cached_buffer_num - 1].
Return (current_index, episode_reward, episode_length, episode_start_index)
with each of the shape (len(buffer_ids), ...), where (current_index[i],
episode_reward[i], episode_length[i], episode_start_index[i]) refers to the
cached_buffer_ids[i]th cached buffer's corresponding episode result.
"""
if buffer_ids is None:
buf_arr = np.arange(1, 1 + self.cached_buffer_num)
else: # make sure it is np.ndarray
buf_arr = np.asarray(buffer_ids) + 1
ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids=buf_arr)
# find the terminated episode, move data from cached buf to main buf
updated_ptr, updated_ep_idx = [], []
done = batch.done.astype(bool)
for buffer_idx in buf_arr[done]:
index = self.main_buffer.update(self.buffers[buffer_idx])
if len(index) == 0: # unsuccessful move, replace with -1
index = [-1]
updated_ep_idx.append(index[0])
updated_ptr.append(index[-1])
self.buffers[buffer_idx].reset()
self._lengths[0] = len(self.main_buffer)
self._lengths[buffer_idx] = 0
self.last_index[0] = index[-1]
self.last_index[buffer_idx] = self._offset[buffer_idx]
ptr[done] = updated_ptr
ep_idx[done] = updated_ep_idx
return ptr, ep_rew, ep_len, ep_idx