Closes #914 Additional changes: - Deprecate python below 11 - Remove 3rd party and throughput tests. This simplifies install and test pipeline - Remove gym compatibility and shimmy - Format with 3.11 conventions. In particular, add `zip(..., strict=True/False)` where possible Since the additional tests and gym were complicating the CI pipeline (flaky and dist-dependent), it didn't make sense to work on fixing the current tests in this PR to then just delete them in the next one. So this PR changes the build and removes these tests at the same time.
83 lines
3.5 KiB
Python
83 lines
3.5 KiB
Python
import numpy as np
|
|
|
|
from tianshou.data import ReplayBuffer, ReplayBufferManager
|
|
from tianshou.data.types import RolloutBatchProtocol
|
|
|
|
|
|
class CachedReplayBuffer(ReplayBufferManager):
|
|
"""CachedReplayBuffer contains a given main buffer and n cached buffers, ``cached_buffer_num * ReplayBuffer(size=max_episode_length)``.
|
|
|
|
The memory layout is: ``| main_buffer | cached_buffers[0] | cached_buffers[1] | ...
|
|
| cached_buffers[cached_buffer_num - 1] |``.
|
|
|
|
The data is first stored in cached buffers. When an episode is terminated, the data
|
|
will move to the main buffer and the corresponding cached buffer will be reset.
|
|
|
|
:param ReplayBuffer main_buffer: the main buffer whose ``.update()`` function
|
|
behaves normally.
|
|
:param int cached_buffer_num: number of ReplayBuffer needs to be created for cached
|
|
buffer.
|
|
:param int max_episode_length: the maximum length of one episode, used in each
|
|
cached buffer's maxsize.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
main_buffer: ReplayBuffer,
|
|
cached_buffer_num: int,
|
|
max_episode_length: int,
|
|
) -> None:
|
|
assert cached_buffer_num > 0
|
|
assert max_episode_length > 0
|
|
assert isinstance(main_buffer, ReplayBuffer)
|
|
kwargs = main_buffer.options
|
|
buffers = [main_buffer] + [
|
|
ReplayBuffer(max_episode_length, **kwargs) for _ in range(cached_buffer_num)
|
|
]
|
|
super().__init__(buffer_list=buffers)
|
|
self.main_buffer = self.buffers[0]
|
|
self.cached_buffers = self.buffers[1:]
|
|
self.cached_buffer_num = cached_buffer_num
|
|
|
|
def add(
|
|
self,
|
|
batch: RolloutBatchProtocol,
|
|
buffer_ids: np.ndarray | list[int] | None = None,
|
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
"""Add a batch of data into CachedReplayBuffer.
|
|
|
|
Each of the data's length (first dimension) must equal to the length of
|
|
buffer_ids. By default the buffer_ids is [0, 1, ..., cached_buffer_num - 1].
|
|
|
|
Return (current_index, episode_reward, episode_length, episode_start_index)
|
|
with each of the shape (len(buffer_ids), ...), where (current_index[i],
|
|
episode_reward[i], episode_length[i], episode_start_index[i]) refers to the
|
|
cached_buffer_ids[i]th cached buffer's corresponding episode result.
|
|
"""
|
|
if buffer_ids is None:
|
|
buf_arr = np.arange(1, 1 + self.cached_buffer_num)
|
|
else: # make sure it is np.ndarray
|
|
buf_arr = np.asarray(buffer_ids) + 1
|
|
ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids=buf_arr)
|
|
# find the terminated episode, move data from cached buf to main buf
|
|
updated_ptr, updated_ep_idx = [], []
|
|
done = np.logical_or(batch.terminated, batch.truncated)
|
|
for buffer_idx in buf_arr[done]:
|
|
index = self.main_buffer.update(self.buffers[buffer_idx])
|
|
if len(index) == 0: # unsuccessful move, replace with -1
|
|
index = [-1]
|
|
updated_ep_idx.append(index[0])
|
|
updated_ptr.append(index[-1])
|
|
self.buffers[buffer_idx].reset()
|
|
self._lengths[0] = len(self.main_buffer)
|
|
self._lengths[buffer_idx] = 0
|
|
self.last_index[0] = index[-1]
|
|
self.last_index[buffer_idx] = self._offset[buffer_idx]
|
|
ptr[done] = updated_ptr
|
|
ep_idx[done] = updated_ep_idx
|
|
return ptr, ep_rew, ep_len, ep_idx
|