Improve docstrings

This commit is contained in:
Dominik Jain 2024-04-29 17:35:14 +02:00
parent 250a129cc4
commit 8ac6bf5fbb

View File

@ -172,18 +172,19 @@ class ReplayBuffer:
return np.array([last] if not self.done[last] and self._size else [], int) return np.array([last] if not self.done[last] and self._size else [], int)
def prev(self, index: int | np.ndarray) -> np.ndarray: def prev(self, index: int | np.ndarray) -> np.ndarray:
"""Return the index of previous transition. """Return the index of preceding step within the same episode if it exists.
If it does not exist (because it is the first index within the episode),
The index won't be modified if it is the beginning of an episode. the index remains unmodified.
""" """
index = (index - 1) % self._size index = (index - 1) % self._size # compute preceding index with wrap-around
# end_flag will be 1 if the previous index is the last step of an episode or
# if it is the very last index of the buffer (wrap-around case), and 0 otherwise
end_flag = self.done[index] | (index == self.last_index[0]) end_flag = self.done[index] | (index == self.last_index[0])
return (index + end_flag) % self._size return (index + end_flag) % self._size
def next(self, index: int | np.ndarray) -> np.ndarray: def next(self, index: int | np.ndarray) -> np.ndarray:
"""Return the index of next transition. """Return the index of next step if there is a next step within the episode.
If there isn't a next step, the index remains unmodified.
The index won't be modified if it is the end of an episode.
""" """
end_flag = self.done[index] | (index == self.last_index[0]) end_flag = self.done[index] | (index == self.last_index[0])
return (index + (1 - end_flag)) % self._size return (index + (1 - end_flag)) % self._size