From 8ac6bf5fbb2dca4a41ae04caeaf40fc7dc0d885f Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 29 Apr 2024 17:35:14 +0200 Subject: [PATCH] Improve docstrings --- tianshou/data/buffer/base.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index a34964f..b1719be 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -172,18 +172,19 @@ class ReplayBuffer: return np.array([last] if not self.done[last] and self._size else [], int) def prev(self, index: int | np.ndarray) -> np.ndarray: - """Return the index of previous transition. - - The index won't be modified if it is the beginning of an episode. + """Return the index of preceding step within the same episode if it exists. + If it does not exist (because it is the first index within the episode), + the index remains unmodified. """ - index = (index - 1) % self._size + index = (index - 1) % self._size # compute preceding index with wrap-around + # end_flag will be 1 if the previous index is the last step of an episode or + # if it is the very last index of the buffer (wrap-around case), and 0 otherwise end_flag = self.done[index] | (index == self.last_index[0]) return (index + end_flag) % self._size def next(self, index: int | np.ndarray) -> np.ndarray: - """Return the index of next transition. - - The index won't be modified if it is the end of an episode. + """Return the index of next step if there is a next step within the episode. + If there isn't a next step, the index remains unmodified. """ end_flag = self.done[index] | (index == self.last_index[0]) return (index + (1 - end_flag)) % self._size