2018-05-20 22:36:04 +08:00

132 lines
4.5 KiB
Python

import logging
import numpy as np
from .replay_buffer_base import ReplayBufferBase
__all__ = [
'VanillaReplayBuffer',
]
STATE = 0
ACTION = 1
REWARD = 2
DONE = 3
# TODO: valid data points could be less than `nstep` timesteps. Check priority replay paper!
class VanillaReplayBuffer(ReplayBufferBase):
"""
Class for vanilla replay buffer as used in (Mnih, et al., 2015).
Frames are always continuous in temporal order. They are only removed from the beginning
and added at the tail.
This continuity in `self.data` could be exploited only in vanilla replay buffer.
:param capacity: An int. The capacity of the buffer.
:param nstep: An int defaulting to 1. The number of timesteps to lookahead for temporal difference computation.
Only continuous data pieces longer than this number or already terminated ones are
considered valid data points.
"""
def __init__(self, capacity, nstep=1):
assert capacity > 0
self.capacity = int(capacity)
self.nstep = nstep
self.data = [[]]
self.index = [[]]
self.candidate_index = 0
self.size = 0 # number of valid data points (not frames)
self.index_lengths = [0] # for sampling
def add(self, frame):
"""
Adds one frame of data to the buffer.
:param frame: A tuple of (observation, action, reward, done_flag).
"""
self.data[-1].append(frame)
has_enough_frames = len(self.data[-1]) > self.nstep
if frame[DONE]: # episode terminates, all trailing frames become valid data points
trailing_index = list(range(self.candidate_index, len(self.data[-1])))
self.index[-1] += trailing_index
self.size += len(trailing_index)
self.index_lengths[-1] += len(trailing_index)
# prepare for the next episode
self.data.append([])
self.index.append([])
self.candidate_index = 0
self.index_lengths.append(0)
elif has_enough_frames: # add one valid data point
self.index[-1].append(self.candidate_index)
self.candidate_index += 1
self.size += 1
self.index_lengths[-1] += 1
# automated removal to capacity
if self.size > self.capacity:
self.remove()
def remove(self):
"""
Removes data from the buffer until ``self.size <= self.capacity``.
"""
if self.size:
while self.size > self.capacity:
self._remove_oldest()
else:
logging.warning('Attempting to remove from empty buffer!')
def _remove_oldest(self):
"""
Removes the oldest data point, in this case, just the oldest frame. Empty episodes are also removed
if resulted from removal.
"""
self.index[0].pop() # note that all index of frames in the first episode are shifted forward by 1
if self.index[0]: # first episode still has data points
self.data[0].pop(0)
if len(self.data) == 1: # otherwise self.candidate index is for another episode
self.candidate_index -= 1
self.index_lengths[0] -= 1
else: # first episode becomes empty
self.data.pop(0)
self.index.pop(0)
if len(self.data) == 0: # otherwise self.candidate index is for another episode
self.candidate_index = 0
self.index_lengths.pop(0)
self.size -= 1
def sample(self, batch_size):
"""
Performs uniform random sampling on ``self.index``. For simplicity, we do random sampling with replacement
for now with time O(``batch_size``). Fastest sampling without replacement seems to have to be of time
O(``batch_size`` * log(num_episodes)).
:param batch_size: An int. The size of the minibatch.
:return: A list of list of the sampled indexes. Episodes without sampled data points
correspond to empty sub-lists.
"""
prob_episode = np.array(self.index_lengths) * 1. / self.size
num_episodes = len(self.index)
sampled_index = [[] for _ in range(num_episodes)]
for _ in range(batch_size):
# sample which episode
sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode))
# sample which data point within the sampled episode
sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i]))
sampled_index[sampled_episode_i].append(sampled_frame_i)
return sampled_index