27 lines
890 B
Python
27 lines
890 B
Python
|
|
|
|
class DataBufferBase(object):
|
|
"""
|
|
base class for data buffer, including replay buffer as in DQN and batched dataset as in on-policy algos
|
|
"""
|
|
def add(self, frame):
|
|
raise NotImplementedError()
|
|
|
|
def clear(self):
|
|
raise NotImplementedError()
|
|
|
|
def sample(self, batch_size):
|
|
prob_episode = np.array(self.index_lengths) * 1. / self.size
|
|
num_episodes = len(self.index)
|
|
sampled_index = [[] for _ in range(num_episodes)]
|
|
|
|
for _ in range(batch_size):
|
|
# sample which episode
|
|
sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode))
|
|
|
|
# sample which data point within the sampled episode
|
|
sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i]))
|
|
sampled_index[sampled_episode_i].append(sampled_frame_i)
|
|
|
|
return sampled_index
|