27 lines
890 B
Python

class DataBufferBase(object):
"""
base class for data buffer, including replay buffer as in DQN and batched dataset as in on-policy algos
"""
def add(self, frame):
raise NotImplementedError()
def clear(self):
raise NotImplementedError()
def sample(self, batch_size):
prob_episode = np.array(self.index_lengths) * 1. / self.size
num_episodes = len(self.index)
sampled_index = [[] for _ in range(num_episodes)]
for _ in range(batch_size):
# sample which episode
sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode))
# sample which data point within the sampled episode
sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i]))
sampled_index[sampled_episode_i].append(sampled_frame_i)
return sampled_index